contextual_decoder.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from funasr.export.utils.torch_function import MakePadMask
  5. from funasr.export.utils.torch_function import sequence_mask
  6. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
  7. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
  8. from funasr.modules.attention import MultiHeadedAttentionCrossAtt
  9. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
  10. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  11. from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
  12. from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
  13. class ContextualSANMDecoder(nn.Module):
  14. def __init__(self, model,
  15. max_seq_len=512,
  16. model_name='decoder',
  17. onnx: bool = True,):
  18. super().__init__()
  19. # self.embed = model.embed #Embedding(model.embed, max_seq_len)
  20. self.model = model
  21. if onnx:
  22. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  23. else:
  24. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  25. for i, d in enumerate(self.model.decoders):
  26. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  27. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  28. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  29. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  30. if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
  31. d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
  32. self.model.decoders[i] = DecoderLayerSANM_export(d)
  33. if self.model.decoders2 is not None:
  34. for i, d in enumerate(self.model.decoders2):
  35. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  36. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  37. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  38. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  39. self.model.decoders2[i] = DecoderLayerSANM_export(d)
  40. for i, d in enumerate(self.model.decoders3):
  41. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  42. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  43. self.model.decoders3[i] = DecoderLayerSANM_export(d)
  44. self.output_layer = model.output_layer
  45. self.after_norm = model.after_norm
  46. self.model_name = model_name
  47. # bias decoder
  48. if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
  49. self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
  50. self.bias_decoder = self.model.bias_decoder
  51. # last decoder
  52. if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
  53. self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
  54. if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
  55. self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
  56. if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
  57. self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
  58. self.last_decoder = self.model.last_decoder
  59. self.bias_output = self.model.bias_output
  60. self.dropout = self.model.dropout
  61. def prepare_mask(self, mask):
  62. mask_3d_btd = mask[:, :, None]
  63. if len(mask.shape) == 2:
  64. mask_4d_bhlt = 1 - mask[:, None, None, :]
  65. elif len(mask.shape) == 3:
  66. mask_4d_bhlt = 1 - mask[:, None, :]
  67. mask_4d_bhlt = mask_4d_bhlt * -10000.0
  68. return mask_3d_btd, mask_4d_bhlt
  69. def forward(
  70. self,
  71. hs_pad: torch.Tensor,
  72. hlens: torch.Tensor,
  73. ys_in_pad: torch.Tensor,
  74. ys_in_lens: torch.Tensor,
  75. bias_embed: torch.Tensor,
  76. ):
  77. tgt = ys_in_pad
  78. tgt_mask = self.make_pad_mask(ys_in_lens)
  79. tgt_mask, _ = self.prepare_mask(tgt_mask)
  80. # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  81. memory = hs_pad
  82. memory_mask = self.make_pad_mask(hlens)
  83. _, memory_mask = self.prepare_mask(memory_mask)
  84. # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  85. x = tgt
  86. x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
  87. x, tgt_mask, memory, memory_mask
  88. )
  89. _, _, x_self_attn, x_src_attn = self.last_decoder(
  90. x, tgt_mask, memory, memory_mask
  91. )
  92. # contextual paraformer related
  93. contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
  94. # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
  95. contextual_mask = self.make_pad_mask(contextual_length)
  96. contextual_mask, _ = self.prepare_mask(contextual_mask)
  97. # import pdb; pdb.set_trace()
  98. contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
  99. cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
  100. if self.bias_output is not None:
  101. x = torch.cat([x_src_attn, cx], dim=2)
  102. x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
  103. x = x_self_attn + self.dropout(x)
  104. if self.model.decoders2 is not None:
  105. x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
  106. x, tgt_mask, memory, memory_mask
  107. )
  108. x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
  109. x, tgt_mask, memory, memory_mask
  110. )
  111. x = self.after_norm(x)
  112. x = self.output_layer(x)
  113. return x, ys_in_lens
  114. def get_dummy_inputs(self, enc_size):
  115. tgt = torch.LongTensor([0]).unsqueeze(0)
  116. memory = torch.randn(1, 100, enc_size)
  117. pre_acoustic_embeds = torch.randn(1, 1, enc_size)
  118. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  119. cache = [
  120. torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
  121. for _ in range(cache_num)
  122. ]
  123. return (tgt, memory, pre_acoustic_embeds, cache)
  124. def is_optimizable(self):
  125. return True
  126. def get_input_names(self):
  127. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  128. return ['tgt', 'memory', 'pre_acoustic_embeds'] \
  129. + ['cache_%d' % i for i in range(cache_num)]
  130. def get_output_names(self):
  131. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  132. return ['y'] \
  133. + ['out_cache_%d' % i for i in range(cache_num)]
  134. def get_dynamic_axes(self):
  135. ret = {
  136. 'tgt': {
  137. 0: 'tgt_batch',
  138. 1: 'tgt_length'
  139. },
  140. 'memory': {
  141. 0: 'memory_batch',
  142. 1: 'memory_length'
  143. },
  144. 'pre_acoustic_embeds': {
  145. 0: 'acoustic_embeds_batch',
  146. 1: 'acoustic_embeds_length',
  147. }
  148. }
  149. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  150. ret.update({
  151. 'cache_%d' % d: {
  152. 0: 'cache_%d_batch' % d,
  153. 2: 'cache_%d_length' % d
  154. }
  155. for d in range(cache_num)
  156. })
  157. return ret
  158. def get_model_config(self, path):
  159. return {
  160. "dec_type": "XformerDecoder",
  161. "model_path": os.path.join(path, f'{self.model_name}.onnx'),
  162. "n_layers": len(self.model.decoders) + len(self.model.decoders2),
  163. "odim": self.model.decoders[0].size
  164. }