e2e_asr_contextual_paraformer.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from audioop import bias
  2. import logging
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. from funasr.export.utils.torch_function import MakePadMask
  7. from funasr.export.utils.torch_function import sequence_mask
  8. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  9. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  10. from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
  11. from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
  12. from funasr.models.predictor.cif import CifPredictorV2
  13. from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
  14. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
  15. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  16. from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
  17. from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
  18. from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export
  19. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  20. class ContextualParaformer_backbone(nn.Module):
  21. """
  22. Author: Speech Lab of DAMO Academy, Alibaba Group
  23. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  24. https://arxiv.org/abs/2206.08317
  25. """
  26. def __init__(
  27. self,
  28. model,
  29. max_seq_len=512,
  30. feats_dim=560,
  31. model_name='model',
  32. **kwargs,
  33. ):
  34. super().__init__()
  35. onnx = False
  36. if "onnx" in kwargs:
  37. onnx = kwargs["onnx"]
  38. if isinstance(model.encoder, SANMEncoder):
  39. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  40. elif isinstance(model.encoder, ConformerEncoder):
  41. self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
  42. if isinstance(model.predictor, CifPredictorV2):
  43. self.predictor = CifPredictorV2_export(model.predictor)
  44. # decoder
  45. if isinstance(model.decoder, ContextualParaformerDecoder):
  46. self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx)
  47. elif isinstance(model.decoder, ParaformerSANMDecoder):
  48. self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
  49. elif isinstance(model.decoder, ParaformerDecoderSAN):
  50. self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
  51. self.feats_dim = feats_dim
  52. self.model_name = model_name
  53. if onnx:
  54. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  55. else:
  56. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  57. def forward(
  58. self,
  59. speech: torch.Tensor,
  60. speech_lengths: torch.Tensor,
  61. bias_embed: torch.Tensor,
  62. ):
  63. # a. To device
  64. batch = {"speech": speech, "speech_lengths": speech_lengths}
  65. # batch = to_device(batch, device=self.device)
  66. enc, enc_len = self.encoder(**batch)
  67. mask = self.make_pad_mask(enc_len)[:, None, :]
  68. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
  69. pre_token_length = pre_token_length.floor().type(torch.int32)
  70. # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1])
  71. decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
  72. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  73. # sample_ids = decoder_out.argmax(dim=-1)
  74. return decoder_out, pre_token_length
  75. def get_dummy_inputs(self):
  76. speech = torch.randn(2, 30, self.feats_dim)
  77. speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
  78. bias_embed = torch.randn(2, 1, 512)
  79. return (speech, speech_lengths, bias_embed)
  80. def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
  81. import numpy as np
  82. fbank = np.loadtxt(txt_file)
  83. fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
  84. speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
  85. speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
  86. return (speech, speech_lengths)
  87. def get_input_names(self):
  88. return ['speech', 'speech_lengths', 'bias_embed']
  89. def get_output_names(self):
  90. return ['logits', 'token_num']
  91. def get_dynamic_axes(self):
  92. return {
  93. 'speech': {
  94. 0: 'batch_size',
  95. 1: 'feats_length'
  96. },
  97. 'speech_lengths': {
  98. 0: 'batch_size',
  99. },
  100. 'bias_embed': {
  101. 0: 'batch_size',
  102. 1: 'num_hotwords'
  103. },
  104. 'logits': {
  105. 0: 'batch_size',
  106. 1: 'logits_length'
  107. },
  108. }
  109. class ContextualParaformer_embedder(nn.Module):
  110. def __init__(self,
  111. model,
  112. max_seq_len=512,
  113. feats_dim=560,
  114. model_name='model',
  115. **kwargs,):
  116. super().__init__()
  117. self.embedding = model.bias_embed
  118. model.bias_encoder.batch_first = False
  119. self.bias_encoder = model.bias_encoder
  120. # self.bias_encoder.batch_first = False
  121. self.feats_dim = feats_dim
  122. self.model_name = "{}_eb".format(model_name)
  123. def forward(self, hotword):
  124. hotword = self.embedding(hotword).transpose(0, 1) # batch second
  125. hw_embed, (_, _) = self.bias_encoder(hotword)
  126. return hw_embed
  127. def get_dummy_inputs(self):
  128. hotword = torch.tensor([
  129. [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
  130. [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
  131. [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  132. [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
  133. [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
  134. [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  135. ],
  136. dtype=torch.int32)
  137. # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
  138. return (hotword)
  139. def get_input_names(self):
  140. return ['hotword']
  141. def get_output_names(self):
  142. return ['hw_embed']
  143. def get_dynamic_axes(self):
  144. return {
  145. 'hotword': {
  146. 0: 'num_hotwords',
  147. },
  148. 'hw_embed': {
  149. 0: 'num_hotwords',
  150. },
  151. }