paraformer_bin.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # -*- encoding: utf-8 -*-
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import os.path
  5. from pathlib import Path
  6. from typing import List, Union, Tuple
  7. import copy
  8. import librosa
  9. import numpy as np
  10. from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
  11. OrtInferSession, TokenIDConverter, get_logger,
  12. read_yaml)
  13. from .utils.postprocess_utils import (sentence_postprocess,
  14. sentence_postprocess_sentencepiece)
  15. from .utils.frontend import WavFrontend
  16. from .utils.timestamp_utils import time_stamp_lfr6_onnx
  17. from .utils.utils import pad_list
  18. logging = get_logger()
  19. class Paraformer():
  20. """
  21. Author: Speech Lab of DAMO Academy, Alibaba Group
  22. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  23. https://arxiv.org/abs/2206.08317
  24. """
  25. def __init__(self, model_dir: Union[str, Path] = None,
  26. batch_size: int = 1,
  27. device_id: Union[str, int] = "-1",
  28. plot_timestamp_to: str = "",
  29. quantize: bool = False,
  30. intra_op_num_threads: int = 4,
  31. cache_dir: str = None
  32. ):
  33. if not Path(model_dir).exists():
  34. try:
  35. from modelscope.hub.snapshot_download import snapshot_download
  36. except:
  37. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  38. "\npip3 install -U modelscope\n" \
  39. "For the users in China, you could install with the command:\n" \
  40. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  41. try:
  42. model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
  43. except:
  44. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
  45. model_file = os.path.join(model_dir, 'model.onnx')
  46. if quantize:
  47. model_file = os.path.join(model_dir, 'model_quant.onnx')
  48. if not os.path.exists(model_file):
  49. print(".onnx is not exist, begin to export onnx")
  50. try:
  51. from funasr.export.export_model import ModelExport
  52. except:
  53. raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
  54. "\npip3 install -U funasr\n" \
  55. "For the users in China, you could install with the command:\n" \
  56. "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  57. export_model = ModelExport(
  58. cache_dir=cache_dir,
  59. onnx=True,
  60. device="cpu",
  61. quant=quantize,
  62. )
  63. export_model.export(model_dir)
  64. config_file = os.path.join(model_dir, 'config.yaml')
  65. cmvn_file = os.path.join(model_dir, 'am.mvn')
  66. config = read_yaml(config_file)
  67. self.converter = TokenIDConverter(config['token_list'])
  68. self.tokenizer = CharTokenizer()
  69. self.frontend = WavFrontend(
  70. cmvn_file=cmvn_file,
  71. **config['frontend_conf']
  72. )
  73. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  74. self.batch_size = batch_size
  75. self.plot_timestamp_to = plot_timestamp_to
  76. if "predictor_bias" in config['model_conf'].keys():
  77. self.pred_bias = config['model_conf']['predictor_bias']
  78. else:
  79. self.pred_bias = 0
  80. if "lang" in config:
  81. self.language = config['lang']
  82. else:
  83. self.language = None
  84. def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
  85. waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
  86. waveform_nums = len(waveform_list)
  87. asr_res = []
  88. for beg_idx in range(0, waveform_nums, self.batch_size):
  89. end_idx = min(waveform_nums, beg_idx + self.batch_size)
  90. feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
  91. try:
  92. outputs = self.infer(feats, feats_len)
  93. am_scores, valid_token_lens = outputs[0], outputs[1]
  94. if len(outputs) == 4:
  95. # for BiCifParaformer Inference
  96. us_alphas, us_peaks = outputs[2], outputs[3]
  97. else:
  98. us_alphas, us_peaks = None, None
  99. except ONNXRuntimeError:
  100. #logging.warning(traceback.format_exc())
  101. logging.warning("input wav is silence or noise")
  102. preds = ['']
  103. else:
  104. preds = self.decode(am_scores, valid_token_lens)
  105. if us_peaks is None:
  106. for pred in preds:
  107. if self.language == "en-bpe":
  108. pred = sentence_postprocess_sentencepiece(pred)
  109. else:
  110. pred = sentence_postprocess(pred)
  111. asr_res.append({'preds': pred})
  112. else:
  113. for pred, us_peaks_ in zip(preds, us_peaks):
  114. raw_tokens = pred
  115. timestamp, timestamp_raw = time_stamp_lfr6_onnx(us_peaks_, copy.copy(raw_tokens))
  116. text_proc, timestamp_proc, _ = sentence_postprocess(raw_tokens, timestamp_raw)
  117. # logging.warning(timestamp)
  118. if len(self.plot_timestamp_to):
  119. self.plot_wave_timestamp(waveform_list[0], timestamp, self.plot_timestamp_to)
  120. asr_res.append({'preds': text_proc, 'timestamp': timestamp_proc, "raw_tokens": raw_tokens})
  121. return asr_res
  122. def plot_wave_timestamp(self, wav, text_timestamp, dest):
  123. # TODO: Plot the wav and timestamp results with matplotlib
  124. import matplotlib
  125. matplotlib.use('Agg')
  126. matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports
  127. import matplotlib.pyplot as plt
  128. fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
  129. ax2 = ax1.twinx()
  130. ax2.set_ylim([0, 2.0])
  131. # plot waveform
  132. ax1.set_ylim([-0.3, 0.3])
  133. time = np.arange(wav.shape[0]) / 16000
  134. ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4)
  135. # plot lines and text
  136. for (char, start, end) in text_timestamp:
  137. ax1.vlines(start, -0.3, 0.3, ls='--')
  138. ax1.vlines(end, -0.3, 0.3, ls='--')
  139. x_adj = 0.045 if char != '<sil>' else 0.12
  140. ax1.text((start + end) * 0.5 - x_adj, 0, char)
  141. # plt.legend()
  142. plotname = "{}/timestamp.png".format(dest)
  143. plt.savefig(plotname, bbox_inches='tight')
  144. def load_data(self,
  145. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  146. def load_wav(path: str) -> np.ndarray:
  147. waveform, _ = librosa.load(path, sr=fs)
  148. return waveform
  149. if isinstance(wav_content, np.ndarray):
  150. return [wav_content]
  151. if isinstance(wav_content, str):
  152. return [load_wav(wav_content)]
  153. if isinstance(wav_content, list):
  154. return [load_wav(path) for path in wav_content]
  155. raise TypeError(
  156. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  157. def extract_feat(self,
  158. waveform_list: List[np.ndarray]
  159. ) -> Tuple[np.ndarray, np.ndarray]:
  160. feats, feats_len = [], []
  161. for waveform in waveform_list:
  162. speech, _ = self.frontend.fbank(waveform)
  163. feat, feat_len = self.frontend.lfr_cmvn(speech)
  164. feats.append(feat)
  165. feats_len.append(feat_len)
  166. feats = self.pad_feats(feats, np.max(feats_len))
  167. feats_len = np.array(feats_len).astype(np.int32)
  168. return feats, feats_len
  169. @staticmethod
  170. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  171. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  172. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  173. return np.pad(feat, pad_width, 'constant', constant_values=0)
  174. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  175. feats = np.array(feat_res).astype(np.float32)
  176. return feats
  177. def infer(self, feats: np.ndarray,
  178. feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  179. outputs = self.ort_infer([feats, feats_len])
  180. return outputs
  181. def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
  182. return [self.decode_one(am_score, token_num)
  183. for am_score, token_num in zip(am_scores, token_nums)]
  184. def decode_one(self,
  185. am_score: np.ndarray,
  186. valid_token_num: int) -> List[str]:
  187. yseq = am_score.argmax(axis=-1)
  188. score = am_score.max(axis=-1)
  189. score = np.sum(score, axis=-1)
  190. # pad with mask tokens to ensure compatibility with sos/eos tokens
  191. # asr_model.sos:1 asr_model.eos:2
  192. yseq = np.array([1] + yseq.tolist() + [2])
  193. hyp = Hypothesis(yseq=yseq, score=score)
  194. # remove sos/eos and get results
  195. last_pos = -1
  196. token_int = hyp.yseq[1:last_pos].tolist()
  197. # remove blank symbol id, which is assumed to be 0
  198. token_int = list(filter(lambda x: x not in (0, 2), token_int))
  199. # Change integer-ids to tokens
  200. token = self.converter.ids2tokens(token_int)
  201. token = token[:valid_token_num-self.pred_bias]
  202. # texts = sentence_postprocess(token)
  203. return token
  204. class ContextualParaformer(Paraformer):
  205. """
  206. Author: Speech Lab of DAMO Academy, Alibaba Group
  207. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  208. https://arxiv.org/abs/2206.08317
  209. """
  210. def __init__(self, model_dir: Union[str, Path] = None,
  211. batch_size: int = 1,
  212. device_id: Union[str, int] = "-1",
  213. plot_timestamp_to: str = "",
  214. quantize: bool = False,
  215. intra_op_num_threads: int = 4,
  216. cache_dir: str = None
  217. ):
  218. if not Path(model_dir).exists():
  219. try:
  220. from modelscope.hub.snapshot_download import snapshot_download
  221. except:
  222. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  223. "\npip3 install -U modelscope\n" \
  224. "For the users in China, you could install with the command:\n" \
  225. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  226. try:
  227. model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
  228. except:
  229. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
  230. if quantize:
  231. model_bb_file = os.path.join(model_dir, 'model_quant.onnx')
  232. model_eb_file = os.path.join(model_dir, 'model_eb_quant.onnx')
  233. else:
  234. model_bb_file = os.path.join(model_dir, 'model.onnx')
  235. model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
  236. token_list_file = os.path.join(model_dir, 'tokens.txt')
  237. self.vocab = {}
  238. with open(Path(token_list_file), 'r') as fin:
  239. for i, line in enumerate(fin.readlines()):
  240. self.vocab[line.strip()] = i
  241. #if quantize:
  242. # model_file = os.path.join(model_dir, 'model_quant.onnx')
  243. #if not os.path.exists(model_file):
  244. # logging.error(".onnx model not exist, please export first.")
  245. config_file = os.path.join(model_dir, 'config.yaml')
  246. cmvn_file = os.path.join(model_dir, 'am.mvn')
  247. config = read_yaml(config_file)
  248. self.converter = TokenIDConverter(config['token_list'])
  249. self.tokenizer = CharTokenizer()
  250. self.frontend = WavFrontend(
  251. cmvn_file=cmvn_file,
  252. **config['frontend_conf']
  253. )
  254. self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads)
  255. self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads)
  256. self.batch_size = batch_size
  257. self.plot_timestamp_to = plot_timestamp_to
  258. if "predictor_bias" in config['model_conf'].keys():
  259. self.pred_bias = config['model_conf']['predictor_bias']
  260. else:
  261. self.pred_bias = 0
  262. def __call__(self,
  263. wav_content: Union[str, np.ndarray, List[str]],
  264. hotwords: str,
  265. **kwargs) -> List:
  266. # make hotword list
  267. hotwords, hotwords_length = self.proc_hotword(hotwords)
  268. # import pdb; pdb.set_trace()
  269. [bias_embed] = self.eb_infer(hotwords, hotwords_length)
  270. # index from bias_embed
  271. bias_embed = bias_embed.transpose(1, 0, 2)
  272. _ind = np.arange(0, len(hotwords)).tolist()
  273. bias_embed = bias_embed[_ind, hotwords_length.tolist()]
  274. waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
  275. waveform_nums = len(waveform_list)
  276. asr_res = []
  277. for beg_idx in range(0, waveform_nums, self.batch_size):
  278. end_idx = min(waveform_nums, beg_idx + self.batch_size)
  279. feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
  280. bias_embed = np.expand_dims(bias_embed, axis=0)
  281. bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
  282. try:
  283. outputs = self.bb_infer(feats, feats_len, bias_embed)
  284. am_scores, valid_token_lens = outputs[0], outputs[1]
  285. except ONNXRuntimeError:
  286. #logging.warning(traceback.format_exc())
  287. logging.warning("input wav is silence or noise")
  288. preds = ['']
  289. else:
  290. preds = self.decode(am_scores, valid_token_lens)
  291. for pred in preds:
  292. pred = sentence_postprocess(pred)
  293. asr_res.append({'preds': pred})
  294. return asr_res
  295. def proc_hotword(self, hotwords):
  296. hotwords = hotwords.split(" ")
  297. hotwords_length = [len(i) - 1 for i in hotwords]
  298. hotwords_length.append(0)
  299. hotwords_length = np.array(hotwords_length)
  300. # hotwords.append('<s>')
  301. def word_map(word):
  302. hotwords = []
  303. for c in word:
  304. if c not in self.vocab.keys():
  305. hotwords.append(8403)
  306. logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
  307. else:
  308. hotwords.append(self.vocab[c])
  309. return np.array(hotwords)
  310. hotword_int = [word_map(i) for i in hotwords]
  311. # import pdb; pdb.set_trace()
  312. hotword_int.append(np.array([1]))
  313. hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
  314. # import pdb; pdb.set_trace()
  315. return hotwords, hotwords_length
  316. def bb_infer(self, feats: np.ndarray,
  317. feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]:
  318. outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
  319. return outputs
  320. def eb_infer(self, hotwords, hotwords_length):
  321. outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
  322. return outputs
  323. def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
  324. return [self.decode_one(am_score, token_num)
  325. for am_score, token_num in zip(am_scores, token_nums)]
  326. def decode_one(self,
  327. am_score: np.ndarray,
  328. valid_token_num: int) -> List[str]:
  329. yseq = am_score.argmax(axis=-1)
  330. score = am_score.max(axis=-1)
  331. score = np.sum(score, axis=-1)
  332. # pad with mask tokens to ensure compatibility with sos/eos tokens
  333. # asr_model.sos:1 asr_model.eos:2
  334. yseq = np.array([1] + yseq.tolist() + [2])
  335. hyp = Hypothesis(yseq=yseq, score=score)
  336. # remove sos/eos and get results
  337. last_pos = -1
  338. token_int = hyp.yseq[1:last_pos].tolist()
  339. # remove blank symbol id, which is assumed to be 0
  340. token_int = list(filter(lambda x: x not in (0, 2), token_int))
  341. # Change integer-ids to tokens
  342. token = self.converter.ids2tokens(token_int)
  343. token = token[:valid_token_num-self.pred_bias]
  344. # texts = sentence_postprocess(token)
  345. return token