export_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import os
  2. import torch
  3. import random
  4. import logging
  5. import numpy as np
  6. from pathlib import Path
  7. from typing import Union, Dict, List
  8. from funasr.export.models import get_model
  9. from funasr.utils.types import str2bool, str2triple_str
  10. # torch_version = float(".".join(torch.__version__.split(".")[:2]))
  11. # assert torch_version > 1.9
  12. class ModelExport:
  13. def __init__(
  14. self,
  15. cache_dir: Union[Path, str] = None,
  16. onnx: bool = True,
  17. device: str = "cpu",
  18. quant: bool = True,
  19. fallback_num: int = 0,
  20. audio_in: str = None,
  21. calib_num: int = 200,
  22. model_revision: str = None,
  23. ):
  24. self.set_all_random_seed(0)
  25. self.cache_dir = cache_dir
  26. self.export_config = dict(
  27. feats_dim=560,
  28. onnx=False,
  29. )
  30. self.onnx = onnx
  31. self.device = device
  32. self.quant = quant
  33. self.fallback_num = fallback_num
  34. self.frontend = None
  35. self.audio_in = audio_in
  36. self.calib_num = calib_num
  37. self.model_revision = model_revision
  38. def _export(
  39. self,
  40. model,
  41. tag_name: str = None,
  42. verbose: bool = False,
  43. ):
  44. export_dir = self.cache_dir
  45. os.makedirs(export_dir, exist_ok=True)
  46. # export encoder1
  47. self.export_config["model_name"] = "model"
  48. model = get_model(
  49. model,
  50. self.export_config,
  51. )
  52. if isinstance(model, List):
  53. for m in model:
  54. m.eval()
  55. if self.onnx:
  56. self._export_onnx(m, verbose, export_dir)
  57. else:
  58. self._export_torchscripts(m, verbose, export_dir)
  59. print("output dir: {}".format(export_dir))
  60. else:
  61. model.eval()
  62. # self._export_onnx(model, verbose, export_dir)
  63. if self.onnx:
  64. self._export_onnx(model, verbose, export_dir)
  65. else:
  66. self._export_torchscripts(model, verbose, export_dir)
  67. print("output dir: {}".format(export_dir))
  68. def _torch_quantize(self, model):
  69. def _run_calibration_data(m):
  70. # using dummy inputs for a example
  71. if self.audio_in is not None:
  72. feats, feats_len = self.load_feats(self.audio_in)
  73. for i, (feat, len) in enumerate(zip(feats, feats_len)):
  74. with torch.no_grad():
  75. m(feat, len)
  76. else:
  77. dummy_input = model.get_dummy_inputs()
  78. m(*dummy_input)
  79. from torch_quant.module import ModuleFilter
  80. from torch_quant.quantizer import Backend, Quantizer
  81. from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
  82. from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
  83. module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM])
  84. module_filter.exclude_op_types = [torch.nn.Conv1d]
  85. quantizer = Quantizer(
  86. module_filter=module_filter,
  87. backend=Backend.FBGEMM,
  88. )
  89. model.eval()
  90. calib_model = quantizer.calib(model)
  91. _run_calibration_data(calib_model)
  92. if self.fallback_num > 0:
  93. # perform automatic mixed precision quantization
  94. amp_model = quantizer.amp(model)
  95. _run_calibration_data(amp_model)
  96. quantizer.fallback(amp_model, num=self.fallback_num)
  97. print('Fallback layers:')
  98. print('\n'.join(quantizer.module_filter.exclude_names))
  99. quant_model = quantizer.quantize(model)
  100. return quant_model
  101. def _export_torchscripts(self, model, verbose, path, enc_size=None):
  102. if enc_size:
  103. dummy_input = model.get_dummy_inputs(enc_size)
  104. else:
  105. dummy_input = model.get_dummy_inputs()
  106. if self.device == 'cuda':
  107. model = model.cuda()
  108. dummy_input = tuple([i.cuda() for i in dummy_input])
  109. # model_script = torch.jit.script(model)
  110. model_script = torch.jit.trace(model, dummy_input)
  111. model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
  112. if self.quant:
  113. quant_model = self._torch_quantize(model)
  114. model_script = torch.jit.trace(quant_model, dummy_input)
  115. model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts'))
  116. def set_all_random_seed(self, seed: int):
  117. random.seed(seed)
  118. np.random.seed(seed)
  119. torch.random.manual_seed(seed)
  120. def parse_audio_in(self, audio_in):
  121. wav_list, name_list = [], []
  122. if audio_in.endswith(".scp"):
  123. f = open(audio_in, 'r')
  124. lines = f.readlines()[:self.calib_num]
  125. for line in lines:
  126. name, path = line.strip().split()
  127. name_list.append(name)
  128. wav_list.append(path)
  129. else:
  130. wav_list = [audio_in,]
  131. name_list = ["test",]
  132. return wav_list, name_list
  133. def load_feats(self, audio_in: str = None):
  134. import torchaudio
  135. wav_list, name_list = self.parse_audio_in(audio_in)
  136. feats = []
  137. feats_len = []
  138. for line in wav_list:
  139. path = line.strip()
  140. waveform, sampling_rate = torchaudio.load(path)
  141. if sampling_rate != self.frontend.fs:
  142. waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
  143. new_freq=self.frontend.fs)(waveform)
  144. fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
  145. feats.append(fbank)
  146. feats_len.append(fbank_len)
  147. return feats, feats_len
  148. def export(self,
  149. tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
  150. mode: str = None,
  151. ):
  152. model_dir = tag_name
  153. if model_dir.startswith('damo'):
  154. from modelscope.hub.snapshot_download import snapshot_download
  155. model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
  156. self.cache_dir = model_dir
  157. if mode is None:
  158. import json
  159. json_file = os.path.join(model_dir, 'configuration.json')
  160. with open(json_file, 'r') as f:
  161. config_data = json.load(f)
  162. if config_data['task'] == "punctuation":
  163. mode = config_data['model']['punc_model_config']['mode']
  164. else:
  165. mode = config_data['model']['model_config']['mode']
  166. if mode.startswith('paraformer'):
  167. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  168. config = os.path.join(model_dir, 'config.yaml')
  169. model_file = os.path.join(model_dir, 'model.pb')
  170. cmvn_file = os.path.join(model_dir, 'am.mvn')
  171. model, asr_train_args = ASRTask.build_model_from_file(
  172. config, model_file, cmvn_file, 'cpu'
  173. )
  174. self.frontend = model.frontend
  175. self.export_config["feats_dim"] = 560
  176. elif mode.startswith('offline'):
  177. from funasr.tasks.vad import VADTask
  178. config = os.path.join(model_dir, 'vad.yaml')
  179. model_file = os.path.join(model_dir, 'vad.pb')
  180. cmvn_file = os.path.join(model_dir, 'vad.mvn')
  181. model, vad_infer_args = VADTask.build_model_from_file(
  182. config, model_file, cmvn_file=cmvn_file, device='cpu'
  183. )
  184. self.export_config["feats_dim"] = 400
  185. self.frontend = model.frontend
  186. elif mode.startswith('punc'):
  187. from funasr.tasks.punctuation import PunctuationTask as PUNCTask
  188. punc_train_config = os.path.join(model_dir, 'config.yaml')
  189. punc_model_file = os.path.join(model_dir, 'punc.pb')
  190. model, punc_train_args = PUNCTask.build_model_from_file(
  191. punc_train_config, punc_model_file, 'cpu'
  192. )
  193. elif mode.startswith('punc_VadRealtime'):
  194. from funasr.tasks.punctuation import PunctuationTask as PUNCTask
  195. punc_train_config = os.path.join(model_dir, 'config.yaml')
  196. punc_model_file = os.path.join(model_dir, 'punc.pb')
  197. model, punc_train_args = PUNCTask.build_model_from_file(
  198. punc_train_config, punc_model_file, 'cpu'
  199. )
  200. self._export(model, tag_name)
  201. def _export_onnx(self, model, verbose, path, enc_size=None):
  202. if enc_size:
  203. dummy_input = model.get_dummy_inputs(enc_size)
  204. else:
  205. dummy_input = model.get_dummy_inputs()
  206. # model_script = torch.jit.script(model)
  207. model_script = model #torch.jit.trace(model)
  208. model_path = os.path.join(path, f'{model.model_name}.onnx')
  209. # if not os.path.exists(model_path):
  210. torch.onnx.export(
  211. model_script,
  212. dummy_input,
  213. model_path,
  214. verbose=verbose,
  215. opset_version=14,
  216. input_names=model.get_input_names(),
  217. output_names=model.get_output_names(),
  218. dynamic_axes=model.get_dynamic_axes()
  219. )
  220. if self.quant:
  221. from onnxruntime.quantization import QuantType, quantize_dynamic
  222. import onnx
  223. quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
  224. if not os.path.exists(quant_model_path):
  225. onnx_model = onnx.load(model_path)
  226. nodes = [n.name for n in onnx_model.graph.node]
  227. nodes_to_exclude = [m for m in nodes if 'output' in m]
  228. quantize_dynamic(
  229. model_input=model_path,
  230. model_output=quant_model_path,
  231. op_types_to_quantize=['MatMul'],
  232. per_channel=True,
  233. reduce_range=False,
  234. weight_type=QuantType.QUInt8,
  235. nodes_to_exclude=nodes_to_exclude,
  236. )
  237. if __name__ == '__main__':
  238. import argparse
  239. parser = argparse.ArgumentParser()
  240. # parser.add_argument('--model-name', type=str, required=True)
  241. parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
  242. parser.add_argument('--export-dir', type=str, required=True)
  243. parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
  244. parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
  245. parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
  246. parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
  247. parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
  248. parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
  249. parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
  250. args = parser.parse_args()
  251. export_model = ModelExport(
  252. cache_dir=args.export_dir,
  253. onnx=args.type == 'onnx',
  254. device=args.device,
  255. quant=args.quantize,
  256. fallback_num=args.fallback_num,
  257. audio_in=args.audio_in,
  258. calib_num=args.calib_num,
  259. model_revision=args.model_revision,
  260. )
  261. for model_name in args.model_name:
  262. print("export model: {}".format(model_name))
  263. export_model.export(model_name)