export_model.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import json
  2. from typing import Union, Dict
  3. from pathlib import Path
  4. from typeguard import check_argument_types
  5. import os
  6. import logging
  7. import torch
  8. from funasr.bin.asr_inference_paraformer import Speech2Text
  9. from funasr.export.models import get_model
  10. import numpy as np
  11. import random
  12. class ASRModelExportParaformer:
  13. def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
  14. assert check_argument_types()
  15. self.set_all_random_seed(0)
  16. if cache_dir is None:
  17. cache_dir = Path.home() / ".cache" / "export"
  18. self.cache_dir = Path(cache_dir)
  19. self.export_config = dict(
  20. feats_dim=560,
  21. onnx=False,
  22. )
  23. print("output dir: {}".format(self.cache_dir))
  24. self.onnx = onnx
  25. def _export(
  26. self,
  27. model: Speech2Text,
  28. tag_name: str = None,
  29. verbose: bool = False,
  30. ):
  31. export_dir = self.cache_dir / tag_name.replace(' ', '-')
  32. os.makedirs(export_dir, exist_ok=True)
  33. # export encoder1
  34. self.export_config["model_name"] = "model"
  35. model = get_model(
  36. model,
  37. self.export_config,
  38. )
  39. model.eval()
  40. # self._export_onnx(model, verbose, export_dir)
  41. if self.onnx:
  42. self._export_onnx(model, verbose, export_dir)
  43. else:
  44. self._export_torchscripts(model, verbose, export_dir)
  45. print("output dir: {}".format(export_dir))
  46. def _export_torchscripts(self, model, verbose, path, enc_size=None):
  47. if enc_size:
  48. dummy_input = model.get_dummy_inputs(enc_size)
  49. else:
  50. dummy_input = model.get_dummy_inputs()
  51. # model_script = torch.jit.script(model)
  52. model_script = torch.jit.trace(model, dummy_input)
  53. model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
  54. def set_all_random_seed(self, seed: int):
  55. random.seed(seed)
  56. np.random.seed(seed)
  57. torch.random.manual_seed(seed)
  58. def export(self,
  59. tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  60. mode: str = 'paraformer',
  61. ):
  62. model_dir = tag_name
  63. if model_dir.startswith('damo/'):
  64. from modelscope.hub.snapshot_download import snapshot_download
  65. model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
  66. asr_train_config = os.path.join(model_dir, 'config.yaml')
  67. asr_model_file = os.path.join(model_dir, 'model.pb')
  68. cmvn_file = os.path.join(model_dir, 'am.mvn')
  69. json_file = os.path.join(model_dir, 'configuration.json')
  70. if mode is None:
  71. import json
  72. with open(json_file, 'r') as f:
  73. config_data = json.load(f)
  74. mode = config_data['model']['model_config']['mode']
  75. if mode.startswith('paraformer'):
  76. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  77. elif mode.startswith('uniasr'):
  78. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  79. model, asr_train_args = ASRTask.build_model_from_file(
  80. asr_train_config, asr_model_file, cmvn_file, 'cpu'
  81. )
  82. self._export(model, tag_name)
  83. def _export_onnx(self, model, verbose, path, enc_size=None):
  84. if enc_size:
  85. dummy_input = model.get_dummy_inputs(enc_size)
  86. else:
  87. dummy_input = model.get_dummy_inputs()
  88. # model_script = torch.jit.script(model)
  89. model_script = model #torch.jit.trace(model)
  90. torch.onnx.export(
  91. model_script,
  92. dummy_input,
  93. os.path.join(path, f'{model.model_name}.onnx'),
  94. verbose=verbose,
  95. opset_version=14,
  96. input_names=model.get_input_names(),
  97. output_names=model.get_output_names(),
  98. dynamic_axes=model.get_dynamic_axes()
  99. )
  100. class ASRModelExport:
  101. def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
  102. assert check_argument_types()
  103. self.set_all_random_seed(0)
  104. if cache_dir is None:
  105. cache_dir = Path.home() / ".cache" / "export"
  106. self.cache_dir = Path(cache_dir)
  107. self.export_config = dict(
  108. feats_dim=560,
  109. onnx=False,
  110. )
  111. print("output dir: {}".format(self.cache_dir))
  112. self.onnx = onnx
  113. def _export(
  114. self,
  115. model: Speech2Text,
  116. tag_name: str = None,
  117. verbose: bool = False,
  118. ):
  119. export_dir = self.cache_dir / tag_name.replace(' ', '-')
  120. os.makedirs(export_dir, exist_ok=True)
  121. # export encoder1
  122. self.export_config["model_name"] = "model"
  123. model = get_model(
  124. model,
  125. self.export_config,
  126. )
  127. model.eval()
  128. # self._export_onnx(model, verbose, export_dir)
  129. if self.onnx:
  130. self._export_onnx(model, verbose, export_dir)
  131. else:
  132. self._export_torchscripts(model, verbose, export_dir)
  133. print("output dir: {}".format(export_dir))
  134. def _export_torchscripts(self, model, verbose, path, enc_size=None):
  135. if enc_size:
  136. dummy_input = model.get_dummy_inputs(enc_size)
  137. else:
  138. dummy_input = model.get_dummy_inputs_txt()
  139. # model_script = torch.jit.script(model)
  140. model_script = torch.jit.trace(model, dummy_input)
  141. model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
  142. def set_all_random_seed(self, seed: int):
  143. random.seed(seed)
  144. np.random.seed(seed)
  145. torch.random.manual_seed(seed)
  146. def export(self,
  147. tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  148. mode: str = 'paraformer',
  149. ):
  150. model_dir = tag_name
  151. if model_dir.startswith('damo/'):
  152. from modelscope.hub.snapshot_download import snapshot_download
  153. model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
  154. asr_train_config = os.path.join(model_dir, 'config.yaml')
  155. asr_model_file = os.path.join(model_dir, 'model.pb')
  156. cmvn_file = os.path.join(model_dir, 'am.mvn')
  157. json_file = os.path.join(model_dir, 'configuration.json')
  158. if mode is None:
  159. import json
  160. with open(json_file, 'r') as f:
  161. config_data = json.load(f)
  162. mode = config_data['model']['model_config']['mode']
  163. if mode.startswith('paraformer'):
  164. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  165. elif mode.startswith('uniasr'):
  166. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  167. model, asr_train_args = ASRTask.build_model_from_file(
  168. asr_train_config, asr_model_file, cmvn_file, 'cpu'
  169. )
  170. self._export(model, tag_name)
  171. def _export_onnx(self, model, verbose, path, enc_size=None):
  172. if enc_size:
  173. dummy_input = model.get_dummy_inputs(enc_size)
  174. else:
  175. dummy_input = model.get_dummy_inputs()
  176. # model_script = torch.jit.script(model)
  177. model_script = model # torch.jit.trace(model)
  178. torch.onnx.export(
  179. model_script,
  180. dummy_input,
  181. os.path.join(path, f'{model.model_name}.onnx'),
  182. verbose=verbose,
  183. opset_version=12,
  184. input_names=model.get_input_names(),
  185. output_names=model.get_output_names(),
  186. dynamic_axes=model.get_dynamic_axes()
  187. )
  188. if __name__ == '__main__':
  189. import sys
  190. model_path = sys.argv[1]
  191. output_dir = sys.argv[2]
  192. onnx = sys.argv[3]
  193. onnx = onnx.lower()
  194. onnx = onnx == 'true'
  195. # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
  196. # output_dir = "../export"
  197. export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx)
  198. export_model.export(model_path)
  199. # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')