e2e_asr_paraformer.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import logging
  2. import torch
  3. import torch.nn as nn
  4. from funasr.export.utils.torch_function import MakePadMask
  5. from funasr.export.utils.torch_function import sequence_mask
  6. from funasr.models.encoder.sanm_encoder import SANMEncoder
  7. from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
  8. from funasr.models.predictor.cif import CifPredictorV2
  9. from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
  10. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
  11. from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
  12. class Paraformer(nn.Module):
  13. """
  14. Author: Speech Lab, Alibaba Group, China
  15. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  16. https://arxiv.org/abs/2206.08317
  17. """
  18. def __init__(
  19. self,
  20. model,
  21. max_seq_len=512,
  22. feats_dim=560,
  23. model_name='model',
  24. **kwargs,
  25. ):
  26. super().__init__()
  27. onnx = False
  28. if "onnx" in kwargs:
  29. onnx = kwargs["onnx"]
  30. if isinstance(model.encoder, SANMEncoder):
  31. self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
  32. if isinstance(model.predictor, CifPredictorV2):
  33. self.predictor = CifPredictorV2_export(model.predictor)
  34. if isinstance(model.decoder, ParaformerSANMDecoder):
  35. self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
  36. self.feats_dim = feats_dim
  37. self.model_name = model_name
  38. if onnx:
  39. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  40. else:
  41. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  42. def forward(
  43. self,
  44. speech: torch.Tensor,
  45. speech_lengths: torch.Tensor,
  46. ):
  47. # a. To device
  48. batch = {"speech": speech, "speech_lengths": speech_lengths}
  49. # batch = to_device(batch, device=self.device)
  50. enc, enc_len = self.encoder(**batch)
  51. mask = self.make_pad_mask(enc_len)[:, None, :]
  52. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
  53. pre_token_length = pre_token_length.round().type(torch.int32)
  54. decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
  55. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  56. # sample_ids = decoder_out.argmax(dim=-1)
  57. return decoder_out, pre_token_length
  58. def get_dummy_inputs(self):
  59. speech = torch.randn(2, 30, self.feats_dim)
  60. speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
  61. return (speech, speech_lengths)
  62. def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
  63. import numpy as np
  64. fbank = np.loadtxt(txt_file)
  65. fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
  66. speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
  67. speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
  68. return (speech, speech_lengths)
  69. def get_input_names(self):
  70. return ['speech', 'speech_lengths']
  71. def get_output_names(self):
  72. return ['logits', 'token_num']
  73. def get_dynamic_axes(self):
  74. return {
  75. 'speech': {
  76. 0: 'batch_size',
  77. 1: 'feats_length'
  78. },
  79. 'speech_lengths': {
  80. 0: 'batch_size',
  81. },
  82. 'logits': {
  83. 0: 'batch_size',
  84. 1: 'logits_length'
  85. },
  86. }