sanm_decoder.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from funasr.export.utils.torch_function import MakePadMask
  5. from funasr.export.utils.torch_function import sequence_mask
  6. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
  7. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
  8. from funasr.modules.attention import MultiHeadedAttentionCrossAtt
  9. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
  10. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  11. from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
  12. from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
  13. class ParaformerSANMDecoder(nn.Module):
  14. def __init__(self, model,
  15. max_seq_len=512,
  16. model_name='decoder',
  17. onnx: bool = True,):
  18. super().__init__()
  19. # self.embed = model.embed #Embedding(model.embed, max_seq_len)
  20. self.model = model
  21. if onnx:
  22. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  23. else:
  24. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  25. for i, d in enumerate(self.model.decoders):
  26. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  27. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  28. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  29. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  30. if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
  31. d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
  32. self.model.decoders[i] = DecoderLayerSANM_export(d)
  33. if self.model.decoders2 is not None:
  34. for i, d in enumerate(self.model.decoders2):
  35. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  36. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  37. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  38. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  39. self.model.decoders2[i] = DecoderLayerSANM_export(d)
  40. for i, d in enumerate(self.model.decoders3):
  41. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  42. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  43. self.model.decoders3[i] = DecoderLayerSANM_export(d)
  44. self.output_layer = model.output_layer
  45. self.after_norm = model.after_norm
  46. self.model_name = model_name
  47. def prepare_mask(self, mask):
  48. mask_3d_btd = mask[:, :, None]
  49. if len(mask.shape) == 2:
  50. mask_4d_bhlt = 1 - mask[:, None, None, :]
  51. elif len(mask.shape) == 3:
  52. mask_4d_bhlt = 1 - mask[:, None, :]
  53. mask_4d_bhlt = mask_4d_bhlt * -10000.0
  54. return mask_3d_btd, mask_4d_bhlt
  55. def forward(
  56. self,
  57. hs_pad: torch.Tensor,
  58. hlens: torch.Tensor,
  59. ys_in_pad: torch.Tensor,
  60. ys_in_lens: torch.Tensor,
  61. ):
  62. tgt = ys_in_pad
  63. tgt_mask = self.make_pad_mask(ys_in_lens)
  64. tgt_mask, _ = self.prepare_mask(tgt_mask)
  65. # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  66. memory = hs_pad
  67. memory_mask = self.make_pad_mask(hlens)
  68. _, memory_mask = self.prepare_mask(memory_mask)
  69. # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  70. x = tgt
  71. x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
  72. x, tgt_mask, memory, memory_mask
  73. )
  74. if self.model.decoders2 is not None:
  75. x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
  76. x, tgt_mask, memory, memory_mask
  77. )
  78. x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
  79. x, tgt_mask, memory, memory_mask
  80. )
  81. x = self.after_norm(x)
  82. x = self.output_layer(x)
  83. return x, ys_in_lens
  84. def get_dummy_inputs(self, enc_size):
  85. tgt = torch.LongTensor([0]).unsqueeze(0)
  86. memory = torch.randn(1, 100, enc_size)
  87. pre_acoustic_embeds = torch.randn(1, 1, enc_size)
  88. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  89. cache = [
  90. torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
  91. for _ in range(cache_num)
  92. ]
  93. return (tgt, memory, pre_acoustic_embeds, cache)
  94. def is_optimizable(self):
  95. return True
  96. def get_input_names(self):
  97. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  98. return ['tgt', 'memory', 'pre_acoustic_embeds'] \
  99. + ['cache_%d' % i for i in range(cache_num)]
  100. def get_output_names(self):
  101. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  102. return ['y'] \
  103. + ['out_cache_%d' % i for i in range(cache_num)]
  104. def get_dynamic_axes(self):
  105. ret = {
  106. 'tgt': {
  107. 0: 'tgt_batch',
  108. 1: 'tgt_length'
  109. },
  110. 'memory': {
  111. 0: 'memory_batch',
  112. 1: 'memory_length'
  113. },
  114. 'pre_acoustic_embeds': {
  115. 0: 'acoustic_embeds_batch',
  116. 1: 'acoustic_embeds_length',
  117. }
  118. }
  119. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  120. ret.update({
  121. 'cache_%d' % d: {
  122. 0: 'cache_%d_batch' % d,
  123. 2: 'cache_%d_length' % d
  124. }
  125. for d in range(cache_num)
  126. })
  127. return ret
  128. def get_model_config(self, path):
  129. return {
  130. "dec_type": "XformerDecoder",
  131. "model_path": os.path.join(path, f'{self.model_name}.onnx'),
  132. "n_layers": len(self.model.decoders) + len(self.model.decoders2),
  133. "odim": self.model.decoders[0].size
  134. }