paraformer_bin.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # -*- encoding: utf-8 -*-
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import os.path
  5. from pathlib import Path
  6. from typing import List, Union, Tuple
  7. import copy
  8. import torch
  9. import librosa
  10. import numpy as np
  11. from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
  12. OrtInferSession, TokenIDConverter, get_logger,
  13. read_yaml)
  14. from .utils.postprocess_utils import (sentence_postprocess,
  15. sentence_postprocess_sentencepiece)
  16. from .utils.frontend import WavFrontend
  17. from .utils.timestamp_utils import time_stamp_lfr6_onnx
  18. from .utils.utils import pad_list, make_pad_mask
  19. logging = get_logger()
  20. class Paraformer():
  21. """
  22. Author: Speech Lab of DAMO Academy, Alibaba Group
  23. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  24. https://arxiv.org/abs/2206.08317
  25. """
  26. def __init__(self, model_dir: Union[str, Path] = None,
  27. batch_size: int = 1,
  28. device_id: Union[str, int] = "-1",
  29. plot_timestamp_to: str = "",
  30. quantize: bool = False,
  31. intra_op_num_threads: int = 4,
  32. cache_dir: str = None
  33. ):
  34. if not Path(model_dir).exists():
  35. try:
  36. from modelscope.hub.snapshot_download import snapshot_download
  37. except:
  38. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  39. "\npip3 install -U modelscope\n" \
  40. "For the users in China, you could install with the command:\n" \
  41. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  42. try:
  43. model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
  44. except:
  45. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
  46. model_file = os.path.join(model_dir, 'model.onnx')
  47. if quantize:
  48. model_file = os.path.join(model_dir, 'model_quant.onnx')
  49. if not os.path.exists(model_file):
  50. print(".onnx is not exist, begin to export onnx")
  51. try:
  52. from funasr.export.export_model import ModelExport
  53. except:
  54. raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
  55. "\npip3 install -U funasr\n" \
  56. "For the users in China, you could install with the command:\n" \
  57. "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  58. export_model = ModelExport(
  59. cache_dir=cache_dir,
  60. onnx=True,
  61. device="cpu",
  62. quant=quantize,
  63. )
  64. export_model.export(model_dir)
  65. config_file = os.path.join(model_dir, 'config.yaml')
  66. cmvn_file = os.path.join(model_dir, 'am.mvn')
  67. config = read_yaml(config_file)
  68. self.converter = TokenIDConverter(config['token_list'])
  69. self.tokenizer = CharTokenizer()
  70. self.frontend = WavFrontend(
  71. cmvn_file=cmvn_file,
  72. **config['frontend_conf']
  73. )
  74. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  75. self.batch_size = batch_size
  76. self.plot_timestamp_to = plot_timestamp_to
  77. if "predictor_bias" in config['model_conf'].keys():
  78. self.pred_bias = config['model_conf']['predictor_bias']
  79. else:
  80. self.pred_bias = 0
  81. if "lang" in config:
  82. self.language = config['lang']
  83. else:
  84. self.language = None
  85. def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
  86. waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
  87. waveform_nums = len(waveform_list)
  88. asr_res = []
  89. for beg_idx in range(0, waveform_nums, self.batch_size):
  90. end_idx = min(waveform_nums, beg_idx + self.batch_size)
  91. feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
  92. try:
  93. outputs = self.infer(feats, feats_len)
  94. am_scores, valid_token_lens = outputs[0], outputs[1]
  95. if len(outputs) == 4:
  96. # for BiCifParaformer Inference
  97. us_alphas, us_peaks = outputs[2], outputs[3]
  98. else:
  99. us_alphas, us_peaks = None, None
  100. except ONNXRuntimeError:
  101. #logging.warning(traceback.format_exc())
  102. logging.warning("input wav is silence or noise")
  103. preds = ['']
  104. else:
  105. preds = self.decode(am_scores, valid_token_lens)
  106. if us_peaks is None:
  107. for pred in preds:
  108. if self.language == "en-bpe":
  109. pred = sentence_postprocess_sentencepiece(pred)
  110. else:
  111. pred = sentence_postprocess(pred)
  112. asr_res.append({'preds': pred})
  113. else:
  114. for pred, us_peaks_ in zip(preds, us_peaks):
  115. raw_tokens = pred
  116. timestamp, timestamp_raw = time_stamp_lfr6_onnx(us_peaks_, copy.copy(raw_tokens))
  117. text_proc, timestamp_proc, _ = sentence_postprocess(raw_tokens, timestamp_raw)
  118. # logging.warning(timestamp)
  119. if len(self.plot_timestamp_to):
  120. self.plot_wave_timestamp(waveform_list[0], timestamp, self.plot_timestamp_to)
  121. asr_res.append({'preds': text_proc, 'timestamp': timestamp_proc, "raw_tokens": raw_tokens})
  122. return asr_res
  123. def plot_wave_timestamp(self, wav, text_timestamp, dest):
  124. # TODO: Plot the wav and timestamp results with matplotlib
  125. import matplotlib
  126. matplotlib.use('Agg')
  127. matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports
  128. import matplotlib.pyplot as plt
  129. fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
  130. ax2 = ax1.twinx()
  131. ax2.set_ylim([0, 2.0])
  132. # plot waveform
  133. ax1.set_ylim([-0.3, 0.3])
  134. time = np.arange(wav.shape[0]) / 16000
  135. ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4)
  136. # plot lines and text
  137. for (char, start, end) in text_timestamp:
  138. ax1.vlines(start, -0.3, 0.3, ls='--')
  139. ax1.vlines(end, -0.3, 0.3, ls='--')
  140. x_adj = 0.045 if char != '<sil>' else 0.12
  141. ax1.text((start + end) * 0.5 - x_adj, 0, char)
  142. # plt.legend()
  143. plotname = "{}/timestamp.png".format(dest)
  144. plt.savefig(plotname, bbox_inches='tight')
  145. def load_data(self,
  146. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  147. def load_wav(path: str) -> np.ndarray:
  148. waveform, _ = librosa.load(path, sr=fs)
  149. return waveform
  150. if isinstance(wav_content, np.ndarray):
  151. return [wav_content]
  152. if isinstance(wav_content, str):
  153. return [load_wav(wav_content)]
  154. if isinstance(wav_content, list):
  155. return [load_wav(path) for path in wav_content]
  156. raise TypeError(
  157. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  158. def extract_feat(self,
  159. waveform_list: List[np.ndarray]
  160. ) -> Tuple[np.ndarray, np.ndarray]:
  161. feats, feats_len = [], []
  162. for waveform in waveform_list:
  163. speech, _ = self.frontend.fbank(waveform)
  164. feat, feat_len = self.frontend.lfr_cmvn(speech)
  165. feats.append(feat)
  166. feats_len.append(feat_len)
  167. feats = self.pad_feats(feats, np.max(feats_len))
  168. feats_len = np.array(feats_len).astype(np.int32)
  169. return feats, feats_len
  170. @staticmethod
  171. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  172. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  173. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  174. return np.pad(feat, pad_width, 'constant', constant_values=0)
  175. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  176. feats = np.array(feat_res).astype(np.float32)
  177. return feats
  178. def infer(self, feats: np.ndarray,
  179. feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  180. outputs = self.ort_infer([feats, feats_len])
  181. return outputs
  182. def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
  183. return [self.decode_one(am_score, token_num)
  184. for am_score, token_num in zip(am_scores, token_nums)]
  185. def decode_one(self,
  186. am_score: np.ndarray,
  187. valid_token_num: int) -> List[str]:
  188. yseq = am_score.argmax(axis=-1)
  189. score = am_score.max(axis=-1)
  190. score = np.sum(score, axis=-1)
  191. # pad with mask tokens to ensure compatibility with sos/eos tokens
  192. # asr_model.sos:1 asr_model.eos:2
  193. yseq = np.array([1] + yseq.tolist() + [2])
  194. hyp = Hypothesis(yseq=yseq, score=score)
  195. # remove sos/eos and get results
  196. last_pos = -1
  197. token_int = hyp.yseq[1:last_pos].tolist()
  198. # remove blank symbol id, which is assumed to be 0
  199. token_int = list(filter(lambda x: x not in (0, 2), token_int))
  200. # Change integer-ids to tokens
  201. token = self.converter.ids2tokens(token_int)
  202. token = token[:valid_token_num-self.pred_bias]
  203. # texts = sentence_postprocess(token)
  204. return token
  205. class ContextualParaformer(Paraformer):
  206. """
  207. Author: Speech Lab of DAMO Academy, Alibaba Group
  208. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  209. https://arxiv.org/abs/2206.08317
  210. """
  211. def __init__(self, model_dir: Union[str, Path] = None,
  212. batch_size: int = 1,
  213. device_id: Union[str, int] = "-1",
  214. plot_timestamp_to: str = "",
  215. quantize: bool = False,
  216. intra_op_num_threads: int = 4,
  217. cache_dir: str = None
  218. ):
  219. if not Path(model_dir).exists():
  220. try:
  221. from modelscope.hub.snapshot_download import snapshot_download
  222. except:
  223. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  224. "\npip3 install -U modelscope\n" \
  225. "For the users in China, you could install with the command:\n" \
  226. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  227. try:
  228. model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
  229. except:
  230. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
  231. if quantize:
  232. model_bb_file = os.path.join(model_dir, 'model_quant.onnx')
  233. model_eb_file = os.path.join(model_dir, 'model_eb_quant.onnx')
  234. else:
  235. model_bb_file = os.path.join(model_dir, 'model.onnx')
  236. model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
  237. token_list_file = os.path.join(model_dir, 'tokens.txt')
  238. self.vocab = {}
  239. with open(Path(token_list_file), 'r') as fin:
  240. for i, line in enumerate(fin.readlines()):
  241. self.vocab[line.strip()] = i
  242. #if quantize:
  243. # model_file = os.path.join(model_dir, 'model_quant.onnx')
  244. #if not os.path.exists(model_file):
  245. # logging.error(".onnx model not exist, please export first.")
  246. config_file = os.path.join(model_dir, 'config.yaml')
  247. cmvn_file = os.path.join(model_dir, 'am.mvn')
  248. config = read_yaml(config_file)
  249. self.converter = TokenIDConverter(config['token_list'])
  250. self.tokenizer = CharTokenizer()
  251. self.frontend = WavFrontend(
  252. cmvn_file=cmvn_file,
  253. **config['frontend_conf']
  254. )
  255. self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads)
  256. self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads)
  257. self.batch_size = batch_size
  258. self.plot_timestamp_to = plot_timestamp_to
  259. if "predictor_bias" in config['model_conf'].keys():
  260. self.pred_bias = config['model_conf']['predictor_bias']
  261. else:
  262. self.pred_bias = 0
  263. def __call__(self,
  264. wav_content: Union[str, np.ndarray, List[str]],
  265. hotwords: str,
  266. **kwargs) -> List:
  267. # make hotword list
  268. hotwords, hotwords_length = self.proc_hotword(hotwords)
  269. # import pdb; pdb.set_trace()
  270. [bias_embed] = self.eb_infer(hotwords, hotwords_length)
  271. # index from bias_embed
  272. bias_embed = bias_embed.transpose(1, 0, 2)
  273. _ind = np.arange(0, len(hotwords)).tolist()
  274. bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
  275. waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
  276. waveform_nums = len(waveform_list)
  277. asr_res = []
  278. for beg_idx in range(0, waveform_nums, self.batch_size):
  279. end_idx = min(waveform_nums, beg_idx + self.batch_size)
  280. feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
  281. bias_embed = np.expand_dims(bias_embed, axis=0)
  282. bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
  283. try:
  284. outputs = self.bb_infer(feats, feats_len, bias_embed)
  285. am_scores, valid_token_lens = outputs[0], outputs[1]
  286. except ONNXRuntimeError:
  287. #logging.warning(traceback.format_exc())
  288. logging.warning("input wav is silence or noise")
  289. preds = ['']
  290. else:
  291. preds = self.decode(am_scores, valid_token_lens)
  292. for pred in preds:
  293. pred = sentence_postprocess(pred)
  294. asr_res.append({'preds': pred})
  295. return asr_res
  296. def proc_hotword(self, hotwords):
  297. hotwords = hotwords.split(" ")
  298. hotwords_length = [len(i) - 1 for i in hotwords]
  299. hotwords_length.append(0)
  300. hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
  301. # hotwords.append('<s>')
  302. def word_map(word):
  303. hotwords = []
  304. for c in word:
  305. if c not in self.vocab.keys():
  306. hotwords.append(8403)
  307. logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
  308. else:
  309. hotwords.append(self.vocab[c])
  310. return torch.tensor(hotwords)
  311. hotword_int = [word_map(i) for i in hotwords]
  312. # import pdb; pdb.set_trace()
  313. hotword_int.append(torch.tensor([1]))
  314. hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
  315. return hotwords, hotwords_length
  316. def bb_infer(self, feats: np.ndarray,
  317. feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]:
  318. outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
  319. return outputs
  320. def eb_infer(self, hotwords, hotwords_length):
  321. outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
  322. return outputs
  323. def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
  324. return [self.decode_one(am_score, token_num)
  325. for am_score, token_num in zip(am_scores, token_nums)]
  326. def decode_one(self,
  327. am_score: np.ndarray,
  328. valid_token_num: int) -> List[str]:
  329. yseq = am_score.argmax(axis=-1)
  330. score = am_score.max(axis=-1)
  331. score = np.sum(score, axis=-1)
  332. # pad with mask tokens to ensure compatibility with sos/eos tokens
  333. # asr_model.sos:1 asr_model.eos:2
  334. yseq = np.array([1] + yseq.tolist() + [2])
  335. hyp = Hypothesis(yseq=yseq, score=score)
  336. # remove sos/eos and get results
  337. last_pos = -1
  338. token_int = hyp.yseq[1:last_pos].tolist()
  339. # remove blank symbol id, which is assumed to be 0
  340. token_int = list(filter(lambda x: x not in (0, 2), token_int))
  341. # Change integer-ids to tokens
  342. token = self.converter.ids2tokens(token_int)
  343. token = token[:valid_token_num-self.pred_bias]
  344. # texts = sentence_postprocess(token)
  345. return token