paraformer_online_bin.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # -*- encoding: utf-8 -*-
  2. import os.path
  3. from pathlib import Path
  4. from typing import List, Union, Tuple
  5. import copy
  6. import librosa
  7. import numpy as np
  8. from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
  9. OrtInferSession, TokenIDConverter, get_logger,
  10. read_yaml)
  11. from .utils.postprocess_utils import sentence_postprocess
  12. from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
  13. logging = get_logger()
  14. class Paraformer():
  15. def __init__(self, model_dir: Union[str, Path] = None,
  16. batch_size: int = 1,
  17. chunk_size: List = [5, 10, 5],
  18. device_id: Union[str, int] = "-1",
  19. quantize: bool = False,
  20. intra_op_num_threads: int = 4,
  21. cache_dir: str = None
  22. ):
  23. if not Path(model_dir).exists():
  24. try:
  25. from modelscope.hub.snapshot_download import snapshot_download
  26. except:
  27. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  28. "\npip3 install -U modelscope\n" \
  29. "For the users in China, you could install with the command:\n" \
  30. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  31. try:
  32. model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
  33. except:
  34. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
  35. encoder_model_file = os.path.join(model_dir, 'model.onnx')
  36. decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
  37. if quantize:
  38. encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
  39. decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
  40. if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
  41. print(".onnx is not exist, begin to export onnx")
  42. try:
  43. from funasr.export.export_model import ModelExport
  44. except:
  45. raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
  46. "\npip3 install -U funasr\n" \
  47. "For the users in China, you could install with the command:\n" \
  48. "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  49. export_model = ModelExport(
  50. cache_dir=cache_dir,
  51. onnx=True,
  52. device="cpu",
  53. quant=quantize,
  54. )
  55. export_model.export(model_dir)
  56. config_file = os.path.join(model_dir, 'config.yaml')
  57. cmvn_file = os.path.join(model_dir, 'am.mvn')
  58. config = read_yaml(config_file)
  59. self.converter = TokenIDConverter(config['token_list'])
  60. self.tokenizer = CharTokenizer()
  61. self.frontend = WavFrontendOnline(
  62. cmvn_file=cmvn_file,
  63. **config['frontend_conf']
  64. )
  65. self.pe = SinusoidalPositionEncoderOnline()
  66. self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
  67. intra_op_num_threads=intra_op_num_threads)
  68. self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
  69. intra_op_num_threads=intra_op_num_threads)
  70. self.batch_size = batch_size
  71. self.chunk_size = chunk_size
  72. self.encoder_output_size = config["encoder_conf"]["output_size"]
  73. self.fsmn_layer = config["decoder_conf"]["num_blocks"]
  74. self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
  75. self.fsmn_dims = config["encoder_conf"]["output_size"]
  76. self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  77. self.cif_threshold = config["predictor_conf"]["threshold"]
  78. self.tail_threshold = config["predictor_conf"]["tail_threshold"]
  79. def prepare_cache(self, cache: dict = {}, batch_size=1):
  80. if len(cache) > 0:
  81. return cache
  82. cache["start_idx"] = 0
  83. cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
  84. cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
  85. cache["chunk_size"] = self.chunk_size
  86. cache["last_chunk"] = False
  87. cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
  88. cache["decoder_fsmn"] = []
  89. for i in range(self.fsmn_layer):
  90. fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
  91. cache["decoder_fsmn"].append(fsmn_cache)
  92. return cache
  93. def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  94. if len(cache) == 0:
  95. return feats
  96. # process last chunk
  97. overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
  98. if cache["is_final"]:
  99. cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
  100. if not cache["last_chunk"]:
  101. padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
  102. overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
  103. else:
  104. cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
  105. return overlap_feats
  106. def __call__(self, audio_in: np.ndarray, **kwargs):
  107. waveforms = np.expand_dims(audio_in, axis=0)
  108. param_dict = kwargs.get('param_dict', dict())
  109. is_final = param_dict.get('is_final', False)
  110. cache = param_dict.get('cache', dict())
  111. asr_res = []
  112. if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
  113. cache["last_chunk"] = True
  114. feats = cache["feats"]
  115. feats_len = np.array([feats.shape[1]]).astype(np.int32)
  116. asr_res = self.infer(feats, feats_len, cache)
  117. return asr_res
  118. feats, feats_len = self.extract_feat(waveforms, is_final)
  119. if feats.shape[1] != 0:
  120. feats *= self.encoder_output_size ** 0.5
  121. cache = self.prepare_cache(cache)
  122. cache["is_final"] = is_final
  123. # fbank -> position encoding -> overlap chunk
  124. feats = self.pe.forward(feats, cache["start_idx"])
  125. cache["start_idx"] += feats.shape[1]
  126. if is_final:
  127. if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
  128. cache["last_chunk"] = True
  129. feats = self.add_overlap_chunk(feats, cache)
  130. else:
  131. # first chunk
  132. feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
  133. feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
  134. asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
  135. # last chunk
  136. cache["last_chunk"] = True
  137. feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
  138. feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
  139. asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
  140. asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
  141. res = {}
  142. for pred in asr_res_chunk:
  143. for key, value in pred.items():
  144. if key in res:
  145. res[key][0] += value[0]
  146. res[key][1].extend(value[1])
  147. else:
  148. res[key] = [value[0], value[1]]
  149. return [res]
  150. else:
  151. feats = self.add_overlap_chunk(feats, cache)
  152. feats_len = np.array([feats.shape[1]]).astype(np.int32)
  153. asr_res = self.infer(feats, feats_len, cache)
  154. return asr_res
  155. def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
  156. # encoder forward
  157. enc_input = [feats, feats_len]
  158. enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
  159. # predictor forward
  160. acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
  161. # decoder forward
  162. asr_res = []
  163. if acoustic_embeds.shape[1] > 0:
  164. dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
  165. dec_input.extend(cache["decoder_fsmn"])
  166. dec_output = self.ort_decoder_infer(dec_input)
  167. logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
  168. cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
  169. preds = self.decode(logits, acoustic_embeds_len)
  170. for pred in preds:
  171. pred = sentence_postprocess(pred)
  172. asr_res.append({'preds': pred})
  173. return asr_res
  174. def load_data(self,
  175. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  176. def load_wav(path: str) -> np.ndarray:
  177. waveform, _ = librosa.load(path, sr=fs)
  178. return waveform
  179. if isinstance(wav_content, np.ndarray):
  180. return [wav_content]
  181. if isinstance(wav_content, str):
  182. return [load_wav(wav_content)]
  183. if isinstance(wav_content, list):
  184. return [load_wav(path) for path in wav_content]
  185. raise TypeError(
  186. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  187. def extract_feat(self,
  188. waveforms: np.ndarray, is_final: bool = False
  189. ) -> Tuple[np.ndarray, np.ndarray]:
  190. waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
  191. for idx, waveform in enumerate(waveforms):
  192. waveforms_lens[idx] = waveform.shape[-1]
  193. feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
  194. return feats.astype(np.float32), feats_len.astype(np.int32)
  195. def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
  196. return [self.decode_one(am_score, token_num)
  197. for am_score, token_num in zip(am_scores, token_nums)]
  198. def decode_one(self,
  199. am_score: np.ndarray,
  200. valid_token_num: int) -> List[str]:
  201. yseq = am_score.argmax(axis=-1)
  202. score = am_score.max(axis=-1)
  203. score = np.sum(score, axis=-1)
  204. # pad with mask tokens to ensure compatibility with sos/eos tokens
  205. # asr_model.sos:1 asr_model.eos:2
  206. yseq = np.array([1] + yseq.tolist() + [2])
  207. hyp = Hypothesis(yseq=yseq, score=score)
  208. # remove sos/eos and get results
  209. last_pos = -1
  210. token_int = hyp.yseq[1:last_pos].tolist()
  211. # remove blank symbol id, which is assumed to be 0
  212. token_int = list(filter(lambda x: x not in (0, 2), token_int))
  213. # Change integer-ids to tokens
  214. token = self.converter.ids2tokens(token_int)
  215. token = token[:valid_token_num]
  216. # texts = sentence_postprocess(token)
  217. return token
  218. def cif_search(self, hidden, alphas, cache=None):
  219. batch_size, len_time, hidden_size = hidden.shape
  220. token_length = []
  221. list_fires = []
  222. list_frames = []
  223. cache_alphas = []
  224. cache_hiddens = []
  225. alphas[:, :self.chunk_size[0]] = 0.0
  226. alphas[:, sum(self.chunk_size[:2]):] = 0.0
  227. if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
  228. hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
  229. alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
  230. if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
  231. tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
  232. tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
  233. tail_alphas =np.tile(tail_alphas, (batch_size, 1))
  234. hidden = np.concatenate((hidden, tail_hidden), axis=1)
  235. alphas = np.concatenate((alphas, tail_alphas), axis=1)
  236. len_time = alphas.shape[1]
  237. for b in range(batch_size):
  238. integrate = 0.0
  239. frames = np.zeros(hidden_size).astype(np.float32)
  240. list_frame = []
  241. list_fire = []
  242. for t in range(len_time):
  243. alpha = alphas[b][t]
  244. if alpha + integrate < self.cif_threshold:
  245. integrate += alpha
  246. list_fire.append(integrate)
  247. frames += alpha * hidden[b][t]
  248. else:
  249. frames += (self.cif_threshold - integrate) * hidden[b][t]
  250. list_frame.append(frames)
  251. integrate += alpha
  252. list_fire.append(integrate)
  253. integrate -= self.cif_threshold
  254. frames = integrate * hidden[b][t]
  255. cache_alphas.append(integrate)
  256. if integrate > 0.0:
  257. cache_hiddens.append(frames / integrate)
  258. else:
  259. cache_hiddens.append(frames)
  260. token_length.append(len(list_frame))
  261. list_fires.append(list_fire)
  262. list_frames.append(list_frame)
  263. max_token_len = max(token_length)
  264. list_ls = []
  265. for b in range(batch_size):
  266. pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
  267. if token_length[b] == 0:
  268. list_ls.append(pad_frames)
  269. else:
  270. list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
  271. cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
  272. cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
  273. cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
  274. cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
  275. return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)