export_model.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import json
  2. from typing import Union, Dict
  3. from pathlib import Path
  4. from typeguard import check_argument_types
  5. import os
  6. import logging
  7. import torch
  8. from funasr.export.models import get_model
  9. import numpy as np
  10. import random
  11. # torch_version = float(".".join(torch.__version__.split(".")[:2]))
  12. # assert torch_version > 1.9
  13. class ASRModelExportParaformer:
  14. def __init__(
  15. self,
  16. cache_dir: Union[Path, str] = None,
  17. onnx: bool = True,
  18. quant: bool = True,
  19. fallback_num: int = 0,
  20. audio_in: str = None,
  21. calib_num: int = 200,
  22. ):
  23. assert check_argument_types()
  24. self.set_all_random_seed(0)
  25. if cache_dir is None:
  26. cache_dir = Path.home() / ".cache" / "export"
  27. self.cache_dir = Path(cache_dir)
  28. self.export_config = dict(
  29. feats_dim=560,
  30. onnx=False,
  31. )
  32. print("output dir: {}".format(self.cache_dir))
  33. self.onnx = onnx
  34. self.quant = quant
  35. self.fallback_num = fallback_num
  36. self.frontend = None
  37. self.audio_in = audio_in
  38. self.calib_num = calib_num
  39. def _export(
  40. self,
  41. model,
  42. tag_name: str = None,
  43. verbose: bool = False,
  44. ):
  45. export_dir = self.cache_dir / tag_name.replace(' ', '-')
  46. os.makedirs(export_dir, exist_ok=True)
  47. # export encoder1
  48. self.export_config["model_name"] = "model"
  49. model = get_model(
  50. model,
  51. self.export_config,
  52. )
  53. model.eval()
  54. # self._export_onnx(model, verbose, export_dir)
  55. if self.onnx:
  56. self._export_onnx(model, verbose, export_dir)
  57. else:
  58. self._export_torchscripts(model, verbose, export_dir)
  59. print("output dir: {}".format(export_dir))
  60. def _torch_quantize(self, model):
  61. def _run_calibration_data(m):
  62. # using dummy inputs for a example
  63. if self.audio_in is not None:
  64. feats, feats_len = self.load_feats(self.audio_in)
  65. for feat, len in zip(feats, feats_len):
  66. m(feat, len)
  67. else:
  68. dummy_input = model.get_dummy_inputs()
  69. m(*dummy_input)
  70. from torch_quant.module import ModuleFilter
  71. from torch_quant.quantizer import Backend, Quantizer
  72. from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
  73. from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
  74. module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM])
  75. module_filter.exclude_op_types = [torch.nn.Conv1d]
  76. quantizer = Quantizer(
  77. module_filter=module_filter,
  78. backend=Backend.FBGEMM,
  79. )
  80. model.eval()
  81. calib_model = quantizer.calib(model)
  82. _run_calibration_data(calib_model)
  83. if self.fallback_num > 0:
  84. # perform automatic mixed precision quantization
  85. amp_model = quantizer.amp(model)
  86. _run_calibration_data(amp_model)
  87. quantizer.fallback(amp_model, num=self.fallback_num)
  88. print('Fallback layers:')
  89. print('\n'.join(quantizer.module_filter.exclude_names))
  90. quant_model = quantizer.quantize(model)
  91. return quant_model
  92. def _export_torchscripts(self, model, verbose, path, enc_size=None):
  93. if enc_size:
  94. dummy_input = model.get_dummy_inputs(enc_size)
  95. else:
  96. dummy_input = model.get_dummy_inputs()
  97. # model_script = torch.jit.script(model)
  98. model_script = torch.jit.trace(model, dummy_input)
  99. model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
  100. if self.quant:
  101. quant_model = self._torch_quantize(model)
  102. model_script = torch.jit.trace(quant_model, dummy_input)
  103. model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts'))
  104. def set_all_random_seed(self, seed: int):
  105. random.seed(seed)
  106. np.random.seed(seed)
  107. torch.random.manual_seed(seed)
  108. def parse_audio_in(self, audio_in):
  109. wav_list, name_list = [], []
  110. if audio_in.endswith(".scp"):
  111. f = open(audio_in, 'r')
  112. lines = f.readlines()[:self.calib_num]
  113. for line in lines:
  114. name, path = line.strip().split()
  115. name_list.append(name)
  116. wav_list.append(path)
  117. else:
  118. wav_list = [audio_in,]
  119. name_list = ["test",]
  120. return wav_list, name_list
  121. def load_feats(self, audio_in: str = None):
  122. import torchaudio
  123. wav_list, name_list = self.parse_audio_in(audio_in)
  124. feats = []
  125. feats_len = []
  126. for line in wav_list:
  127. name, path = line.strip().split()
  128. waveform, sampling_rate = torchaudio.load(path)
  129. if sampling_rate != self.frontend.fs:
  130. waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
  131. new_freq=self.frontend.fs)(waveform)
  132. fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
  133. feats.append(fbank)
  134. feats_len.append(fbank_len)
  135. return feats, feats_len
  136. def export(self,
  137. tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  138. mode: str = 'paraformer',
  139. ):
  140. model_dir = tag_name
  141. if model_dir.startswith('damo/'):
  142. from modelscope.hub.snapshot_download import snapshot_download
  143. model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
  144. asr_train_config = os.path.join(model_dir, 'config.yaml')
  145. asr_model_file = os.path.join(model_dir, 'model.pb')
  146. cmvn_file = os.path.join(model_dir, 'am.mvn')
  147. json_file = os.path.join(model_dir, 'configuration.json')
  148. if mode is None:
  149. import json
  150. with open(json_file, 'r') as f:
  151. config_data = json.load(f)
  152. mode = config_data['model']['model_config']['mode']
  153. if mode.startswith('paraformer'):
  154. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  155. elif mode.startswith('uniasr'):
  156. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  157. model, asr_train_args = ASRTask.build_model_from_file(
  158. asr_train_config, asr_model_file, cmvn_file, 'cpu'
  159. )
  160. self._export(model, tag_name)
  161. def _export_onnx(self, model, verbose, path, enc_size=None):
  162. if enc_size:
  163. dummy_input = model.get_dummy_inputs(enc_size)
  164. else:
  165. dummy_input = model.get_dummy_inputs()
  166. # model_script = torch.jit.script(model)
  167. model_script = model #torch.jit.trace(model)
  168. model_path = os.path.join(path, f'{model.model_name}.onnx')
  169. torch.onnx.export(
  170. model_script,
  171. dummy_input,
  172. model_path,
  173. verbose=verbose,
  174. opset_version=14,
  175. input_names=model.get_input_names(),
  176. output_names=model.get_output_names(),
  177. dynamic_axes=model.get_dynamic_axes()
  178. )
  179. if self.quant:
  180. from onnxruntime.quantization import QuantType, quantize_dynamic
  181. import onnx
  182. quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
  183. onnx_model = onnx.load(model_path)
  184. nodes = [n.name for n in onnx_model.graph.node]
  185. nodes_to_exclude = [m for m in nodes if 'output' in m]
  186. quantize_dynamic(
  187. model_input=model_path,
  188. model_output=quant_model_path,
  189. op_types_to_quantize=['MatMul'],
  190. per_channel=True,
  191. reduce_range=False,
  192. weight_type=QuantType.QUInt8,
  193. nodes_to_exclude=nodes_to_exclude,
  194. )
  195. if __name__ == '__main__':
  196. import argparse
  197. parser = argparse.ArgumentParser()
  198. parser.add_argument('--model-name', type=str, required=True)
  199. parser.add_argument('--export-dir', type=str, required=True)
  200. parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
  201. parser.add_argument('--quantize', action='store_true', help='export quantized model')
  202. parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
  203. parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
  204. parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
  205. args = parser.parse_args()
  206. export_model = ASRModelExportParaformer(
  207. cache_dir=args.export_dir,
  208. onnx=args.type == 'onnx',
  209. quant=args.quantize,
  210. fallback_num=args.fallback_num,
  211. audio_in=args.audio_in,
  212. calib_num=args.calib_num,
  213. )
  214. export_model.export(args.model_name)