transformer_decoder.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. from funasr.export import models
  3. import torch
  4. import torch.nn as nn
  5. from funasr.export.utils.torch_function import MakePadMask
  6. from funasr.export.utils.torch_function import sequence_mask
  7. from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
  8. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
  9. from funasr.modules.attention import MultiHeadedAttentionCrossAtt, MultiHeadedAttention
  10. from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
  11. from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
  12. from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  13. from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
  14. from funasr.export.models.modules.decoder_layer import DecoderLayer as DecoderLayer_export
  15. class ParaformerDecoderSAN(nn.Module):
  16. def __init__(self, model,
  17. max_seq_len=512,
  18. model_name='decoder',
  19. onnx: bool = True,):
  20. super().__init__()
  21. # self.embed = model.embed #Embedding(model.embed, max_seq_len)
  22. self.model = model
  23. if onnx:
  24. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  25. else:
  26. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  27. for i, d in enumerate(self.model.decoders):
  28. if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
  29. d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
  30. if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
  31. d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
  32. # if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
  33. # d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
  34. if isinstance(d.src_attn, MultiHeadedAttention):
  35. d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
  36. self.model.decoders[i] = DecoderLayer_export(d)
  37. self.output_layer = model.output_layer
  38. self.after_norm = model.after_norm
  39. self.model_name = model_name
  40. def prepare_mask(self, mask):
  41. mask_3d_btd = mask[:, :, None]
  42. if len(mask.shape) == 2:
  43. mask_4d_bhlt = 1 - mask[:, None, None, :]
  44. elif len(mask.shape) == 3:
  45. mask_4d_bhlt = 1 - mask[:, None, :]
  46. mask_4d_bhlt = mask_4d_bhlt * -10000.0
  47. return mask_3d_btd, mask_4d_bhlt
  48. def forward(
  49. self,
  50. hs_pad: torch.Tensor,
  51. hlens: torch.Tensor,
  52. ys_in_pad: torch.Tensor,
  53. ys_in_lens: torch.Tensor,
  54. ):
  55. tgt = ys_in_pad
  56. tgt_mask = self.make_pad_mask(ys_in_lens)
  57. tgt_mask, _ = self.prepare_mask(tgt_mask)
  58. # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  59. memory = hs_pad
  60. memory_mask = self.make_pad_mask(hlens)
  61. _, memory_mask = self.prepare_mask(memory_mask)
  62. # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  63. x = tgt
  64. x, tgt_mask, memory, memory_mask = self.model.decoders(
  65. x, tgt_mask, memory, memory_mask
  66. )
  67. x = self.after_norm(x)
  68. x = self.output_layer(x)
  69. return x, ys_in_lens
  70. def get_dummy_inputs(self, enc_size):
  71. tgt = torch.LongTensor([0]).unsqueeze(0)
  72. memory = torch.randn(1, 100, enc_size)
  73. pre_acoustic_embeds = torch.randn(1, 1, enc_size)
  74. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  75. cache = [
  76. torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
  77. for _ in range(cache_num)
  78. ]
  79. return (tgt, memory, pre_acoustic_embeds, cache)
  80. def is_optimizable(self):
  81. return True
  82. def get_input_names(self):
  83. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  84. return ['tgt', 'memory', 'pre_acoustic_embeds'] \
  85. + ['cache_%d' % i for i in range(cache_num)]
  86. def get_output_names(self):
  87. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  88. return ['y'] \
  89. + ['out_cache_%d' % i for i in range(cache_num)]
  90. def get_dynamic_axes(self):
  91. ret = {
  92. 'tgt': {
  93. 0: 'tgt_batch',
  94. 1: 'tgt_length'
  95. },
  96. 'memory': {
  97. 0: 'memory_batch',
  98. 1: 'memory_length'
  99. },
  100. 'pre_acoustic_embeds': {
  101. 0: 'acoustic_embeds_batch',
  102. 1: 'acoustic_embeds_length',
  103. }
  104. }
  105. cache_num = len(self.model.decoders) + len(self.model.decoders2)
  106. ret.update({
  107. 'cache_%d' % d: {
  108. 0: 'cache_%d_batch' % d,
  109. 2: 'cache_%d_length' % d
  110. }
  111. for d in range(cache_num)
  112. })
  113. return ret
  114. def get_model_config(self, path):
  115. return {
  116. "dec_type": "XformerDecoder",
  117. "model_path": os.path.join(path, f'{self.model_name}.onnx'),
  118. "n_layers": len(self.model.decoders) + len(self.model.decoders2),
  119. "odim": self.model.decoders[0].size
  120. }