conv_encoder.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. from torch.nn import functional as F
  10. from typeguard import check_argument_types
  11. import numpy as np
  12. from funasr.modules.nets_utils import make_pad_mask
  13. from funasr.modules.layer_norm import LayerNorm
  14. from funasr.models.encoder.abs_encoder import AbsEncoder
  15. import math
  16. from funasr.modules.repeat import repeat
  17. class EncoderLayer(nn.Module):
  18. def __init__(
  19. self,
  20. input_units,
  21. num_units,
  22. kernel_size=3,
  23. activation="tanh",
  24. stride=1,
  25. include_batch_norm=False,
  26. residual=False
  27. ):
  28. super().__init__()
  29. left_padding = math.ceil((kernel_size - stride) / 2)
  30. right_padding = kernel_size - stride - left_padding
  31. self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  32. self.conv1d = nn.Conv1d(
  33. input_units,
  34. num_units,
  35. kernel_size,
  36. stride,
  37. )
  38. self.activation = self.get_activation(activation)
  39. if include_batch_norm:
  40. self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
  41. self.residual = residual
  42. self.include_batch_norm = include_batch_norm
  43. self.input_units = input_units
  44. self.num_units = num_units
  45. self.stride = stride
  46. @staticmethod
  47. def get_activation(activation):
  48. if activation == "tanh":
  49. return nn.Tanh()
  50. else:
  51. return nn.ReLU()
  52. def forward(self, xs_pad, ilens=None):
  53. outputs = self.conv1d(self.conv_padding(xs_pad))
  54. if self.residual and self.stride == 1 and self.input_units == self.num_units:
  55. outputs = outputs + xs_pad
  56. if self.include_batch_norm:
  57. outputs = self.bn(outputs)
  58. # add parenthesis for repeat module
  59. return self.activation(outputs), ilens
  60. class ConvEncoder(AbsEncoder):
  61. """
  62. Author: Speech Lab of DAMO Academy, Alibaba Group
  63. Convolution encoder in OpenNMT framework
  64. """
  65. def __init__(
  66. self,
  67. num_layers,
  68. input_units,
  69. num_units,
  70. kernel_size=3,
  71. dropout_rate=0.3,
  72. position_encoder=None,
  73. activation='tanh',
  74. auxiliary_states=True,
  75. out_units=None,
  76. out_norm=False,
  77. out_residual=False,
  78. include_batchnorm=False,
  79. regularization_weight=0.0,
  80. stride=1,
  81. tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
  82. tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
  83. ):
  84. assert check_argument_types()
  85. super().__init__()
  86. self._output_size = num_units
  87. self.num_layers = num_layers
  88. self.input_units = input_units
  89. self.num_units = num_units
  90. self.kernel_size = kernel_size
  91. self.dropout_rate = dropout_rate
  92. self.position_encoder = position_encoder
  93. self.out_units = out_units
  94. self.auxiliary_states = auxiliary_states
  95. self.out_norm = out_norm
  96. self.activation = activation
  97. self.out_residual = out_residual
  98. self.include_batch_norm = include_batchnorm
  99. self.regularization_weight = regularization_weight
  100. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  101. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  102. if isinstance(stride, int):
  103. self.stride = [stride] * self.num_layers
  104. else:
  105. self.stride = stride
  106. self.downsample_rate = 1
  107. for s in self.stride:
  108. self.downsample_rate *= s
  109. self.dropout = nn.Dropout(dropout_rate)
  110. self.cnn_a = repeat(
  111. self.num_layers,
  112. lambda lnum: EncoderLayer(
  113. input_units if lnum == 0 else num_units,
  114. num_units,
  115. kernel_size,
  116. activation,
  117. self.stride[lnum],
  118. include_batchnorm,
  119. residual=True if lnum > 0 else False
  120. )
  121. )
  122. if self.out_units is not None:
  123. left_padding = math.ceil((kernel_size - stride) / 2)
  124. right_padding = kernel_size - stride - left_padding
  125. self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  126. self.conv_out = nn.Conv1d(
  127. num_units,
  128. out_units,
  129. kernel_size,
  130. )
  131. if self.out_norm:
  132. self.after_norm = LayerNorm(out_units)
  133. def output_size(self) -> int:
  134. return self.num_units
  135. def forward(
  136. self,
  137. xs_pad: torch.Tensor,
  138. ilens: torch.Tensor,
  139. prev_states: torch.Tensor = None,
  140. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  141. inputs = xs_pad
  142. if self.position_encoder is not None:
  143. inputs = self.position_encoder(inputs)
  144. if self.dropout_rate > 0:
  145. inputs = self.dropout(inputs)
  146. outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
  147. if self.out_units is not None:
  148. outputs = self.conv_out(self.out_padding(outputs))
  149. outputs = outputs.transpose(1, 2)
  150. if self.out_norm:
  151. outputs = self.after_norm(outputs)
  152. if self.out_residual:
  153. outputs = outputs + inputs
  154. return outputs, ilens, None
  155. def gen_tf2torch_map_dict(self):
  156. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  157. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  158. map_dict_local = {
  159. # torch: conv1d.weight in "out_channel in_channel kernel_size"
  160. # tf : conv1d.weight in "kernel_size in_channel out_channel"
  161. # torch: linear.weight in "out_channel in_channel"
  162. # tf : dense.weight in "in_channel out_channel"
  163. "{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
  164. {"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
  165. "squeeze": None,
  166. "transpose": (2, 1, 0),
  167. },
  168. "{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
  169. {"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
  170. "squeeze": None,
  171. "transpose": None,
  172. },
  173. "{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
  174. {"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
  175. "squeeze": None,
  176. "transpose": (2, 1, 0),
  177. },
  178. "{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
  179. {"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
  180. "squeeze": None,
  181. "transpose": None,
  182. },
  183. }
  184. if self.out_units is not None:
  185. # add output layer
  186. map_dict_local.update({
  187. "{}.conv_out.weight".format(tensor_name_prefix_torch):
  188. {"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
  189. "squeeze": None,
  190. "transpose": (2, 1, 0),
  191. }, # tf: (1, 256, 256) -> torch: (256, 256, 1)
  192. "{}.conv_out.bias".format(tensor_name_prefix_torch):
  193. {"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
  194. "squeeze": None,
  195. "transpose": None,
  196. }, # tf: (256,) -> torch: (256,)
  197. })
  198. return map_dict_local
  199. def convert_tf2torch(self,
  200. var_dict_tf,
  201. var_dict_torch,
  202. ):
  203. map_dict = self.gen_tf2torch_map_dict()
  204. var_dict_torch_update = dict()
  205. for name in sorted(var_dict_torch.keys(), reverse=False):
  206. if name.startswith(self.tf2torch_tensor_name_prefix_torch):
  207. # process special (first and last) layers
  208. if name in map_dict:
  209. name_tf = map_dict[name]["name"]
  210. data_tf = var_dict_tf[name_tf]
  211. if map_dict[name]["squeeze"] is not None:
  212. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  213. if map_dict[name]["transpose"] is not None:
  214. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  215. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  216. assert var_dict_torch[name].size() == data_tf.size(), \
  217. "{}, {}, {} != {}".format(name, name_tf,
  218. var_dict_torch[name].size(), data_tf.size())
  219. var_dict_torch_update[name] = data_tf
  220. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  221. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  222. ))
  223. # process general layers
  224. else:
  225. # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
  226. names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
  227. layeridx = int(names[2])
  228. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  229. if name_q in map_dict.keys():
  230. name_v = map_dict[name_q]["name"]
  231. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  232. data_tf = var_dict_tf[name_tf]
  233. if map_dict[name_q]["squeeze"] is not None:
  234. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  235. if map_dict[name_q]["transpose"] is not None:
  236. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  237. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  238. assert var_dict_torch[name].size() == data_tf.size(), \
  239. "{}, {}, {} != {}".format(name, name_tf,
  240. var_dict_torch[name].size(), data_tf.size())
  241. var_dict_torch_update[name] = data_tf
  242. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  243. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  244. ))
  245. else:
  246. logging.warning("{} is missed from tf checkpoint".format(name))
  247. return var_dict_torch_update