sanm_decoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from funasr.export.utils.torch_function import MakePadMask
  5. from funasr.export.utils.torch_function import sequence_mask
  6. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
  7. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
  8. from funasr.modules.attention import MultiHeadedAttentionCrossAtt
  9. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
  10. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  11. from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
  12. from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
  13. class ParaformerSANMDecoder(nn.Module):
  14. def __init__(self, model,
  15. max_seq_len=512,
  16. model_name='decoder',
  17. onnx: bool = True,):
  18. super().__init__()
  19. # self.embed = model.embed #Embedding(model.embed, max_seq_len)
  20. self.model = model
  21. if onnx:
  22. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  23. else:
  24. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  25. for i, d in enumerate(self.model.decoders):
  26. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  27. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  28. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  29. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  30. if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
  31. d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
  32. self.model.decoders[i] = DecoderLayerSANM_export(d)
  33. if self.model.decoders2 is not None:
  34. for i, d in enumerate(self.model.decoders2):
  35. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  36. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  37. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  38. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  39. self.model.decoders2[i] = DecoderLayerSANM_export(d)
  40. for i, d in enumerate(self.model.decoders3):
  41. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  42. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  43. self.model.decoders3[i] = DecoderLayerSANM_export(d)
  44. self.output_layer = model.output_layer
  45. self.after_norm = model.after_norm
  46. self.model_name = model_name
  47. def prepare_mask(self, mask):
  48. mask_3d_btd = mask[:, :, None]
  49. if len(mask.shape) == 2:
  50. mask_4d_bhlt = 1 - mask[:, None, None, :]
  51. elif len(mask.shape) == 3:
  52. mask_4d_bhlt = 1 - mask[:, None, :]
  53. mask_4d_bhlt = mask_4d_bhlt * -10000.0
  54. return mask_3d_btd, mask_4d_bhlt
  55. def forward(
  56. self,
  57. hs_pad: torch.Tensor,
  58. hlens: torch.Tensor,
  59. ys_in_pad: torch.Tensor,
  60. ys_in_lens: torch.Tensor,
  61. ):
  62. tgt = ys_in_pad
  63. tgt_mask = self.make_pad_mask(ys_in_lens)
  64. tgt_mask, _ = self.prepare_mask(tgt_mask)
  65. # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  66. memory = hs_pad
  67. memory_mask = self.make_pad_mask(hlens)
  68. _, memory_mask = self.prepare_mask(memory_mask)
  69. # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  70. x = tgt
  71. x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
  72. x, tgt_mask, memory, memory_mask
  73. )
  74. if self.model.decoders2 is not None:
  75. x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
  76. x, tgt_mask, memory, memory_mask
  77. )
  78. x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
  79. x, tgt_mask, memory, memory_mask
  80. )
  81. x = self.after_norm(x)
  82. x = self.output_layer(x)
  83. return x, ys_in_lens
  84. def get_dummy_inputs(self, enc_size):
  85. tgt = torch.LongTensor([0]).unsqueeze(0)
  86. memory = torch.randn(1, 100, enc_size)
  87. pre_acoustic_embeds = torch.randn(1, 1, enc_size)
  88. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  89. cache = [
  90. torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
  91. for _ in range(cache_num)
  92. ]
  93. return (tgt, memory, pre_acoustic_embeds, cache)
  94. def is_optimizable(self):
  95. return True
  96. def get_input_names(self):
  97. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  98. return ['tgt', 'memory', 'pre_acoustic_embeds'] \
  99. + ['cache_%d' % i for i in range(cache_num)]
  100. def get_output_names(self):
  101. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  102. return ['y'] \
  103. + ['out_cache_%d' % i for i in range(cache_num)]
  104. def get_dynamic_axes(self):
  105. ret = {
  106. 'tgt': {
  107. 0: 'tgt_batch',
  108. 1: 'tgt_length'
  109. },
  110. 'memory': {
  111. 0: 'memory_batch',
  112. 1: 'memory_length'
  113. },
  114. 'pre_acoustic_embeds': {
  115. 0: 'acoustic_embeds_batch',
  116. 1: 'acoustic_embeds_length',
  117. }
  118. }
  119. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  120. ret.update({
  121. 'cache_%d' % d: {
  122. 0: 'cache_%d_batch' % d,
  123. 2: 'cache_%d_length' % d
  124. }
  125. for d in range(cache_num)
  126. })
  127. return ret
  128. def get_model_config(self, path):
  129. return {
  130. "dec_type": "XformerDecoder",
  131. "model_path": os.path.join(path, f'{self.model_name}.onnx'),
  132. "n_layers": len(self.model.decoders) + len(self.model.decoders2),
  133. "odim": self.model.decoders[0].size
  134. }
  135. class ParaformerSANMDecoderOnline(nn.Module):
  136. def __init__(self, model,
  137. max_seq_len=512,
  138. model_name='decoder',
  139. onnx: bool = True, ):
  140. super().__init__()
  141. # self.embed = model.embed #Embedding(model.embed, max_seq_len)
  142. self.model = model
  143. if onnx:
  144. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  145. else:
  146. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  147. for i, d in enumerate(self.model.decoders):
  148. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  149. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  150. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  151. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  152. if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
  153. d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
  154. self.model.decoders[i] = DecoderLayerSANM_export(d)
  155. if self.model.decoders2 is not None:
  156. for i, d in enumerate(self.model.decoders2):
  157. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  158. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  159. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  160. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  161. self.model.decoders2[i] = DecoderLayerSANM_export(d)
  162. for i, d in enumerate(self.model.decoders3):
  163. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  164. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  165. self.model.decoders3[i] = DecoderLayerSANM_export(d)
  166. self.output_layer = model.output_layer
  167. self.after_norm = model.after_norm
  168. self.model_name = model_name
  169. def prepare_mask(self, mask):
  170. mask_3d_btd = mask[:, :, None]
  171. if len(mask.shape) == 2:
  172. mask_4d_bhlt = 1 - mask[:, None, None, :]
  173. elif len(mask.shape) == 3:
  174. mask_4d_bhlt = 1 - mask[:, None, :]
  175. mask_4d_bhlt = mask_4d_bhlt * -10000.0
  176. return mask_3d_btd, mask_4d_bhlt
  177. def forward(
  178. self,
  179. hs_pad: torch.Tensor,
  180. hlens: torch.Tensor,
  181. ys_in_pad: torch.Tensor,
  182. ys_in_lens: torch.Tensor,
  183. *args,
  184. ):
  185. tgt = ys_in_pad
  186. tgt_mask = self.make_pad_mask(ys_in_lens)
  187. tgt_mask, _ = self.prepare_mask(tgt_mask)
  188. # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  189. memory = hs_pad
  190. memory_mask = self.make_pad_mask(hlens)
  191. _, memory_mask = self.prepare_mask(memory_mask)
  192. # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  193. x = tgt
  194. out_caches = list()
  195. for i, decoder in enumerate(self.model.decoders):
  196. in_cache = args[i]
  197. x, tgt_mask, memory, memory_mask, out_cache = decoder(
  198. x, tgt_mask, memory, memory_mask, cache=in_cache
  199. )
  200. out_caches.append(out_cache)
  201. if self.model.decoders2 is not None:
  202. for i, decoder in enumerate(self.model.decoders2):
  203. in_cache = args[i+len(self.model.decoders)]
  204. x, tgt_mask, memory, memory_mask, out_cache = decoder(
  205. x, tgt_mask, memory, memory_mask, cache=in_cache
  206. )
  207. out_caches.append(out_cache)
  208. x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
  209. x, tgt_mask, memory, memory_mask
  210. )
  211. x = self.after_norm(x)
  212. x = self.output_layer(x)
  213. return x, out_caches
  214. def get_dummy_inputs(self, enc_size):
  215. enc = torch.randn(2, 100, enc_size).type(torch.float32)
  216. enc_len = torch.tensor([30, 100], dtype=torch.int32)
  217. acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
  218. acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
  219. cache_num = len(self.model.decoders)
  220. if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
  221. cache_num += len(self.model.decoders2)
  222. cache = [
  223. torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size-1), dtype=torch.float32)
  224. for _ in range(cache_num)
  225. ]
  226. return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
  227. def get_input_names(self):
  228. cache_num = len(self.model.decoders)
  229. if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
  230. cache_num += len(self.model.decoders2)
  231. return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
  232. + ['in_cache_%d' % i for i in range(cache_num)]
  233. def get_output_names(self):
  234. cache_num = len(self.model.decoders)
  235. if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
  236. cache_num += len(self.model.decoders2)
  237. return ['logits', 'sample_ids'] \
  238. + ['out_cache_%d' % i for i in range(cache_num)]
  239. def get_dynamic_axes(self):
  240. ret = {
  241. 'enc': {
  242. 0: 'batch_size',
  243. 1: 'enc_length'
  244. },
  245. 'acoustic_embeds': {
  246. 0: 'batch_size',
  247. 1: 'token_length'
  248. },
  249. 'enc_len': {
  250. 0: 'batch_size',
  251. },
  252. 'acoustic_embeds_len': {
  253. 0: 'batch_size',
  254. },
  255. }
  256. cache_num = len(self.model.decoders)
  257. if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
  258. cache_num += len(self.model.decoders2)
  259. ret.update({
  260. 'in_cache_%d' % d: {
  261. 0: 'batch_size',
  262. }
  263. for d in range(cache_num)
  264. })
  265. ret.update({
  266. 'out_cache_%d' % d: {
  267. 0: 'batch_size',
  268. }
  269. for d in range(cache_num)
  270. })
  271. return ret