asr_infer.py 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import codecs
  6. import copy
  7. import logging
  8. import os
  9. import re
  10. import tempfile
  11. from pathlib import Path
  12. from typing import Any
  13. from typing import Dict
  14. from typing import List
  15. from typing import Optional
  16. from typing import Tuple
  17. from typing import Union
  18. import numpy as np
  19. import requests
  20. import torch
  21. from packaging.version import parse as V
  22. from typeguard import check_argument_types
  23. from typeguard import check_return_type
  24. from funasr.build_utils.build_model_from_file import build_model_from_file
  25. from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
  26. from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
  27. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  28. from funasr.modules.beam_search.beam_search import BeamSearch
  29. from funasr.modules.beam_search.beam_search import Hypothesis
  30. from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
  31. from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
  32. from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
  33. from funasr.modules.scorers.ctc import CTCPrefixScorer
  34. from funasr.modules.scorers.length_bonus import LengthBonus
  35. from funasr.build_utils.build_asr_model import frontend_choices
  36. from funasr.text.build_tokenizer import build_tokenizer
  37. from funasr.text.token_id_converter import TokenIDConverter
  38. from funasr.torch_utils.device_funcs import to_device
  39. from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
  40. class Speech2Text:
  41. """Speech2Text class
  42. Examples:
  43. >>> import soundfile
  44. >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
  45. >>> audio, rate = soundfile.read("speech.wav")
  46. >>> speech2text(audio)
  47. [(text, token, token_int, hypothesis object), ...]
  48. """
  49. def __init__(
  50. self,
  51. asr_train_config: Union[Path, str] = None,
  52. asr_model_file: Union[Path, str] = None,
  53. cmvn_file: Union[Path, str] = None,
  54. lm_train_config: Union[Path, str] = None,
  55. lm_file: Union[Path, str] = None,
  56. token_type: str = None,
  57. bpemodel: str = None,
  58. device: str = "cpu",
  59. maxlenratio: float = 0.0,
  60. minlenratio: float = 0.0,
  61. batch_size: int = 1,
  62. dtype: str = "float32",
  63. beam_size: int = 20,
  64. ctc_weight: float = 0.5,
  65. lm_weight: float = 1.0,
  66. ngram_weight: float = 0.9,
  67. penalty: float = 0.0,
  68. nbest: int = 1,
  69. streaming: bool = False,
  70. frontend_conf: dict = None,
  71. **kwargs,
  72. ):
  73. assert check_argument_types()
  74. # 1. Build ASR model
  75. scorers = {}
  76. asr_model, asr_train_args = build_model_from_file(
  77. asr_train_config, asr_model_file, cmvn_file, device, mode="asr"
  78. )
  79. frontend = None
  80. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  81. if asr_train_args.frontend == 'wav_frontend':
  82. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  83. else:
  84. frontend_class = frontend_choices.get_class(asr_train_args.frontend)
  85. frontend = frontend_class(**asr_train_args.frontend_conf).eval()
  86. logging.info("asr_model: {}".format(asr_model))
  87. logging.info("asr_train_args: {}".format(asr_train_args))
  88. asr_model.to(dtype=getattr(torch, dtype)).eval()
  89. decoder = asr_model.decoder
  90. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  91. token_list = asr_model.token_list
  92. scorers.update(
  93. decoder=decoder,
  94. ctc=ctc,
  95. length_bonus=LengthBonus(len(token_list)),
  96. )
  97. # 2. Build Language model
  98. if lm_train_config is not None:
  99. lm, lm_train_args = build_model_from_file(
  100. lm_train_config, lm_file, None, device
  101. )
  102. scorers["lm"] = lm.lm
  103. # 3. Build ngram model
  104. # ngram is not supported now
  105. ngram = None
  106. scorers["ngram"] = ngram
  107. # 4. Build BeamSearch object
  108. # transducer is not supported now
  109. beam_search_transducer = None
  110. from funasr.modules.beam_search.beam_search import BeamSearch
  111. weights = dict(
  112. decoder=1.0 - ctc_weight,
  113. ctc=ctc_weight,
  114. lm=lm_weight,
  115. ngram=ngram_weight,
  116. length_bonus=penalty,
  117. )
  118. beam_search = BeamSearch(
  119. beam_size=beam_size,
  120. weights=weights,
  121. scorers=scorers,
  122. sos=asr_model.sos,
  123. eos=asr_model.eos,
  124. vocab_size=len(token_list),
  125. token_list=token_list,
  126. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  127. )
  128. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  129. if token_type is None:
  130. token_type = asr_train_args.token_type
  131. if bpemodel is None:
  132. bpemodel = asr_train_args.bpemodel
  133. if token_type is None:
  134. tokenizer = None
  135. elif token_type == "bpe":
  136. if bpemodel is not None:
  137. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  138. else:
  139. tokenizer = None
  140. else:
  141. tokenizer = build_tokenizer(token_type=token_type)
  142. converter = TokenIDConverter(token_list=token_list)
  143. logging.info(f"Text tokenizer: {tokenizer}")
  144. self.asr_model = asr_model
  145. self.asr_train_args = asr_train_args
  146. self.converter = converter
  147. self.tokenizer = tokenizer
  148. self.beam_search = beam_search
  149. self.beam_search_transducer = beam_search_transducer
  150. self.maxlenratio = maxlenratio
  151. self.minlenratio = minlenratio
  152. self.device = device
  153. self.dtype = dtype
  154. self.nbest = nbest
  155. self.frontend = frontend
  156. @torch.no_grad()
  157. def __call__(
  158. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  159. ) -> List[
  160. Tuple[
  161. Optional[str],
  162. List[str],
  163. List[int],
  164. Union[Hypothesis],
  165. ]
  166. ]:
  167. """Inference
  168. Args:
  169. speech: Input speech data
  170. Returns:
  171. text, token, token_int, hyp
  172. """
  173. assert check_argument_types()
  174. # Input as audio signal
  175. if isinstance(speech, np.ndarray):
  176. speech = torch.tensor(speech)
  177. if self.frontend is not None:
  178. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  179. feats = to_device(feats, device=self.device)
  180. feats_len = feats_len.int()
  181. self.asr_model.frontend = None
  182. else:
  183. feats = speech
  184. feats_len = speech_lengths
  185. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  186. batch = {"speech": feats, "speech_lengths": feats_len}
  187. # a. To device
  188. batch = to_device(batch, device=self.device)
  189. # b. Forward Encoder
  190. enc, _ = self.asr_model.encode(**batch)
  191. if isinstance(enc, tuple):
  192. enc = enc[0]
  193. assert len(enc) == 1, len(enc)
  194. # c. Passed the encoder result and the beam search
  195. nbest_hyps = self.beam_search(
  196. x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  197. )
  198. nbest_hyps = nbest_hyps[: self.nbest]
  199. results = []
  200. for hyp in nbest_hyps:
  201. assert isinstance(hyp, (Hypothesis)), type(hyp)
  202. # remove sos/eos and get results
  203. last_pos = -1
  204. if isinstance(hyp.yseq, list):
  205. token_int = hyp.yseq[1:last_pos]
  206. else:
  207. token_int = hyp.yseq[1:last_pos].tolist()
  208. # remove blank symbol id, which is assumed to be 0
  209. token_int = list(filter(lambda x: x != 0, token_int))
  210. # Change integer-ids to tokens
  211. token = self.converter.ids2tokens(token_int)
  212. if self.tokenizer is not None:
  213. text = self.tokenizer.tokens2text(token)
  214. else:
  215. text = None
  216. results.append((text, token, token_int, hyp))
  217. assert check_return_type(results)
  218. return results
  219. class Speech2TextParaformer:
  220. """Speech2Text class
  221. Examples:
  222. >>> import soundfile
  223. >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
  224. >>> audio, rate = soundfile.read("speech.wav")
  225. >>> speech2text(audio)
  226. [(text, token, token_int, hypothesis object), ...]
  227. """
  228. def __init__(
  229. self,
  230. asr_train_config: Union[Path, str] = None,
  231. asr_model_file: Union[Path, str] = None,
  232. cmvn_file: Union[Path, str] = None,
  233. lm_train_config: Union[Path, str] = None,
  234. lm_file: Union[Path, str] = None,
  235. token_type: str = None,
  236. bpemodel: str = None,
  237. device: str = "cpu",
  238. maxlenratio: float = 0.0,
  239. minlenratio: float = 0.0,
  240. dtype: str = "float32",
  241. beam_size: int = 20,
  242. ctc_weight: float = 0.5,
  243. lm_weight: float = 1.0,
  244. ngram_weight: float = 0.9,
  245. penalty: float = 0.0,
  246. nbest: int = 1,
  247. frontend_conf: dict = None,
  248. hotword_list_or_file: str = None,
  249. decoding_ind: int = 0,
  250. **kwargs,
  251. ):
  252. assert check_argument_types()
  253. # 1. Build ASR model
  254. scorers = {}
  255. asr_model, asr_train_args = build_model_from_file(
  256. asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
  257. )
  258. frontend = None
  259. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  260. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  261. logging.info("asr_model: {}".format(asr_model))
  262. logging.info("asr_train_args: {}".format(asr_train_args))
  263. asr_model.to(dtype=getattr(torch, dtype)).eval()
  264. if asr_model.ctc != None:
  265. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  266. scorers.update(
  267. ctc=ctc
  268. )
  269. token_list = asr_model.token_list
  270. scorers.update(
  271. length_bonus=LengthBonus(len(token_list)),
  272. )
  273. # 2. Build Language model
  274. if lm_train_config is not None:
  275. lm, lm_train_args = build_model_from_file(
  276. lm_train_config, lm_file, device
  277. )
  278. scorers["lm"] = lm.lm
  279. # 3. Build ngram model
  280. # ngram is not supported now
  281. ngram = None
  282. scorers["ngram"] = ngram
  283. # 4. Build BeamSearch object
  284. # transducer is not supported now
  285. beam_search_transducer = None
  286. from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
  287. weights = dict(
  288. decoder=1.0 - ctc_weight,
  289. ctc=ctc_weight,
  290. lm=lm_weight,
  291. ngram=ngram_weight,
  292. length_bonus=penalty,
  293. )
  294. beam_search = BeamSearch(
  295. beam_size=beam_size,
  296. weights=weights,
  297. scorers=scorers,
  298. sos=asr_model.sos,
  299. eos=asr_model.eos,
  300. vocab_size=len(token_list),
  301. token_list=token_list,
  302. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  303. )
  304. beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
  305. for scorer in scorers.values():
  306. if isinstance(scorer, torch.nn.Module):
  307. scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
  308. logging.info(f"Decoding device={device}, dtype={dtype}")
  309. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  310. if token_type is None:
  311. token_type = asr_train_args.token_type
  312. if bpemodel is None:
  313. bpemodel = asr_train_args.bpemodel
  314. if token_type is None:
  315. tokenizer = None
  316. elif token_type == "bpe":
  317. if bpemodel is not None:
  318. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  319. else:
  320. tokenizer = None
  321. else:
  322. tokenizer = build_tokenizer(token_type=token_type)
  323. converter = TokenIDConverter(token_list=token_list)
  324. logging.info(f"Text tokenizer: {tokenizer}")
  325. self.asr_model = asr_model
  326. self.asr_train_args = asr_train_args
  327. self.converter = converter
  328. self.tokenizer = tokenizer
  329. # 6. [Optional] Build hotword list from str, local file or url
  330. self.hotword_list = None
  331. self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
  332. is_use_lm = lm_weight != 0.0 and lm_file is not None
  333. if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
  334. beam_search = None
  335. self.beam_search = beam_search
  336. logging.info(f"Beam_search: {self.beam_search}")
  337. self.beam_search_transducer = beam_search_transducer
  338. self.maxlenratio = maxlenratio
  339. self.minlenratio = minlenratio
  340. self.device = device
  341. self.dtype = dtype
  342. self.nbest = nbest
  343. self.frontend = frontend
  344. self.encoder_downsampling_factor = 1
  345. self.decoding_ind = decoding_ind
  346. if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
  347. self.encoder_downsampling_factor = 4
  348. @torch.no_grad()
  349. def __call__(
  350. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
  351. begin_time: int = 0, end_time: int = None,
  352. ):
  353. """Inference
  354. Args:
  355. speech: Input speech data
  356. Returns:
  357. text, token, token_int, hyp
  358. """
  359. assert check_argument_types()
  360. # Input as audio signal
  361. if isinstance(speech, np.ndarray):
  362. speech = torch.tensor(speech)
  363. if self.frontend is not None:
  364. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  365. feats = to_device(feats, device=self.device)
  366. feats_len = feats_len.int()
  367. self.asr_model.frontend = None
  368. else:
  369. feats = speech
  370. feats_len = speech_lengths
  371. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  372. batch = {"speech": feats, "speech_lengths": feats_len}
  373. # a. To device
  374. batch = to_device(batch, device=self.device)
  375. # b. Forward Encoder
  376. enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
  377. if isinstance(enc, tuple):
  378. enc = enc[0]
  379. # assert len(enc) == 1, len(enc)
  380. enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
  381. predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
  382. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  383. predictor_outs[2], predictor_outs[3]
  384. pre_token_length = pre_token_length.round().long()
  385. if torch.max(pre_token_length) < 1:
  386. return []
  387. if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
  388. NeatContextualParaformer):
  389. if self.hotword_list:
  390. logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
  391. decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
  392. pre_token_length)
  393. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  394. else:
  395. decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
  396. pre_token_length, hw_list=self.hotword_list)
  397. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  398. if isinstance(self.asr_model, BiCifParaformer):
  399. _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
  400. pre_token_length) # test no bias cif2
  401. results = []
  402. b, n, d = decoder_out.size()
  403. for i in range(b):
  404. x = enc[i, :enc_len[i], :]
  405. am_scores = decoder_out[i, :pre_token_length[i], :]
  406. if self.beam_search is not None:
  407. nbest_hyps = self.beam_search(
  408. x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  409. )
  410. nbest_hyps = nbest_hyps[: self.nbest]
  411. else:
  412. if pre_token_length[i] == 0:
  413. yseq = torch.tensor(
  414. [self.asr_model.sos] + [self.asr_model.eos], device=pre_acoustic_embeds.device
  415. )
  416. score = torch.tensor(0.0, device=pre_acoustic_embeds.device)
  417. else:
  418. yseq = am_scores.argmax(dim=-1)
  419. score = am_scores.max(dim=-1)[0]
  420. score = torch.sum(score, dim=-1)
  421. # pad with mask tokens to ensure compatibility with sos/eos tokens
  422. yseq = torch.tensor(
  423. [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
  424. )
  425. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  426. for hyp in nbest_hyps:
  427. assert isinstance(hyp, (Hypothesis)), type(hyp)
  428. # remove sos/eos and get results
  429. last_pos = -1
  430. if isinstance(hyp.yseq, list):
  431. token_int = hyp.yseq[1:last_pos]
  432. else:
  433. token_int = hyp.yseq[1:last_pos].tolist()
  434. # remove blank symbol id, which is assumed to be 0
  435. token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
  436. # Change integer-ids to tokens
  437. token = self.converter.ids2tokens(token_int)
  438. if self.tokenizer is not None:
  439. text = self.tokenizer.tokens2text(token)
  440. else:
  441. text = None
  442. timestamp = []
  443. if isinstance(self.asr_model, BiCifParaformer):
  444. _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3],
  445. us_peaks[i][:enc_len[i] * 3],
  446. copy.copy(token),
  447. vad_offset=begin_time)
  448. results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
  449. # assert check_return_type(results)
  450. return results
  451. def generate_hotwords_list(self, hotword_list_or_file):
  452. # for None
  453. if hotword_list_or_file is None:
  454. hotword_list = None
  455. # for local txt inputs
  456. elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
  457. logging.info("Attempting to parse hotwords from local txt...")
  458. hotword_list = []
  459. hotword_str_list = []
  460. with codecs.open(hotword_list_or_file, 'r') as fin:
  461. for line in fin.readlines():
  462. hw = line.strip()
  463. hotword_str_list.append(hw)
  464. hotword_list.append(self.converter.tokens2ids([i for i in hw]))
  465. hotword_list.append([self.asr_model.sos])
  466. hotword_str_list.append('<s>')
  467. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  468. .format(hotword_list_or_file, hotword_str_list))
  469. # for url, download and generate txt
  470. elif hotword_list_or_file.startswith('http'):
  471. logging.info("Attempting to parse hotwords from url...")
  472. work_dir = tempfile.TemporaryDirectory().name
  473. if not os.path.exists(work_dir):
  474. os.makedirs(work_dir)
  475. text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
  476. local_file = requests.get(hotword_list_or_file)
  477. open(text_file_path, "wb").write(local_file.content)
  478. hotword_list_or_file = text_file_path
  479. hotword_list = []
  480. hotword_str_list = []
  481. with codecs.open(hotword_list_or_file, 'r') as fin:
  482. for line in fin.readlines():
  483. hw = line.strip()
  484. hotword_str_list.append(hw)
  485. hotword_list.append(self.converter.tokens2ids([i for i in hw]))
  486. hotword_list.append([self.asr_model.sos])
  487. hotword_str_list.append('<s>')
  488. logging.info("Initialized hotword list from file: {}, hotword list: {}."
  489. .format(hotword_list_or_file, hotword_str_list))
  490. # for text str input
  491. elif not hotword_list_or_file.endswith('.txt'):
  492. logging.info("Attempting to parse hotwords as str...")
  493. hotword_list = []
  494. hotword_str_list = []
  495. for hw in hotword_list_or_file.strip().split():
  496. hotword_str_list.append(hw)
  497. hotword_list.append(self.converter.tokens2ids([i for i in hw]))
  498. hotword_list.append([self.asr_model.sos])
  499. hotword_str_list.append('<s>')
  500. logging.info("Hotword list: {}.".format(hotword_str_list))
  501. else:
  502. hotword_list = None
  503. return hotword_list
  504. class Speech2TextParaformerOnline:
  505. """Speech2Text class
  506. Examples:
  507. >>> import soundfile
  508. >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
  509. >>> audio, rate = soundfile.read("speech.wav")
  510. >>> speech2text(audio)
  511. [(text, token, token_int, hypothesis object), ...]
  512. """
  513. def __init__(
  514. self,
  515. asr_train_config: Union[Path, str] = None,
  516. asr_model_file: Union[Path, str] = None,
  517. cmvn_file: Union[Path, str] = None,
  518. lm_train_config: Union[Path, str] = None,
  519. lm_file: Union[Path, str] = None,
  520. token_type: str = None,
  521. bpemodel: str = None,
  522. device: str = "cpu",
  523. maxlenratio: float = 0.0,
  524. minlenratio: float = 0.0,
  525. dtype: str = "float32",
  526. beam_size: int = 20,
  527. ctc_weight: float = 0.5,
  528. lm_weight: float = 1.0,
  529. ngram_weight: float = 0.9,
  530. penalty: float = 0.0,
  531. nbest: int = 1,
  532. frontend_conf: dict = None,
  533. hotword_list_or_file: str = None,
  534. **kwargs,
  535. ):
  536. assert check_argument_types()
  537. # 1. Build ASR model
  538. scorers = {}
  539. asr_model, asr_train_args = build_model_from_file(
  540. asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
  541. )
  542. frontend = None
  543. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  544. frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  545. logging.info("asr_model: {}".format(asr_model))
  546. logging.info("asr_train_args: {}".format(asr_train_args))
  547. asr_model.to(dtype=getattr(torch, dtype)).eval()
  548. if asr_model.ctc != None:
  549. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  550. scorers.update(
  551. ctc=ctc
  552. )
  553. token_list = asr_model.token_list
  554. scorers.update(
  555. length_bonus=LengthBonus(len(token_list)),
  556. )
  557. # 2. Build Language model
  558. if lm_train_config is not None:
  559. lm, lm_train_args = build_model_from_file(
  560. lm_train_config, lm_file, device
  561. )
  562. scorers["lm"] = lm.lm
  563. # 3. Build ngram model
  564. # ngram is not supported now
  565. ngram = None
  566. scorers["ngram"] = ngram
  567. # 4. Build BeamSearch object
  568. # transducer is not supported now
  569. beam_search_transducer = None
  570. from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
  571. weights = dict(
  572. decoder=1.0 - ctc_weight,
  573. ctc=ctc_weight,
  574. lm=lm_weight,
  575. ngram=ngram_weight,
  576. length_bonus=penalty,
  577. )
  578. beam_search = BeamSearch(
  579. beam_size=beam_size,
  580. weights=weights,
  581. scorers=scorers,
  582. sos=asr_model.sos,
  583. eos=asr_model.eos,
  584. vocab_size=len(token_list),
  585. token_list=token_list,
  586. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  587. )
  588. beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
  589. for scorer in scorers.values():
  590. if isinstance(scorer, torch.nn.Module):
  591. scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
  592. logging.info(f"Decoding device={device}, dtype={dtype}")
  593. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  594. if token_type is None:
  595. token_type = asr_train_args.token_type
  596. if bpemodel is None:
  597. bpemodel = asr_train_args.bpemodel
  598. if token_type is None:
  599. tokenizer = None
  600. elif token_type == "bpe":
  601. if bpemodel is not None:
  602. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  603. else:
  604. tokenizer = None
  605. else:
  606. tokenizer = build_tokenizer(token_type=token_type)
  607. converter = TokenIDConverter(token_list=token_list)
  608. logging.info(f"Text tokenizer: {tokenizer}")
  609. self.asr_model = asr_model
  610. self.asr_train_args = asr_train_args
  611. self.converter = converter
  612. self.tokenizer = tokenizer
  613. # 6. [Optional] Build hotword list from str, local file or url
  614. is_use_lm = lm_weight != 0.0 and lm_file is not None
  615. if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
  616. beam_search = None
  617. self.beam_search = beam_search
  618. logging.info(f"Beam_search: {self.beam_search}")
  619. self.beam_search_transducer = beam_search_transducer
  620. self.maxlenratio = maxlenratio
  621. self.minlenratio = minlenratio
  622. self.device = device
  623. self.dtype = dtype
  624. self.nbest = nbest
  625. self.frontend = frontend
  626. self.encoder_downsampling_factor = 1
  627. if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
  628. self.encoder_downsampling_factor = 4
  629. @torch.no_grad()
  630. def __call__(
  631. self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None
  632. ):
  633. """Inference
  634. Args:
  635. speech: Input speech data
  636. Returns:
  637. text, token, token_int, hyp
  638. """
  639. assert check_argument_types()
  640. results = []
  641. cache_en = cache["encoder"]
  642. if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
  643. if cache_en["start_idx"] == 0:
  644. return []
  645. cache_en["tail_chunk"] = True
  646. feats = cache_en["feats"]
  647. feats_len = torch.tensor([feats.shape[1]])
  648. self.asr_model.frontend = None
  649. self.frontend.cache_reset()
  650. results = self.infer(feats, feats_len, cache)
  651. return results
  652. else:
  653. if self.frontend is not None:
  654. if cache_en["start_idx"] == 0:
  655. self.frontend.cache_reset()
  656. feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
  657. feats = to_device(feats, device=self.device)
  658. feats_len = feats_len.int()
  659. self.asr_model.frontend = None
  660. else:
  661. feats = speech
  662. feats_len = speech_lengths
  663. if feats.shape[1] != 0:
  664. results = self.infer(feats, feats_len, cache)
  665. return results
  666. @torch.no_grad()
  667. def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None):
  668. batch = {"speech": feats, "speech_lengths": feats_len}
  669. batch = to_device(batch, device=self.device)
  670. # b. Forward Encoder
  671. enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache)
  672. if isinstance(enc, tuple):
  673. enc = enc[0]
  674. # assert len(enc) == 1, len(enc)
  675. enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
  676. predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
  677. pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
  678. if torch.max(pre_token_length) < 1:
  679. return []
  680. decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
  681. decoder_out = decoder_outs
  682. results = []
  683. b, n, d = decoder_out.size()
  684. for i in range(b):
  685. x = enc[i, :enc_len[i], :]
  686. am_scores = decoder_out[i, :pre_token_length[i], :]
  687. if self.beam_search is not None:
  688. nbest_hyps = self.beam_search(
  689. x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  690. )
  691. nbest_hyps = nbest_hyps[: self.nbest]
  692. else:
  693. yseq = am_scores.argmax(dim=-1)
  694. score = am_scores.max(dim=-1)[0]
  695. score = torch.sum(score, dim=-1)
  696. # pad with mask tokens to ensure compatibility with sos/eos tokens
  697. yseq = torch.tensor(
  698. [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
  699. )
  700. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  701. for hyp in nbest_hyps:
  702. assert isinstance(hyp, (Hypothesis)), type(hyp)
  703. # remove sos/eos and get results
  704. last_pos = -1
  705. if isinstance(hyp.yseq, list):
  706. token_int = hyp.yseq[1:last_pos]
  707. else:
  708. token_int = hyp.yseq[1:last_pos].tolist()
  709. # remove blank symbol id, which is assumed to be 0
  710. token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
  711. # Change integer-ids to tokens
  712. token = self.converter.ids2tokens(token_int)
  713. postprocessed_result = ""
  714. for item in token:
  715. if item.endswith('@@'):
  716. postprocessed_result += item[:-2]
  717. elif re.match('^[a-zA-Z]+$', item):
  718. postprocessed_result += item + " "
  719. else:
  720. postprocessed_result += item
  721. results.append(postprocessed_result)
  722. # assert check_return_type(results)
  723. return results
  724. class Speech2TextUniASR:
  725. """Speech2Text class
  726. Examples:
  727. >>> import soundfile
  728. >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
  729. >>> audio, rate = soundfile.read("speech.wav")
  730. >>> speech2text(audio)
  731. [(text, token, token_int, hypothesis object), ...]
  732. """
  733. def __init__(
  734. self,
  735. asr_train_config: Union[Path, str] = None,
  736. asr_model_file: Union[Path, str] = None,
  737. cmvn_file: Union[Path, str] = None,
  738. lm_train_config: Union[Path, str] = None,
  739. lm_file: Union[Path, str] = None,
  740. token_type: str = None,
  741. bpemodel: str = None,
  742. device: str = "cpu",
  743. maxlenratio: float = 0.0,
  744. minlenratio: float = 0.0,
  745. dtype: str = "float32",
  746. beam_size: int = 20,
  747. ctc_weight: float = 0.5,
  748. lm_weight: float = 1.0,
  749. ngram_weight: float = 0.9,
  750. penalty: float = 0.0,
  751. nbest: int = 1,
  752. token_num_relax: int = 1,
  753. decoding_ind: int = 0,
  754. decoding_mode: str = "model1",
  755. frontend_conf: dict = None,
  756. **kwargs,
  757. ):
  758. assert check_argument_types()
  759. # 1. Build ASR model
  760. scorers = {}
  761. asr_model, asr_train_args = build_model_from_file(
  762. asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
  763. )
  764. frontend = None
  765. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  766. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  767. logging.info("asr_train_args: {}".format(asr_train_args))
  768. asr_model.to(dtype=getattr(torch, dtype)).eval()
  769. if decoding_mode == "model1":
  770. decoder = asr_model.decoder
  771. else:
  772. decoder = asr_model.decoder2
  773. if asr_model.ctc != None:
  774. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  775. scorers.update(
  776. ctc=ctc
  777. )
  778. token_list = asr_model.token_list
  779. scorers.update(
  780. decoder=decoder,
  781. length_bonus=LengthBonus(len(token_list)),
  782. )
  783. # 2. Build Language model
  784. if lm_train_config is not None:
  785. lm, lm_train_args = build_model_from_file(
  786. lm_train_config, lm_file, device, "lm"
  787. )
  788. scorers["lm"] = lm.lm
  789. # 3. Build ngram model
  790. # ngram is not supported now
  791. ngram = None
  792. scorers["ngram"] = ngram
  793. # 4. Build BeamSearch object
  794. # transducer is not supported now
  795. beam_search_transducer = None
  796. from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
  797. weights = dict(
  798. decoder=1.0 - ctc_weight,
  799. ctc=ctc_weight,
  800. lm=lm_weight,
  801. ngram=ngram_weight,
  802. length_bonus=penalty,
  803. )
  804. beam_search = BeamSearch(
  805. beam_size=beam_size,
  806. weights=weights,
  807. scorers=scorers,
  808. sos=asr_model.sos,
  809. eos=asr_model.eos,
  810. vocab_size=len(token_list),
  811. token_list=token_list,
  812. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  813. )
  814. beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
  815. for scorer in scorers.values():
  816. if isinstance(scorer, torch.nn.Module):
  817. scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
  818. # logging.info(f"Beam_search: {beam_search}")
  819. logging.info(f"Decoding device={device}, dtype={dtype}")
  820. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  821. if token_type is None:
  822. token_type = asr_train_args.token_type
  823. if bpemodel is None:
  824. bpemodel = asr_train_args.bpemodel
  825. if token_type is None:
  826. tokenizer = None
  827. elif token_type == "bpe":
  828. if bpemodel is not None:
  829. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  830. else:
  831. tokenizer = None
  832. else:
  833. tokenizer = build_tokenizer(token_type=token_type)
  834. converter = TokenIDConverter(token_list=token_list)
  835. logging.info(f"Text tokenizer: {tokenizer}")
  836. self.asr_model = asr_model
  837. self.asr_train_args = asr_train_args
  838. self.converter = converter
  839. self.tokenizer = tokenizer
  840. self.beam_search = beam_search
  841. self.beam_search_transducer = beam_search_transducer
  842. self.maxlenratio = maxlenratio
  843. self.minlenratio = minlenratio
  844. self.device = device
  845. self.dtype = dtype
  846. self.nbest = nbest
  847. self.token_num_relax = token_num_relax
  848. self.decoding_ind = decoding_ind
  849. self.decoding_mode = decoding_mode
  850. self.frontend = frontend
  851. @torch.no_grad()
  852. def __call__(
  853. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  854. ) -> List[
  855. Tuple[
  856. Optional[str],
  857. List[str],
  858. List[int],
  859. Union[Hypothesis],
  860. ]
  861. ]:
  862. """Inference
  863. Args:
  864. speech: Input speech data
  865. Returns:
  866. text, token, token_int, hyp
  867. """
  868. assert check_argument_types()
  869. # Input as audio signal
  870. if isinstance(speech, np.ndarray):
  871. speech = torch.tensor(speech)
  872. if self.frontend is not None:
  873. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  874. feats = to_device(feats, device=self.device)
  875. feats_len = feats_len.int()
  876. self.asr_model.frontend = None
  877. else:
  878. feats = speech
  879. feats_len = speech_lengths
  880. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  881. feats_raw = feats.clone().to(self.device)
  882. batch = {"speech": feats, "speech_lengths": feats_len}
  883. # a. To device
  884. batch = to_device(batch, device=self.device)
  885. # b. Forward Encoder
  886. _, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
  887. if isinstance(enc, tuple):
  888. enc = enc[0]
  889. assert len(enc) == 1, len(enc)
  890. if self.decoding_mode == "model1":
  891. predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
  892. else:
  893. enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
  894. predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
  895. scama_mask = predictor_outs[4]
  896. pre_token_length = predictor_outs[1]
  897. pre_acoustic_embeds = predictor_outs[0]
  898. maxlen = pre_token_length.sum().item() + self.token_num_relax
  899. minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
  900. # c. Passed the encoder result and the beam search
  901. nbest_hyps = self.beam_search(
  902. x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
  903. minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
  904. )
  905. nbest_hyps = nbest_hyps[: self.nbest]
  906. results = []
  907. for hyp in nbest_hyps:
  908. assert isinstance(hyp, (Hypothesis)), type(hyp)
  909. # remove sos/eos and get results
  910. last_pos = -1
  911. if isinstance(hyp.yseq, list):
  912. token_int = hyp.yseq[1:last_pos]
  913. else:
  914. token_int = hyp.yseq[1:last_pos].tolist()
  915. # remove blank symbol id, which is assumed to be 0
  916. token_int = list(filter(lambda x: x != 0, token_int))
  917. # Change integer-ids to tokens
  918. token = self.converter.ids2tokens(token_int)
  919. token = list(filter(lambda x: x != "<gbg>", token))
  920. if self.tokenizer is not None:
  921. text = self.tokenizer.tokens2text(token)
  922. else:
  923. text = None
  924. results.append((text, token, token_int, hyp))
  925. assert check_return_type(results)
  926. return results
  927. class Speech2TextMFCCA:
  928. """Speech2Text class
  929. Examples:
  930. >>> import soundfile
  931. >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
  932. >>> audio, rate = soundfile.read("speech.wav")
  933. >>> speech2text(audio)
  934. [(text, token, token_int, hypothesis object), ...]
  935. """
  936. def __init__(
  937. self,
  938. asr_train_config: Union[Path, str] = None,
  939. asr_model_file: Union[Path, str] = None,
  940. cmvn_file: Union[Path, str] = None,
  941. lm_train_config: Union[Path, str] = None,
  942. lm_file: Union[Path, str] = None,
  943. token_type: str = None,
  944. bpemodel: str = None,
  945. device: str = "cpu",
  946. maxlenratio: float = 0.0,
  947. minlenratio: float = 0.0,
  948. batch_size: int = 1,
  949. dtype: str = "float32",
  950. beam_size: int = 20,
  951. ctc_weight: float = 0.5,
  952. lm_weight: float = 1.0,
  953. ngram_weight: float = 0.9,
  954. penalty: float = 0.0,
  955. nbest: int = 1,
  956. streaming: bool = False,
  957. **kwargs,
  958. ):
  959. assert check_argument_types()
  960. # 1. Build ASR model
  961. scorers = {}
  962. asr_model, asr_train_args = build_model_from_file(
  963. asr_train_config, asr_model_file, cmvn_file, device
  964. )
  965. logging.info("asr_model: {}".format(asr_model))
  966. logging.info("asr_train_args: {}".format(asr_train_args))
  967. asr_model.to(dtype=getattr(torch, dtype)).eval()
  968. decoder = asr_model.decoder
  969. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  970. token_list = asr_model.token_list
  971. scorers.update(
  972. decoder=decoder,
  973. ctc=ctc,
  974. length_bonus=LengthBonus(len(token_list)),
  975. )
  976. # 2. Build Language model
  977. if lm_train_config is not None:
  978. lm, lm_train_args = build_model_from_file(
  979. lm_train_config, lm_file, device
  980. )
  981. lm.to(device)
  982. scorers["lm"] = lm.lm
  983. # 3. Build ngram model
  984. # ngram is not supported now
  985. ngram = None
  986. scorers["ngram"] = ngram
  987. # 4. Build BeamSearch object
  988. # transducer is not supported now
  989. beam_search_transducer = None
  990. weights = dict(
  991. decoder=1.0 - ctc_weight,
  992. ctc=ctc_weight,
  993. lm=lm_weight,
  994. ngram=ngram_weight,
  995. length_bonus=penalty,
  996. )
  997. beam_search = BeamSearch(
  998. beam_size=beam_size,
  999. weights=weights,
  1000. scorers=scorers,
  1001. sos=asr_model.sos,
  1002. eos=asr_model.eos,
  1003. vocab_size=len(token_list),
  1004. token_list=token_list,
  1005. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  1006. )
  1007. # beam_search.__class__ = BatchBeamSearch
  1008. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  1009. if token_type is None:
  1010. token_type = asr_train_args.token_type
  1011. if bpemodel is None:
  1012. bpemodel = asr_train_args.bpemodel
  1013. if token_type is None:
  1014. tokenizer = None
  1015. elif token_type == "bpe":
  1016. if bpemodel is not None:
  1017. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  1018. else:
  1019. tokenizer = None
  1020. else:
  1021. tokenizer = build_tokenizer(token_type=token_type)
  1022. converter = TokenIDConverter(token_list=token_list)
  1023. logging.info(f"Text tokenizer: {tokenizer}")
  1024. self.asr_model = asr_model
  1025. self.asr_train_args = asr_train_args
  1026. self.converter = converter
  1027. self.tokenizer = tokenizer
  1028. self.beam_search = beam_search
  1029. self.beam_search_transducer = beam_search_transducer
  1030. self.maxlenratio = maxlenratio
  1031. self.minlenratio = minlenratio
  1032. self.device = device
  1033. self.dtype = dtype
  1034. self.nbest = nbest
  1035. @torch.no_grad()
  1036. def __call__(
  1037. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  1038. ) -> List[
  1039. Tuple[
  1040. Optional[str],
  1041. List[str],
  1042. List[int],
  1043. Union[Hypothesis],
  1044. ]
  1045. ]:
  1046. """Inference
  1047. Args:
  1048. speech: Input speech data
  1049. Returns:
  1050. text, token, token_int, hyp
  1051. """
  1052. assert check_argument_types()
  1053. # Input as audio signal
  1054. if isinstance(speech, np.ndarray):
  1055. speech = torch.tensor(speech)
  1056. if (speech.dim() == 3):
  1057. speech = torch.squeeze(speech, 2)
  1058. # speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  1059. speech = speech.to(getattr(torch, self.dtype))
  1060. # lenghts: (1,)
  1061. lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
  1062. batch = {"speech": speech, "speech_lengths": lengths}
  1063. # a. To device
  1064. batch = to_device(batch, device=self.device)
  1065. # b. Forward Encoder
  1066. enc, _ = self.asr_model.encode(**batch)
  1067. assert len(enc) == 1, len(enc)
  1068. # c. Passed the encoder result and the beam search
  1069. nbest_hyps = self.beam_search(
  1070. x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  1071. )
  1072. nbest_hyps = nbest_hyps[: self.nbest]
  1073. results = []
  1074. for hyp in nbest_hyps:
  1075. assert isinstance(hyp, (Hypothesis)), type(hyp)
  1076. # remove sos/eos and get results
  1077. last_pos = -1
  1078. if isinstance(hyp.yseq, list):
  1079. token_int = hyp.yseq[1:last_pos]
  1080. else:
  1081. token_int = hyp.yseq[1:last_pos].tolist()
  1082. # remove blank symbol id, which is assumed to be 0
  1083. token_int = list(filter(lambda x: x != 0, token_int))
  1084. # Change integer-ids to tokens
  1085. token = self.converter.ids2tokens(token_int)
  1086. if self.tokenizer is not None:
  1087. text = self.tokenizer.tokens2text(token)
  1088. else:
  1089. text = None
  1090. results.append((text, token, token_int, hyp))
  1091. assert check_return_type(results)
  1092. return results
  1093. class Speech2TextTransducer:
  1094. """Speech2Text class for Transducer models.
  1095. Args:
  1096. asr_train_config: ASR model training config path.
  1097. asr_model_file: ASR model path.
  1098. beam_search_config: Beam search config path.
  1099. lm_train_config: Language Model training config path.
  1100. lm_file: Language Model config path.
  1101. token_type: Type of token units.
  1102. bpemodel: BPE model path.
  1103. device: Device to use for inference.
  1104. beam_size: Size of beam during search.
  1105. dtype: Data type.
  1106. lm_weight: Language model weight.
  1107. quantize_asr_model: Whether to apply dynamic quantization to ASR model.
  1108. quantize_modules: List of module names to apply dynamic quantization on.
  1109. quantize_dtype: Dynamic quantization data type.
  1110. nbest: Number of final hypothesis.
  1111. streaming: Whether to perform chunk-by-chunk inference.
  1112. chunk_size: Number of frames in chunk AFTER subsampling.
  1113. left_context: Number of frames in left context AFTER subsampling.
  1114. right_context: Number of frames in right context AFTER subsampling.
  1115. display_partial_hypotheses: Whether to display partial hypotheses.
  1116. """
  1117. def __init__(
  1118. self,
  1119. asr_train_config: Union[Path, str] = None,
  1120. asr_model_file: Union[Path, str] = None,
  1121. cmvn_file: Union[Path, str] = None,
  1122. beam_search_config: Dict[str, Any] = None,
  1123. lm_train_config: Union[Path, str] = None,
  1124. lm_file: Union[Path, str] = None,
  1125. token_type: str = None,
  1126. bpemodel: str = None,
  1127. device: str = "cpu",
  1128. beam_size: int = 5,
  1129. dtype: str = "float32",
  1130. lm_weight: float = 1.0,
  1131. quantize_asr_model: bool = False,
  1132. quantize_modules: List[str] = None,
  1133. quantize_dtype: str = "qint8",
  1134. nbest: int = 1,
  1135. streaming: bool = False,
  1136. simu_streaming: bool = False,
  1137. chunk_size: int = 16,
  1138. left_context: int = 32,
  1139. right_context: int = 0,
  1140. display_partial_hypotheses: bool = False,
  1141. ) -> None:
  1142. """Construct a Speech2Text object."""
  1143. super().__init__()
  1144. assert check_argument_types()
  1145. asr_model, asr_train_args = build_model_from_file(
  1146. asr_train_config, asr_model_file, cmvn_file, device
  1147. )
  1148. frontend = None
  1149. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  1150. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  1151. if quantize_asr_model:
  1152. if quantize_modules is not None:
  1153. if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
  1154. raise ValueError(
  1155. "Only 'Linear' and 'LSTM' modules are currently supported"
  1156. " by PyTorch and in --quantize_modules"
  1157. )
  1158. q_config = set([getattr(torch.nn, q) for q in quantize_modules])
  1159. else:
  1160. q_config = {torch.nn.Linear}
  1161. if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
  1162. raise ValueError(
  1163. "float16 dtype for dynamic quantization is not supported with torch"
  1164. " version < 1.5.0. Switching to qint8 dtype instead."
  1165. )
  1166. q_dtype = getattr(torch, quantize_dtype)
  1167. asr_model = torch.quantization.quantize_dynamic(
  1168. asr_model, q_config, dtype=q_dtype
  1169. ).eval()
  1170. else:
  1171. asr_model.to(dtype=getattr(torch, dtype)).eval()
  1172. if lm_train_config is not None:
  1173. lm, lm_train_args = build_model_from_file(
  1174. lm_train_config, lm_file, device
  1175. )
  1176. lm_scorer = lm.lm
  1177. else:
  1178. lm_scorer = None
  1179. # 4. Build BeamSearch object
  1180. if beam_search_config is None:
  1181. beam_search_config = {}
  1182. beam_search = BeamSearchTransducer(
  1183. asr_model.decoder,
  1184. asr_model.joint_network,
  1185. beam_size,
  1186. lm=lm_scorer,
  1187. lm_weight=lm_weight,
  1188. nbest=nbest,
  1189. **beam_search_config,
  1190. )
  1191. token_list = asr_model.token_list
  1192. if token_type is None:
  1193. token_type = asr_train_args.token_type
  1194. if bpemodel is None:
  1195. bpemodel = asr_train_args.bpemodel
  1196. if token_type is None:
  1197. tokenizer = None
  1198. elif token_type == "bpe":
  1199. if bpemodel is not None:
  1200. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  1201. else:
  1202. tokenizer = None
  1203. else:
  1204. tokenizer = build_tokenizer(token_type=token_type)
  1205. converter = TokenIDConverter(token_list=token_list)
  1206. logging.info(f"Text tokenizer: {tokenizer}")
  1207. self.asr_model = asr_model
  1208. self.asr_train_args = asr_train_args
  1209. self.device = device
  1210. self.dtype = dtype
  1211. self.nbest = nbest
  1212. self.converter = converter
  1213. self.tokenizer = tokenizer
  1214. self.beam_search = beam_search
  1215. self.streaming = streaming
  1216. self.simu_streaming = simu_streaming
  1217. self.chunk_size = max(chunk_size, 0)
  1218. self.left_context = left_context
  1219. self.right_context = max(right_context, 0)
  1220. if not streaming or chunk_size == 0:
  1221. self.streaming = False
  1222. self.asr_model.encoder.dynamic_chunk_training = False
  1223. if not simu_streaming or chunk_size == 0:
  1224. self.simu_streaming = False
  1225. self.asr_model.encoder.dynamic_chunk_training = False
  1226. self.frontend = frontend
  1227. self.window_size = self.chunk_size + self.right_context
  1228. if self.streaming:
  1229. self._ctx = self.asr_model.encoder.get_encoder_input_size(
  1230. self.window_size
  1231. )
  1232. self.last_chunk_length = (
  1233. self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
  1234. )
  1235. self.reset_inference_cache()
  1236. def reset_inference_cache(self) -> None:
  1237. """Reset Speech2Text parameters."""
  1238. self.frontend_cache = None
  1239. self.asr_model.encoder.reset_streaming_cache(
  1240. self.left_context, device=self.device
  1241. )
  1242. self.beam_search.reset_inference_cache()
  1243. self.num_processed_frames = torch.tensor([[0]], device=self.device)
  1244. @torch.no_grad()
  1245. def streaming_decode(
  1246. self,
  1247. speech: Union[torch.Tensor, np.ndarray],
  1248. is_final: bool = True,
  1249. ) -> List[HypothesisTransducer]:
  1250. """Speech2Text streaming call.
  1251. Args:
  1252. speech: Chunk of speech data. (S)
  1253. is_final: Whether speech corresponds to the final chunk of data.
  1254. Returns:
  1255. nbest_hypothesis: N-best hypothesis.
  1256. """
  1257. if isinstance(speech, np.ndarray):
  1258. speech = torch.tensor(speech)
  1259. if is_final:
  1260. if self.streaming and speech.size(0) < self.last_chunk_length:
  1261. pad = torch.zeros(
  1262. self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
  1263. )
  1264. speech = torch.cat([speech, pad],
  1265. dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
  1266. feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  1267. feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
  1268. if self.asr_model.normalize is not None:
  1269. feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
  1270. feats = to_device(feats, device=self.device)
  1271. feats_lengths = to_device(feats_lengths, device=self.device)
  1272. enc_out = self.asr_model.encoder.chunk_forward(
  1273. feats,
  1274. feats_lengths,
  1275. self.num_processed_frames,
  1276. chunk_size=self.chunk_size,
  1277. left_context=self.left_context,
  1278. right_context=self.right_context,
  1279. )
  1280. nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
  1281. self.num_processed_frames += self.chunk_size
  1282. if is_final:
  1283. self.reset_inference_cache()
  1284. return nbest_hyps
  1285. @torch.no_grad()
  1286. def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
  1287. """Speech2Text call.
  1288. Args:
  1289. speech: Speech data. (S)
  1290. Returns:
  1291. nbest_hypothesis: N-best hypothesis.
  1292. """
  1293. assert check_argument_types()
  1294. if isinstance(speech, np.ndarray):
  1295. speech = torch.tensor(speech)
  1296. if self.frontend is not None:
  1297. speech = torch.unsqueeze(speech, axis=0)
  1298. speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
  1299. feats, feats_lengths = self.frontend(speech, speech_lengths)
  1300. else:
  1301. feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  1302. feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
  1303. if self.asr_model.normalize is not None:
  1304. feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
  1305. feats = to_device(feats, device=self.device)
  1306. feats_lengths = to_device(feats_lengths, device=self.device)
  1307. enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
  1308. self.right_context)
  1309. nbest_hyps = self.beam_search(enc_out[0])
  1310. return nbest_hyps
  1311. @torch.no_grad()
  1312. def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
  1313. """Speech2Text call.
  1314. Args:
  1315. speech: Speech data. (S)
  1316. Returns:
  1317. nbest_hypothesis: N-best hypothesis.
  1318. """
  1319. assert check_argument_types()
  1320. if isinstance(speech, np.ndarray):
  1321. speech = torch.tensor(speech)
  1322. if self.frontend is not None:
  1323. speech = torch.unsqueeze(speech, axis=0)
  1324. speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
  1325. feats, feats_lengths = self.frontend(speech, speech_lengths)
  1326. else:
  1327. feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  1328. feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
  1329. feats = to_device(feats, device=self.device)
  1330. feats_lengths = to_device(feats_lengths, device=self.device)
  1331. enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
  1332. nbest_hyps = self.beam_search(enc_out[0])
  1333. return nbest_hyps
  1334. def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
  1335. """Build partial or final results from the hypotheses.
  1336. Args:
  1337. nbest_hyps: N-best hypothesis.
  1338. Returns:
  1339. results: Results containing different representation for the hypothesis.
  1340. """
  1341. results = []
  1342. for hyp in nbest_hyps:
  1343. token_int = list(filter(lambda x: x != 0, hyp.yseq))
  1344. token = self.converter.ids2tokens(token_int)
  1345. if self.tokenizer is not None:
  1346. text = self.tokenizer.tokens2text(token)
  1347. else:
  1348. text = None
  1349. results.append((text, token, token_int, hyp))
  1350. assert check_return_type(results)
  1351. return results
  1352. @staticmethod
  1353. def from_pretrained(
  1354. model_tag: Optional[str] = None,
  1355. **kwargs: Optional[Any],
  1356. ) -> Speech2Text:
  1357. """Build Speech2Text instance from the pretrained model.
  1358. Args:
  1359. model_tag: Model tag of the pretrained models.
  1360. Return:
  1361. : Speech2Text instance.
  1362. """
  1363. if model_tag is not None:
  1364. try:
  1365. from espnet_model_zoo.downloader import ModelDownloader
  1366. except ImportError:
  1367. logging.error(
  1368. "`espnet_model_zoo` is not installed. "
  1369. "Please install via `pip install -U espnet_model_zoo`."
  1370. )
  1371. raise
  1372. d = ModelDownloader()
  1373. kwargs.update(**d.download_and_unpack(model_tag))
  1374. return Speech2TextTransducer(**kwargs)
  1375. class Speech2TextSAASR:
  1376. """Speech2Text class
  1377. Examples:
  1378. >>> import soundfile
  1379. >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
  1380. >>> audio, rate = soundfile.read("speech.wav")
  1381. >>> speech2text(audio)
  1382. [(text, token, token_int, hypothesis object), ...]
  1383. """
  1384. def __init__(
  1385. self,
  1386. asr_train_config: Union[Path, str] = None,
  1387. asr_model_file: Union[Path, str] = None,
  1388. cmvn_file: Union[Path, str] = None,
  1389. lm_train_config: Union[Path, str] = None,
  1390. lm_file: Union[Path, str] = None,
  1391. token_type: str = None,
  1392. bpemodel: str = None,
  1393. device: str = "cpu",
  1394. maxlenratio: float = 0.0,
  1395. minlenratio: float = 0.0,
  1396. batch_size: int = 1,
  1397. dtype: str = "float32",
  1398. beam_size: int = 20,
  1399. ctc_weight: float = 0.5,
  1400. lm_weight: float = 1.0,
  1401. ngram_weight: float = 0.9,
  1402. penalty: float = 0.0,
  1403. nbest: int = 1,
  1404. streaming: bool = False,
  1405. frontend_conf: dict = None,
  1406. **kwargs,
  1407. ):
  1408. assert check_argument_types()
  1409. # 1. Build ASR model
  1410. scorers = {}
  1411. asr_model, asr_train_args = build_model_from_file(
  1412. asr_train_config, asr_model_file, cmvn_file, device
  1413. )
  1414. frontend = None
  1415. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  1416. if asr_train_args.frontend == 'wav_frontend':
  1417. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  1418. else:
  1419. frontend_class = frontend_choices.get_class(asr_train_args.frontend)
  1420. frontend = frontend_class(**asr_train_args.frontend_conf).eval()
  1421. logging.info("asr_model: {}".format(asr_model))
  1422. logging.info("asr_train_args: {}".format(asr_train_args))
  1423. asr_model.to(dtype=getattr(torch, dtype)).eval()
  1424. decoder = asr_model.decoder
  1425. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  1426. token_list = asr_model.token_list
  1427. scorers.update(
  1428. decoder=decoder,
  1429. ctc=ctc,
  1430. length_bonus=LengthBonus(len(token_list)),
  1431. )
  1432. # 2. Build Language model
  1433. if lm_train_config is not None:
  1434. lm, lm_train_args = build_model_from_file(
  1435. lm_train_config, lm_file, None, device
  1436. )
  1437. scorers["lm"] = lm.lm
  1438. # 3. Build ngram model
  1439. # ngram is not supported now
  1440. ngram = None
  1441. scorers["ngram"] = ngram
  1442. # 4. Build BeamSearch object
  1443. # transducer is not supported now
  1444. beam_search_transducer = None
  1445. from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
  1446. weights = dict(
  1447. decoder=1.0 - ctc_weight,
  1448. ctc=ctc_weight,
  1449. lm=lm_weight,
  1450. ngram=ngram_weight,
  1451. length_bonus=penalty,
  1452. )
  1453. beam_search = BeamSearch(
  1454. beam_size=beam_size,
  1455. weights=weights,
  1456. scorers=scorers,
  1457. sos=asr_model.sos,
  1458. eos=asr_model.eos,
  1459. vocab_size=len(token_list),
  1460. token_list=token_list,
  1461. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  1462. )
  1463. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  1464. if token_type is None:
  1465. token_type = asr_train_args.token_type
  1466. if bpemodel is None:
  1467. bpemodel = asr_train_args.bpemodel
  1468. if token_type is None:
  1469. tokenizer = None
  1470. elif token_type == "bpe":
  1471. if bpemodel is not None:
  1472. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  1473. else:
  1474. tokenizer = None
  1475. else:
  1476. tokenizer = build_tokenizer(token_type=token_type)
  1477. converter = TokenIDConverter(token_list=token_list)
  1478. logging.info(f"Text tokenizer: {tokenizer}")
  1479. self.asr_model = asr_model
  1480. self.asr_train_args = asr_train_args
  1481. self.converter = converter
  1482. self.tokenizer = tokenizer
  1483. self.beam_search = beam_search
  1484. self.beam_search_transducer = beam_search_transducer
  1485. self.maxlenratio = maxlenratio
  1486. self.minlenratio = minlenratio
  1487. self.device = device
  1488. self.dtype = dtype
  1489. self.nbest = nbest
  1490. self.frontend = frontend
  1491. @torch.no_grad()
  1492. def __call__(
  1493. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
  1494. profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
  1495. ) -> List[
  1496. Tuple[
  1497. Optional[str],
  1498. Optional[str],
  1499. List[str],
  1500. List[int],
  1501. Union[HypothesisSAASR],
  1502. ]
  1503. ]:
  1504. """Inference
  1505. Args:
  1506. speech: Input speech data
  1507. Returns:
  1508. text, text_id, token, token_int, hyp
  1509. """
  1510. assert check_argument_types()
  1511. # Input as audio signal
  1512. if isinstance(speech, np.ndarray):
  1513. speech = torch.tensor(speech)
  1514. if isinstance(profile, np.ndarray):
  1515. profile = torch.tensor(profile)
  1516. if self.frontend is not None:
  1517. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  1518. feats = to_device(feats, device=self.device)
  1519. feats_len = feats_len.int()
  1520. self.asr_model.frontend = None
  1521. else:
  1522. feats = speech
  1523. feats_len = speech_lengths
  1524. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  1525. batch = {"speech": feats, "speech_lengths": feats_len}
  1526. # a. To device
  1527. batch = to_device(batch, device=self.device)
  1528. # b. Forward Encoder
  1529. asr_enc, _, spk_enc = self.asr_model.encode(**batch)
  1530. if isinstance(asr_enc, tuple):
  1531. asr_enc = asr_enc[0]
  1532. if isinstance(spk_enc, tuple):
  1533. spk_enc = spk_enc[0]
  1534. assert len(asr_enc) == 1, len(asr_enc)
  1535. assert len(spk_enc) == 1, len(spk_enc)
  1536. # c. Passed the encoder result and the beam search
  1537. nbest_hyps = self.beam_search(
  1538. asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  1539. )
  1540. nbest_hyps = nbest_hyps[: self.nbest]
  1541. results = []
  1542. for hyp in nbest_hyps:
  1543. assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
  1544. # remove sos/eos and get results
  1545. last_pos = -1
  1546. if isinstance(hyp.yseq, list):
  1547. token_int = hyp.yseq[1: last_pos]
  1548. else:
  1549. token_int = hyp.yseq[1: last_pos].tolist()
  1550. spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
  1551. token_ori = self.converter.ids2tokens(token_int)
  1552. text_ori = self.tokenizer.tokens2text(token_ori)
  1553. text_ori_spklist = text_ori.split('$')
  1554. cur_index = 0
  1555. spk_choose = []
  1556. for i in range(len(text_ori_spklist)):
  1557. text_ori_split = text_ori_spklist[i]
  1558. n = len(text_ori_split)
  1559. spk_weights_local = spk_weigths[cur_index: cur_index + n]
  1560. cur_index = cur_index + n + 1
  1561. spk_weights_local = spk_weights_local.mean(dim=0)
  1562. spk_choose_local = spk_weights_local.argmax(-1)
  1563. spk_choose.append(spk_choose_local.item() + 1)
  1564. # remove blank symbol id, which is assumed to be 0
  1565. token_int = list(filter(lambda x: x != 0, token_int))
  1566. # Change integer-ids to tokens
  1567. token = self.converter.ids2tokens(token_int)
  1568. if self.tokenizer is not None:
  1569. text = self.tokenizer.tokens2text(token)
  1570. else:
  1571. text = None
  1572. text_spklist = text.split('$')
  1573. assert len(spk_choose) == len(text_spklist)
  1574. spk_list = []
  1575. for i in range(len(text_spklist)):
  1576. text_split = text_spklist[i]
  1577. n = len(text_split)
  1578. spk_list.append(str(spk_choose[i]) * n)
  1579. text_id = '$'.join(spk_list)
  1580. assert len(text) == len(text_id)
  1581. results.append((text, text_id, token, token_int, hyp))
  1582. assert check_return_type(results)
  1583. return results