export_model.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import json
  2. from typing import Union, Dict
  3. from pathlib import Path
  4. from typeguard import check_argument_types
  5. import os
  6. import logging
  7. import torch
  8. from funasr.bin.asr_inference_paraformer import Speech2Text
  9. from funasr.export.models import get_model
  10. class ASRModelExportParaformer:
  11. def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
  12. assert check_argument_types()
  13. if cache_dir is None:
  14. cache_dir = Path.home() / "cache" / "export"
  15. self.cache_dir = Path(cache_dir)
  16. self.export_config = dict(
  17. feats_dim=560,
  18. onnx=False,
  19. )
  20. logging.info("output dir: {}".format(self.cache_dir))
  21. self.onnx = onnx
  22. def _export(
  23. self,
  24. model: Speech2Text,
  25. tag_name: str = None,
  26. verbose: bool = False,
  27. ):
  28. export_dir = self.cache_dir / tag_name.replace(' ', '-')
  29. os.makedirs(export_dir, exist_ok=True)
  30. # export encoder1
  31. self.export_config["model_name"] = "model"
  32. model = get_model(
  33. model,
  34. self.export_config,
  35. )
  36. self._export_onnx(model, verbose, export_dir)
  37. if self.onnx:
  38. self._export_onnx(model, verbose, export_dir)
  39. else:
  40. self._export_torchscripts(model, verbose, export_dir)
  41. logging.info("output dir: {}".format(export_dir))
  42. def _export_torchscripts(self, model, verbose, path, enc_size=None):
  43. if enc_size:
  44. dummy_input = model.get_dummy_inputs(enc_size)
  45. else:
  46. dummy_input = model.get_dummy_inputs_txt()
  47. # model_script = torch.jit.script(model)
  48. model_script = torch.jit.trace(model, dummy_input)
  49. model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
  50. def export(self,
  51. tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  52. mode: str = 'paraformer',
  53. ):
  54. model_dir = tag_name
  55. if model_dir.startswith('damo/'):
  56. from modelscope.hub.snapshot_download import snapshot_download
  57. model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
  58. asr_train_config = os.path.join(model_dir, 'config.yaml')
  59. asr_model_file = os.path.join(model_dir, 'model.pb')
  60. cmvn_file = os.path.join(model_dir, 'am.mvn')
  61. json_file = os.path.join(model_dir, 'configuration.json')
  62. if mode is None:
  63. import json
  64. with open(json_file, 'r') as f:
  65. config_data = json.load(f)
  66. mode = config_data['model']['model_config']['mode']
  67. if mode == 'paraformer':
  68. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  69. elif mode == 'uniasr':
  70. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  71. model, asr_train_args = ASRTask.build_model_from_file(
  72. asr_train_config, asr_model_file, cmvn_file, 'cpu'
  73. )
  74. self._export(model, tag_name)
  75. # def export_from_modelscope(
  76. # self,
  77. # tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  78. # ):
  79. #
  80. # from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  81. # from modelscope.hub.snapshot_download import snapshot_download
  82. #
  83. # model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir)
  84. # asr_train_config = os.path.join(model_dir, 'config.yaml')
  85. # asr_model_file = os.path.join(model_dir, 'model.pb')
  86. # cmvn_file = os.path.join(model_dir, 'am.mvn')
  87. # model, asr_train_args = ASRTask.build_model_from_file(
  88. # asr_train_config, asr_model_file, cmvn_file, 'cpu'
  89. # )
  90. # self.export(model, tag_name)
  91. #
  92. # def export_from_local(
  93. # self,
  94. # tag_name: str = '/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  95. # ):
  96. #
  97. # from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  98. #
  99. # model_dir = tag_name
  100. # asr_train_config = os.path.join(model_dir, 'config.yaml')
  101. # asr_model_file = os.path.join(model_dir, 'model.pb')
  102. # cmvn_file = os.path.join(model_dir, 'am.mvn')
  103. # model, asr_train_args = ASRTask.build_model_from_file(
  104. # asr_train_config, asr_model_file, cmvn_file, 'cpu'
  105. # )
  106. # self.export(model, tag_name)
  107. def _export_onnx(self, model, verbose, path, enc_size=None):
  108. if enc_size:
  109. dummy_input = model.get_dummy_inputs(enc_size)
  110. else:
  111. dummy_input = model.get_dummy_inputs()
  112. # model_script = torch.jit.script(model)
  113. model_script = model #torch.jit.trace(model)
  114. torch.onnx.export(
  115. model_script,
  116. dummy_input,
  117. os.path.join(path, f'{model.model_name}.onnx'),
  118. verbose=verbose,
  119. opset_version=12,
  120. input_names=model.get_input_names(),
  121. output_names=model.get_output_names(),
  122. dynamic_axes=model.get_dynamic_axes()
  123. )
  124. if __name__ == '__main__':
  125. output_dir = "../export"
  126. export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
  127. export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
  128. # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')