conv_encoder.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. from torch.nn import functional as F
  10. import numpy as np
  11. from funasr.modules.nets_utils import make_pad_mask
  12. from funasr.modules.layer_norm import LayerNorm
  13. from funasr.models.encoder.abs_encoder import AbsEncoder
  14. import math
  15. from funasr.modules.repeat import repeat
  16. class EncoderLayer(nn.Module):
  17. def __init__(
  18. self,
  19. input_units,
  20. num_units,
  21. kernel_size=3,
  22. activation="tanh",
  23. stride=1,
  24. include_batch_norm=False,
  25. residual=False
  26. ):
  27. super().__init__()
  28. left_padding = math.ceil((kernel_size - stride) / 2)
  29. right_padding = kernel_size - stride - left_padding
  30. self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  31. self.conv1d = nn.Conv1d(
  32. input_units,
  33. num_units,
  34. kernel_size,
  35. stride,
  36. )
  37. self.activation = self.get_activation(activation)
  38. if include_batch_norm:
  39. self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
  40. self.residual = residual
  41. self.include_batch_norm = include_batch_norm
  42. self.input_units = input_units
  43. self.num_units = num_units
  44. self.stride = stride
  45. @staticmethod
  46. def get_activation(activation):
  47. if activation == "tanh":
  48. return nn.Tanh()
  49. else:
  50. return nn.ReLU()
  51. def forward(self, xs_pad, ilens=None):
  52. outputs = self.conv1d(self.conv_padding(xs_pad))
  53. if self.residual and self.stride == 1 and self.input_units == self.num_units:
  54. outputs = outputs + xs_pad
  55. if self.include_batch_norm:
  56. outputs = self.bn(outputs)
  57. # add parenthesis for repeat module
  58. return self.activation(outputs), ilens
  59. class ConvEncoder(AbsEncoder):
  60. """
  61. Author: Speech Lab of DAMO Academy, Alibaba Group
  62. Convolution encoder in OpenNMT framework
  63. """
  64. def __init__(
  65. self,
  66. num_layers,
  67. input_units,
  68. num_units,
  69. kernel_size=3,
  70. dropout_rate=0.3,
  71. position_encoder=None,
  72. activation='tanh',
  73. auxiliary_states=True,
  74. out_units=None,
  75. out_norm=False,
  76. out_residual=False,
  77. include_batchnorm=False,
  78. regularization_weight=0.0,
  79. stride=1,
  80. tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
  81. tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
  82. ):
  83. super().__init__()
  84. self._output_size = num_units
  85. self.num_layers = num_layers
  86. self.input_units = input_units
  87. self.num_units = num_units
  88. self.kernel_size = kernel_size
  89. self.dropout_rate = dropout_rate
  90. self.position_encoder = position_encoder
  91. self.out_units = out_units
  92. self.auxiliary_states = auxiliary_states
  93. self.out_norm = out_norm
  94. self.activation = activation
  95. self.out_residual = out_residual
  96. self.include_batch_norm = include_batchnorm
  97. self.regularization_weight = regularization_weight
  98. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  99. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  100. if isinstance(stride, int):
  101. self.stride = [stride] * self.num_layers
  102. else:
  103. self.stride = stride
  104. self.downsample_rate = 1
  105. for s in self.stride:
  106. self.downsample_rate *= s
  107. self.dropout = nn.Dropout(dropout_rate)
  108. self.cnn_a = repeat(
  109. self.num_layers,
  110. lambda lnum: EncoderLayer(
  111. input_units if lnum == 0 else num_units,
  112. num_units,
  113. kernel_size,
  114. activation,
  115. self.stride[lnum],
  116. include_batchnorm,
  117. residual=True if lnum > 0 else False
  118. )
  119. )
  120. if self.out_units is not None:
  121. left_padding = math.ceil((kernel_size - stride) / 2)
  122. right_padding = kernel_size - stride - left_padding
  123. self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
  124. self.conv_out = nn.Conv1d(
  125. num_units,
  126. out_units,
  127. kernel_size,
  128. )
  129. if self.out_norm:
  130. self.after_norm = LayerNorm(out_units)
  131. def output_size(self) -> int:
  132. return self.num_units
  133. def forward(
  134. self,
  135. xs_pad: torch.Tensor,
  136. ilens: torch.Tensor,
  137. prev_states: torch.Tensor = None,
  138. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  139. inputs = xs_pad
  140. if self.position_encoder is not None:
  141. inputs = self.position_encoder(inputs)
  142. if self.dropout_rate > 0:
  143. inputs = self.dropout(inputs)
  144. outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
  145. if self.out_units is not None:
  146. outputs = self.conv_out(self.out_padding(outputs))
  147. outputs = outputs.transpose(1, 2)
  148. if self.out_norm:
  149. outputs = self.after_norm(outputs)
  150. if self.out_residual:
  151. outputs = outputs + inputs
  152. return outputs, ilens, None
  153. def gen_tf2torch_map_dict(self):
  154. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  155. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  156. map_dict_local = {
  157. # torch: conv1d.weight in "out_channel in_channel kernel_size"
  158. # tf : conv1d.weight in "kernel_size in_channel out_channel"
  159. # torch: linear.weight in "out_channel in_channel"
  160. # tf : dense.weight in "in_channel out_channel"
  161. "{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
  162. {"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
  163. "squeeze": None,
  164. "transpose": (2, 1, 0),
  165. },
  166. "{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
  167. {"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
  168. "squeeze": None,
  169. "transpose": None,
  170. },
  171. "{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
  172. {"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
  173. "squeeze": None,
  174. "transpose": (2, 1, 0),
  175. },
  176. "{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
  177. {"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
  178. "squeeze": None,
  179. "transpose": None,
  180. },
  181. }
  182. if self.out_units is not None:
  183. # add output layer
  184. map_dict_local.update({
  185. "{}.conv_out.weight".format(tensor_name_prefix_torch):
  186. {"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
  187. "squeeze": None,
  188. "transpose": (2, 1, 0),
  189. }, # tf: (1, 256, 256) -> torch: (256, 256, 1)
  190. "{}.conv_out.bias".format(tensor_name_prefix_torch):
  191. {"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
  192. "squeeze": None,
  193. "transpose": None,
  194. }, # tf: (256,) -> torch: (256,)
  195. })
  196. return map_dict_local
  197. def convert_tf2torch(self,
  198. var_dict_tf,
  199. var_dict_torch,
  200. ):
  201. map_dict = self.gen_tf2torch_map_dict()
  202. var_dict_torch_update = dict()
  203. for name in sorted(var_dict_torch.keys(), reverse=False):
  204. if name.startswith(self.tf2torch_tensor_name_prefix_torch):
  205. # process special (first and last) layers
  206. if name in map_dict:
  207. name_tf = map_dict[name]["name"]
  208. data_tf = var_dict_tf[name_tf]
  209. if map_dict[name]["squeeze"] is not None:
  210. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  211. if map_dict[name]["transpose"] is not None:
  212. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  213. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  214. assert var_dict_torch[name].size() == data_tf.size(), \
  215. "{}, {}, {} != {}".format(name, name_tf,
  216. var_dict_torch[name].size(), data_tf.size())
  217. var_dict_torch_update[name] = data_tf
  218. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  219. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  220. ))
  221. # process general layers
  222. else:
  223. # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
  224. names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
  225. layeridx = int(names[2])
  226. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  227. if name_q in map_dict.keys():
  228. name_v = map_dict[name_q]["name"]
  229. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  230. data_tf = var_dict_tf[name_tf]
  231. if map_dict[name_q]["squeeze"] is not None:
  232. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  233. if map_dict[name_q]["transpose"] is not None:
  234. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  235. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  236. assert var_dict_torch[name].size() == data_tf.size(), \
  237. "{}, {}, {} != {}".format(name, name_tf,
  238. var_dict_torch[name].size(), data_tf.size())
  239. var_dict_torch_update[name] = data_tf
  240. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  241. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  242. ))
  243. else:
  244. logging.warning("{} is missed from tf checkpoint".format(name))
  245. return var_dict_torch_update