sanm_decoder.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. # from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
  5. from funasr.export.utils.torch_function import MakePadMask
  6. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
  7. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
  8. from funasr.modules.attention import MultiHeadedAttentionCrossAtt
  9. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
  10. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  11. from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
  12. from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
  13. class ParaformerSANMDecoder(nn.Module):
  14. def __init__(self, model,
  15. max_seq_len=512,
  16. model_name='decoder'):
  17. super().__init__()
  18. # self.embed = model.embed #Embedding(model.embed, max_seq_len)
  19. self.model = model
  20. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  21. for i, d in enumerate(self.model.decoders):
  22. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  23. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  24. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  25. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  26. if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
  27. d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
  28. self.model.decoders[i] = DecoderLayerSANM_export(d)
  29. if self.model.decoders2 is not None:
  30. for i, d in enumerate(self.model.decoders2):
  31. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  32. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  33. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  34. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  35. self.model.decoders2[i] = DecoderLayerSANM_export(d)
  36. for i, d in enumerate(self.model.decoders3):
  37. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  38. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  39. self.model.decoders3[i] = DecoderLayerSANM_export(d)
  40. self.output_layer = model.output_layer
  41. self.after_norm = model.after_norm
  42. self.model_name = model_name
  43. def prepare_mask(self, mask):
  44. mask_3d_btd = mask[:, :, None]
  45. if len(mask.shape) == 2:
  46. mask_4d_bhlt = 1 - mask[:, None, None, :]
  47. elif len(mask.shape) == 3:
  48. mask_4d_bhlt = 1 - mask[:, None, :]
  49. mask_4d_bhlt = mask_4d_bhlt * -10000.0
  50. return mask_3d_btd, mask_4d_bhlt
  51. def forward(
  52. self,
  53. hs_pad: torch.Tensor,
  54. hlens: torch.Tensor,
  55. ys_in_pad: torch.Tensor,
  56. ys_in_lens: torch.Tensor,
  57. ):
  58. tgt = ys_in_pad
  59. tgt_mask = self.make_pad_mask(ys_in_lens)
  60. tgt_mask, _ = self.prepare_mask(tgt_mask)
  61. # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  62. memory = hs_pad
  63. memory_mask = self.make_pad_mask(hlens)
  64. _, memory_mask = self.prepare_mask(memory_mask)
  65. # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  66. x = tgt
  67. x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
  68. x, tgt_mask, memory, memory_mask
  69. )
  70. if self.model.decoders2 is not None:
  71. x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
  72. x, tgt_mask, memory, memory_mask
  73. )
  74. x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
  75. x, tgt_mask, memory, memory_mask
  76. )
  77. x = self.after_norm(x)
  78. x = self.output_layer(x)
  79. return x, ys_in_lens
  80. def get_dummy_inputs(self, enc_size):
  81. tgt = torch.LongTensor([0]).unsqueeze(0)
  82. memory = torch.randn(1, 100, enc_size)
  83. pre_acoustic_embeds = torch.randn(1, 1, enc_size)
  84. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  85. cache = [
  86. torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
  87. for _ in range(cache_num)
  88. ]
  89. return (tgt, memory, pre_acoustic_embeds, cache)
  90. def is_optimizable(self):
  91. return True
  92. def get_input_names(self):
  93. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  94. return ['tgt', 'memory', 'pre_acoustic_embeds'] \
  95. + ['cache_%d' % i for i in range(cache_num)]
  96. def get_output_names(self):
  97. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  98. return ['y'] \
  99. + ['out_cache_%d' % i for i in range(cache_num)]
  100. def get_dynamic_axes(self):
  101. ret = {
  102. 'tgt': {
  103. 0: 'tgt_batch',
  104. 1: 'tgt_length'
  105. },
  106. 'memory': {
  107. 0: 'memory_batch',
  108. 1: 'memory_length'
  109. },
  110. 'pre_acoustic_embeds': {
  111. 0: 'acoustic_embeds_batch',
  112. 1: 'acoustic_embeds_length',
  113. }
  114. }
  115. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  116. ret.update({
  117. 'cache_%d' % d: {
  118. 0: 'cache_%d_batch' % d,
  119. 2: 'cache_%d_length' % d
  120. }
  121. for d in range(cache_num)
  122. })
  123. return ret
  124. def get_model_config(self, path):
  125. return {
  126. "dec_type": "XformerDecoder",
  127. "model_path": os.path.join(path, f'{self.model_name}.onnx'),
  128. "n_layers": len(self.model.decoders) + len(self.model.decoders2),
  129. "odim": self.model.decoders[0].size
  130. }