e2e_asr_conformer.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import logging
  3. import torch
  4. import torch.nn as nn
  5. from funasr.export.utils.torch_function import MakePadMask
  6. from funasr.export.utils.torch_function import sequence_mask
  7. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  8. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  9. from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
  10. from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
  11. class Conformer(nn.Module):
  12. """
  13. export conformer into onnx format
  14. """
  15. def __init__(
  16. self,
  17. model,
  18. max_seq_len=512,
  19. feats_dim=560,
  20. model_name='model',
  21. **kwargs,
  22. ):
  23. super().__init__()
  24. onnx = False
  25. if "onnx" in kwargs:
  26. onnx = kwargs["onnx"]
  27. if isinstance(model.encoder, ConformerEncoder):
  28. self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
  29. elif isinstance(model.decoder, TransformerDecoder):
  30. self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
  31. self.feats_dim = feats_dim
  32. self.model_name = model_name
  33. if onnx:
  34. self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
  35. else:
  36. self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
  37. def _export_model(self, model, verbose, path):
  38. dummy_input = model.get_dummy_inputs()
  39. model_script = model
  40. model_path = os.path.join(path, f'{model.model_name}.onnx')
  41. if not os.path.exists(model_path):
  42. torch.onnx.export(
  43. model_script,
  44. dummy_input,
  45. model_path,
  46. verbose=verbose,
  47. opset_version=14,
  48. input_names=model.get_input_names(),
  49. output_names=model.get_output_names(),
  50. dynamic_axes=model.get_dynamic_axes()
  51. )
  52. def _export_encoder_onnx(self, verbose, path):
  53. model_encoder = self.encoder
  54. self._export_model(model_encoder, verbose, path)
  55. def _export_decoder_onnx(self, verbose, path):
  56. model_decoder = self.decoder
  57. self._export_model(model_decoder, verbose, path)
  58. def _export_onnx(self, verbose, path):
  59. self._export_encoder_onnx(verbose, path)
  60. self._export_decoder_onnx(verbose, path)