asr_inference_launch.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. import time
  10. from pathlib import Path
  11. from typing import Dict
  12. from typing import List
  13. from typing import Optional
  14. from typing import Sequence
  15. from typing import Tuple
  16. from typing import Union
  17. import numpy as np
  18. import torch
  19. import torchaudio
  20. import soundfile
  21. import yaml
  22. from funasr.bin.asr_infer import Speech2Text
  23. from funasr.bin.asr_infer import Speech2TextMFCCA
  24. from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
  25. from funasr.bin.asr_infer import Speech2TextSAASR
  26. from funasr.bin.asr_infer import Speech2TextTransducer
  27. from funasr.bin.asr_infer import Speech2TextUniASR
  28. from funasr.bin.punc_infer import Text2Punc
  29. from funasr.bin.tp_infer import Speech2Timestamp
  30. from funasr.bin.vad_infer import Speech2VadSegment
  31. from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
  32. from funasr.fileio.datadir_writer import DatadirWriter
  33. from funasr.modules.beam_search.beam_search import Hypothesis
  34. from funasr.modules.subsampling import TooShortUttError
  35. from funasr.torch_utils.device_funcs import to_device
  36. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  37. from funasr.utils import asr_utils, postprocess_utils
  38. from funasr.utils import config_argparse
  39. from funasr.utils.cli_utils import get_commandline_args
  40. from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
  41. from funasr.utils.types import str2bool
  42. from funasr.utils.types import str2triple_str
  43. from funasr.utils.types import str_or_none
  44. from funasr.utils.vad_utils import slice_padding_fbank
  45. def inference_asr(
  46. maxlenratio: float,
  47. minlenratio: float,
  48. batch_size: int,
  49. beam_size: int,
  50. ngpu: int,
  51. ctc_weight: float,
  52. lm_weight: float,
  53. penalty: float,
  54. log_level: Union[int, str],
  55. # data_path_and_name_and_type,
  56. asr_train_config: Optional[str],
  57. asr_model_file: Optional[str],
  58. cmvn_file: Optional[str] = None,
  59. lm_train_config: Optional[str] = None,
  60. lm_file: Optional[str] = None,
  61. token_type: Optional[str] = None,
  62. key_file: Optional[str] = None,
  63. word_lm_train_config: Optional[str] = None,
  64. bpemodel: Optional[str] = None,
  65. allow_variable_data_keys: bool = False,
  66. streaming: bool = False,
  67. output_dir: Optional[str] = None,
  68. dtype: str = "float32",
  69. seed: int = 0,
  70. ngram_weight: float = 0.9,
  71. nbest: int = 1,
  72. num_workers: int = 1,
  73. mc: bool = False,
  74. param_dict: dict = None,
  75. **kwargs,
  76. ):
  77. ncpu = kwargs.get("ncpu", 1)
  78. torch.set_num_threads(ncpu)
  79. if batch_size > 1:
  80. raise NotImplementedError("batch decoding is not implemented")
  81. if word_lm_train_config is not None:
  82. raise NotImplementedError("Word LM is not implemented")
  83. if ngpu > 1:
  84. raise NotImplementedError("only single GPU decoding is supported")
  85. for handler in logging.root.handlers[:]:
  86. logging.root.removeHandler(handler)
  87. logging.basicConfig(
  88. level=log_level,
  89. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  90. )
  91. if ngpu >= 1 and torch.cuda.is_available():
  92. device = "cuda"
  93. else:
  94. device = "cpu"
  95. # 1. Set random-seed
  96. set_all_random_seed(seed)
  97. # 2. Build speech2text
  98. speech2text_kwargs = dict(
  99. asr_train_config=asr_train_config,
  100. asr_model_file=asr_model_file,
  101. cmvn_file=cmvn_file,
  102. lm_train_config=lm_train_config,
  103. lm_file=lm_file,
  104. token_type=token_type,
  105. bpemodel=bpemodel,
  106. device=device,
  107. maxlenratio=maxlenratio,
  108. minlenratio=minlenratio,
  109. dtype=dtype,
  110. beam_size=beam_size,
  111. ctc_weight=ctc_weight,
  112. lm_weight=lm_weight,
  113. ngram_weight=ngram_weight,
  114. penalty=penalty,
  115. nbest=nbest,
  116. streaming=streaming,
  117. )
  118. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  119. speech2text = Speech2Text(**speech2text_kwargs)
  120. def _forward(data_path_and_name_and_type,
  121. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  122. output_dir_v2: Optional[str] = None,
  123. fs: dict = None,
  124. param_dict: dict = None,
  125. **kwargs,
  126. ):
  127. # 3. Build data-iterator
  128. if data_path_and_name_and_type is None and raw_inputs is not None:
  129. if isinstance(raw_inputs, torch.Tensor):
  130. raw_inputs = raw_inputs.numpy()
  131. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  132. loader = build_streaming_iterator(
  133. task_name="asr",
  134. preprocess_args=speech2text.asr_train_args,
  135. data_path_and_name_and_type=data_path_and_name_and_type,
  136. dtype=dtype,
  137. fs=fs,
  138. mc=mc,
  139. batch_size=batch_size,
  140. key_file=key_file,
  141. num_workers=num_workers,
  142. )
  143. finish_count = 0
  144. file_count = 1
  145. # 7 .Start for-loop
  146. # FIXME(kamo): The output format should be discussed about
  147. asr_result_list = []
  148. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  149. if output_path is not None:
  150. writer = DatadirWriter(output_path)
  151. else:
  152. writer = None
  153. for keys, batch in loader:
  154. assert isinstance(batch, dict), type(batch)
  155. assert all(isinstance(s, str) for s in keys), keys
  156. _bs = len(next(iter(batch.values())))
  157. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  158. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  159. # N-best list of (text, token, token_int, hyp_object)
  160. try:
  161. results = speech2text(**batch)
  162. except TooShortUttError as e:
  163. logging.warning(f"Utterance {keys} {e}")
  164. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  165. results = [[" ", ["sil"], [2], hyp]] * nbest
  166. # Only supporting batch_size==1
  167. key = keys[0]
  168. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  169. # Create a directory: outdir/{n}best_recog
  170. if writer is not None:
  171. ibest_writer = writer[f"{n}best_recog"]
  172. # Write the result to each file
  173. ibest_writer["token"][key] = " ".join(token)
  174. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  175. ibest_writer["score"][key] = str(hyp.score)
  176. if text is not None:
  177. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  178. item = {'key': key, 'value': text_postprocessed}
  179. asr_result_list.append(item)
  180. finish_count += 1
  181. asr_utils.print_progress(finish_count / file_count)
  182. if writer is not None:
  183. ibest_writer["text"][key] = text
  184. logging.info("uttid: {}".format(key))
  185. logging.info("text predictions: {}\n".format(text))
  186. return asr_result_list
  187. return _forward
  188. def inference_paraformer(
  189. maxlenratio: float,
  190. minlenratio: float,
  191. batch_size: int,
  192. beam_size: int,
  193. ngpu: int,
  194. ctc_weight: float,
  195. lm_weight: float,
  196. penalty: float,
  197. log_level: Union[int, str],
  198. # data_path_and_name_and_type,
  199. asr_train_config: Optional[str],
  200. asr_model_file: Optional[str],
  201. cmvn_file: Optional[str] = None,
  202. lm_train_config: Optional[str] = None,
  203. lm_file: Optional[str] = None,
  204. token_type: Optional[str] = None,
  205. key_file: Optional[str] = None,
  206. word_lm_train_config: Optional[str] = None,
  207. bpemodel: Optional[str] = None,
  208. allow_variable_data_keys: bool = False,
  209. dtype: str = "float32",
  210. seed: int = 0,
  211. ngram_weight: float = 0.9,
  212. nbest: int = 1,
  213. num_workers: int = 1,
  214. output_dir: Optional[str] = None,
  215. timestamp_infer_config: Union[Path, str] = None,
  216. timestamp_model_file: Union[Path, str] = None,
  217. param_dict: dict = None,
  218. **kwargs,
  219. ):
  220. ncpu = kwargs.get("ncpu", 1)
  221. torch.set_num_threads(ncpu)
  222. if word_lm_train_config is not None:
  223. raise NotImplementedError("Word LM is not implemented")
  224. if ngpu > 1:
  225. raise NotImplementedError("only single GPU decoding is supported")
  226. logging.basicConfig(
  227. level=log_level,
  228. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  229. )
  230. export_mode = False
  231. if param_dict is not None:
  232. hotword_list_or_file = param_dict.get('hotword')
  233. export_mode = param_dict.get("export_mode", False)
  234. clas_scale = param_dict.get('clas_scale', 1.0)
  235. else:
  236. hotword_list_or_file = None
  237. clas_scale = 1.0
  238. if kwargs.get("device", None) == "cpu":
  239. ngpu = 0
  240. if ngpu >= 1 and torch.cuda.is_available():
  241. device = "cuda"
  242. else:
  243. device = "cpu"
  244. batch_size = 1
  245. # 1. Set random-seed
  246. set_all_random_seed(seed)
  247. # 2. Build speech2text
  248. speech2text_kwargs = dict(
  249. asr_train_config=asr_train_config,
  250. asr_model_file=asr_model_file,
  251. cmvn_file=cmvn_file,
  252. lm_train_config=lm_train_config,
  253. lm_file=lm_file,
  254. token_type=token_type,
  255. bpemodel=bpemodel,
  256. device=device,
  257. maxlenratio=maxlenratio,
  258. minlenratio=minlenratio,
  259. dtype=dtype,
  260. beam_size=beam_size,
  261. ctc_weight=ctc_weight,
  262. lm_weight=lm_weight,
  263. ngram_weight=ngram_weight,
  264. penalty=penalty,
  265. nbest=nbest,
  266. hotword_list_or_file=hotword_list_or_file,
  267. clas_scale=clas_scale,
  268. )
  269. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  270. if timestamp_model_file is not None:
  271. speechtext2timestamp = Speech2Timestamp(
  272. timestamp_cmvn_file=cmvn_file,
  273. timestamp_model_file=timestamp_model_file,
  274. timestamp_infer_config=timestamp_infer_config,
  275. )
  276. else:
  277. speechtext2timestamp = None
  278. def _forward(
  279. data_path_and_name_and_type,
  280. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  281. output_dir_v2: Optional[str] = None,
  282. fs: dict = None,
  283. param_dict: dict = None,
  284. **kwargs,
  285. ):
  286. hotword_list_or_file = None
  287. if param_dict is not None:
  288. hotword_list_or_file = param_dict.get('hotword')
  289. if 'hotword' in kwargs and kwargs['hotword'] is not None:
  290. hotword_list_or_file = kwargs['hotword']
  291. if hotword_list_or_file is not None or 'hotword' in kwargs:
  292. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  293. # 3. Build data-iterator
  294. if data_path_and_name_and_type is None and raw_inputs is not None:
  295. if isinstance(raw_inputs, torch.Tensor):
  296. raw_inputs = raw_inputs.numpy()
  297. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  298. loader = build_streaming_iterator(
  299. task_name="asr",
  300. preprocess_args=speech2text.asr_train_args,
  301. data_path_and_name_and_type=data_path_and_name_and_type,
  302. dtype=dtype,
  303. fs=fs,
  304. batch_size=batch_size,
  305. key_file=key_file,
  306. num_workers=num_workers,
  307. )
  308. if param_dict is not None:
  309. use_timestamp = param_dict.get('use_timestamp', True)
  310. else:
  311. use_timestamp = True
  312. forward_time_total = 0.0
  313. length_total = 0.0
  314. finish_count = 0
  315. file_count = 1
  316. # 7 .Start for-loop
  317. # FIXME(kamo): The output format should be discussed about
  318. asr_result_list = []
  319. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  320. if output_path is not None:
  321. writer = DatadirWriter(output_path)
  322. else:
  323. writer = None
  324. for keys, batch in loader:
  325. assert isinstance(batch, dict), type(batch)
  326. assert all(isinstance(s, str) for s in keys), keys
  327. _bs = len(next(iter(batch.values())))
  328. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  329. # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
  330. logging.info("decoding, utt_id: {}".format(keys))
  331. # N-best list of (text, token, token_int, hyp_object)
  332. time_beg = time.time()
  333. results = speech2text(**batch)
  334. if len(results) < 1:
  335. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  336. results = [[" ", ["sil"], [2], hyp, 10, 6, []]] * nbest
  337. time_end = time.time()
  338. forward_time = time_end - time_beg
  339. lfr_factor = results[0][-1]
  340. length = results[0][-2]
  341. forward_time_total += forward_time
  342. length_total += length
  343. rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
  344. 100 * forward_time / (
  345. length * lfr_factor))
  346. logging.info(rtf_cur)
  347. for batch_id in range(_bs):
  348. result = [results[batch_id][:-2]]
  349. key = keys[batch_id]
  350. for n, result in zip(range(1, nbest + 1), result):
  351. text, token, token_int, hyp = result[0], result[1], result[2], result[3]
  352. timestamp = result[4] if len(result[4]) > 0 else None
  353. # conduct timestamp prediction here
  354. # timestamp inference requires token length
  355. # thus following inference cannot be conducted in batch
  356. if timestamp is None and speechtext2timestamp:
  357. ts_batch = {}
  358. ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
  359. ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
  360. ts_batch['text_lengths'] = torch.tensor([len(token)])
  361. us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
  362. ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token,
  363. force_time_shift=-3.0)
  364. # Create a directory: outdir/{n}best_recog
  365. if writer is not None:
  366. ibest_writer = writer[f"{n}best_recog"]
  367. # Write the result to each file
  368. ibest_writer["token"][key] = " ".join(token)
  369. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  370. ibest_writer["score"][key] = str(hyp.score)
  371. ibest_writer["rtf"][key] = rtf_cur
  372. if text is not None:
  373. if use_timestamp and timestamp is not None:
  374. postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
  375. else:
  376. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  377. timestamp_postprocessed = ""
  378. if len(postprocessed_result) == 3:
  379. text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
  380. postprocessed_result[1], \
  381. postprocessed_result[2]
  382. else:
  383. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  384. item = {'key': key, 'value': text_postprocessed}
  385. if timestamp_postprocessed != "":
  386. item['timestamp'] = timestamp_postprocessed
  387. asr_result_list.append(item)
  388. finish_count += 1
  389. # asr_utils.print_progress(finish_count / file_count)
  390. if writer is not None:
  391. ibest_writer["text"][key] = " ".join(word_lists)
  392. logging.info("decoding, utt: {}, predictions: {}".format(key, text))
  393. rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
  394. forward_time_total,
  395. 100 * forward_time_total / (
  396. length_total * lfr_factor))
  397. logging.info(rtf_avg)
  398. if writer is not None:
  399. ibest_writer["rtf"]["rtf_avf"] = rtf_avg
  400. torch.cuda.empty_cache()
  401. return asr_result_list
  402. return _forward
  403. def inference_paraformer_vad_punc(
  404. maxlenratio: float,
  405. minlenratio: float,
  406. batch_size: int,
  407. beam_size: int,
  408. ngpu: int,
  409. ctc_weight: float,
  410. lm_weight: float,
  411. penalty: float,
  412. log_level: Union[int, str],
  413. # data_path_and_name_and_type,
  414. asr_train_config: Optional[str],
  415. asr_model_file: Optional[str],
  416. cmvn_file: Optional[str] = None,
  417. lm_train_config: Optional[str] = None,
  418. lm_file: Optional[str] = None,
  419. token_type: Optional[str] = None,
  420. key_file: Optional[str] = None,
  421. word_lm_train_config: Optional[str] = None,
  422. bpemodel: Optional[str] = None,
  423. allow_variable_data_keys: bool = False,
  424. output_dir: Optional[str] = None,
  425. dtype: str = "float32",
  426. seed: int = 0,
  427. ngram_weight: float = 0.9,
  428. nbest: int = 1,
  429. num_workers: int = 1,
  430. vad_infer_config: Optional[str] = None,
  431. vad_model_file: Optional[str] = None,
  432. vad_cmvn_file: Optional[str] = None,
  433. time_stamp_writer: bool = True,
  434. punc_infer_config: Optional[str] = None,
  435. punc_model_file: Optional[str] = None,
  436. outputs_dict: Optional[bool] = True,
  437. param_dict: dict = None,
  438. **kwargs,
  439. ):
  440. ncpu = kwargs.get("ncpu", 1)
  441. torch.set_num_threads(ncpu)
  442. if word_lm_train_config is not None:
  443. raise NotImplementedError("Word LM is not implemented")
  444. if ngpu > 1:
  445. raise NotImplementedError("only single GPU decoding is supported")
  446. logging.basicConfig(
  447. level=log_level,
  448. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  449. )
  450. if param_dict is not None:
  451. hotword_list_or_file = param_dict.get('hotword')
  452. else:
  453. hotword_list_or_file = None
  454. if ngpu >= 1 and torch.cuda.is_available():
  455. device = "cuda"
  456. else:
  457. device = "cpu"
  458. # 1. Set random-seed
  459. set_all_random_seed(seed)
  460. # 2. Build speech2vadsegment
  461. speech2vadsegment_kwargs = dict(
  462. vad_infer_config=vad_infer_config,
  463. vad_model_file=vad_model_file,
  464. vad_cmvn_file=vad_cmvn_file,
  465. device=device,
  466. dtype=dtype,
  467. )
  468. # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  469. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  470. # 3. Build speech2text
  471. speech2text_kwargs = dict(
  472. asr_train_config=asr_train_config,
  473. asr_model_file=asr_model_file,
  474. cmvn_file=cmvn_file,
  475. lm_train_config=lm_train_config,
  476. lm_file=lm_file,
  477. token_type=token_type,
  478. bpemodel=bpemodel,
  479. device=device,
  480. maxlenratio=maxlenratio,
  481. minlenratio=minlenratio,
  482. dtype=dtype,
  483. beam_size=beam_size,
  484. ctc_weight=ctc_weight,
  485. lm_weight=lm_weight,
  486. ngram_weight=ngram_weight,
  487. penalty=penalty,
  488. nbest=nbest,
  489. hotword_list_or_file=hotword_list_or_file,
  490. )
  491. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  492. text2punc = None
  493. if punc_model_file is not None:
  494. text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
  495. if output_dir is not None:
  496. writer = DatadirWriter(output_dir)
  497. ibest_writer = writer[f"1best_recog"]
  498. ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
  499. def _forward(data_path_and_name_and_type,
  500. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  501. output_dir_v2: Optional[str] = None,
  502. fs: dict = None,
  503. param_dict: dict = None,
  504. **kwargs,
  505. ):
  506. hotword_list_or_file = None
  507. if param_dict is not None:
  508. hotword_list_or_file = param_dict.get('hotword')
  509. if 'hotword' in kwargs:
  510. hotword_list_or_file = kwargs['hotword']
  511. speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
  512. batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
  513. batch_size_token = kwargs.get("batch_size_token", 6000)
  514. print("batch_size_token: ", batch_size_token)
  515. if speech2text.hotword_list is None:
  516. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  517. # 3. Build data-iterator
  518. if data_path_and_name_and_type is None and raw_inputs is not None:
  519. if isinstance(raw_inputs, torch.Tensor):
  520. raw_inputs = raw_inputs.numpy()
  521. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  522. loader = build_streaming_iterator(
  523. task_name="asr",
  524. preprocess_args=None,
  525. data_path_and_name_and_type=data_path_and_name_and_type,
  526. dtype=dtype,
  527. fs=fs,
  528. batch_size=1,
  529. key_file=key_file,
  530. num_workers=num_workers,
  531. )
  532. if param_dict is not None:
  533. use_timestamp = param_dict.get('use_timestamp', True)
  534. else:
  535. use_timestamp = True
  536. finish_count = 0
  537. file_count = 1
  538. lfr_factor = 6
  539. # 7 .Start for-loop
  540. asr_result_list = []
  541. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  542. writer = None
  543. if output_path is not None:
  544. writer = DatadirWriter(output_path)
  545. ibest_writer = writer[f"1best_recog"]
  546. for keys, batch in loader:
  547. assert isinstance(batch, dict), type(batch)
  548. assert all(isinstance(s, str) for s in keys), keys
  549. _bs = len(next(iter(batch.values())))
  550. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  551. beg_vad = time.time()
  552. vad_results = speech2vadsegment(**batch)
  553. end_vad = time.time()
  554. print("time cost vad: ", end_vad - beg_vad)
  555. _, vadsegments = vad_results[0], vad_results[1][0]
  556. speech, speech_lengths = batch["speech"], batch["speech_lengths"]
  557. n = len(vadsegments)
  558. data_with_index = [(vadsegments[i], i) for i in range(n)]
  559. sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
  560. results_sorted = []
  561. if not len(sorted_data):
  562. key = keys[0]
  563. # no active segments after VAD
  564. if writer is not None:
  565. # Write empty results
  566. ibest_writer["token"][key] = ""
  567. ibest_writer["token_int"][key] = ""
  568. ibest_writer["vad"][key] = ""
  569. ibest_writer["text"][key] = ""
  570. ibest_writer["text_with_punc"][key] = ""
  571. if use_timestamp:
  572. ibest_writer["time_stamp"][key] = ""
  573. logging.info("decoding, utt: {}, empty speech".format(key))
  574. continue
  575. batch_size_token_ms = batch_size_token*60
  576. if speech2text.device == "cpu":
  577. batch_size_token_ms = 0
  578. if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
  579. batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
  580. batch_size_token_ms_cum = 0
  581. beg_idx = 0
  582. for j, _ in enumerate(range(0, n)):
  583. batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
  584. if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
  585. continue
  586. batch_size_token_ms_cum = 0
  587. end_idx = j + 1
  588. speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
  589. beg_idx = end_idx
  590. batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
  591. batch = to_device(batch, device=device)
  592. print("batch: ", speech_j.shape[0])
  593. beg_asr = time.time()
  594. results = speech2text(**batch)
  595. end_asr = time.time()
  596. print("time cost asr: ", end_asr - beg_asr)
  597. if len(results) < 1:
  598. results = [["", [], [], [], [], [], []]]
  599. results_sorted.extend(results)
  600. restored_data = [0] * n
  601. for j in range(n):
  602. index = sorted_data[j][1]
  603. restored_data[index] = results_sorted[j]
  604. result = ["", [], [], [], [], [], []]
  605. for j in range(n):
  606. result[0] += restored_data[j][0]
  607. result[1] += restored_data[j][1]
  608. result[2] += restored_data[j][2]
  609. if len(restored_data[j][4]) > 0:
  610. for t in restored_data[j][4]:
  611. t[0] += vadsegments[j][0]
  612. t[1] += vadsegments[j][0]
  613. result[4] += restored_data[j][4]
  614. # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
  615. key = keys[0]
  616. # result = result_segments[0]
  617. text, token, token_int = result[0], result[1], result[2]
  618. time_stamp = result[4] if len(result[4]) > 0 else None
  619. if use_timestamp and time_stamp is not None:
  620. postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
  621. else:
  622. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  623. text_postprocessed = ""
  624. time_stamp_postprocessed = ""
  625. text_postprocessed_punc = postprocessed_result
  626. if len(postprocessed_result) == 3:
  627. text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
  628. postprocessed_result[1], \
  629. postprocessed_result[2]
  630. else:
  631. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  632. text_postprocessed_punc = text_postprocessed
  633. punc_id_list = []
  634. if len(word_lists) > 0 and text2punc is not None:
  635. beg_punc = time.time()
  636. text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
  637. end_punc = time.time()
  638. print("time cost punc: ", end_punc - beg_punc)
  639. item = {'key': key, 'value': text_postprocessed_punc}
  640. if text_postprocessed != "":
  641. item['text_postprocessed'] = text_postprocessed
  642. if time_stamp_postprocessed != "":
  643. item['time_stamp'] = time_stamp_postprocessed
  644. item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
  645. asr_result_list.append(item)
  646. finish_count += 1
  647. # asr_utils.print_progress(finish_count / file_count)
  648. if writer is not None:
  649. # Write the result to each file
  650. ibest_writer["token"][key] = " ".join(token)
  651. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  652. ibest_writer["vad"][key] = "{}".format(vadsegments)
  653. ibest_writer["text"][key] = " ".join(word_lists)
  654. ibest_writer["text_with_punc"][key] = text_postprocessed_punc
  655. if time_stamp_postprocessed is not None:
  656. ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
  657. logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
  658. torch.cuda.empty_cache()
  659. return asr_result_list
  660. return _forward
  661. def inference_paraformer_online(
  662. maxlenratio: float,
  663. minlenratio: float,
  664. batch_size: int,
  665. beam_size: int,
  666. ngpu: int,
  667. ctc_weight: float,
  668. lm_weight: float,
  669. penalty: float,
  670. log_level: Union[int, str],
  671. # data_path_and_name_and_type,
  672. asr_train_config: Optional[str],
  673. asr_model_file: Optional[str],
  674. cmvn_file: Optional[str] = None,
  675. lm_train_config: Optional[str] = None,
  676. lm_file: Optional[str] = None,
  677. token_type: Optional[str] = None,
  678. key_file: Optional[str] = None,
  679. word_lm_train_config: Optional[str] = None,
  680. bpemodel: Optional[str] = None,
  681. allow_variable_data_keys: bool = False,
  682. dtype: str = "float32",
  683. seed: int = 0,
  684. ngram_weight: float = 0.9,
  685. nbest: int = 1,
  686. num_workers: int = 1,
  687. output_dir: Optional[str] = None,
  688. param_dict: dict = None,
  689. **kwargs,
  690. ):
  691. if word_lm_train_config is not None:
  692. raise NotImplementedError("Word LM is not implemented")
  693. if ngpu > 1:
  694. raise NotImplementedError("only single GPU decoding is supported")
  695. logging.basicConfig(
  696. level=log_level,
  697. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  698. )
  699. export_mode = False
  700. if ngpu >= 1 and torch.cuda.is_available():
  701. device = "cuda"
  702. else:
  703. device = "cpu"
  704. batch_size = 1
  705. # 1. Set random-seed
  706. set_all_random_seed(seed)
  707. # 2. Build speech2text
  708. speech2text_kwargs = dict(
  709. asr_train_config=asr_train_config,
  710. asr_model_file=asr_model_file,
  711. cmvn_file=cmvn_file,
  712. lm_train_config=lm_train_config,
  713. lm_file=lm_file,
  714. token_type=token_type,
  715. bpemodel=bpemodel,
  716. device=device,
  717. maxlenratio=maxlenratio,
  718. minlenratio=minlenratio,
  719. dtype=dtype,
  720. beam_size=beam_size,
  721. ctc_weight=ctc_weight,
  722. lm_weight=lm_weight,
  723. ngram_weight=ngram_weight,
  724. penalty=penalty,
  725. nbest=nbest,
  726. )
  727. speech2text = Speech2TextParaformerOnline(**speech2text_kwargs)
  728. def _load_bytes(input):
  729. middle_data = np.frombuffer(input, dtype=np.int16)
  730. middle_data = np.asarray(middle_data)
  731. if middle_data.dtype.kind not in 'iu':
  732. raise TypeError("'middle_data' must be an array of integers")
  733. dtype = np.dtype('float32')
  734. if dtype.kind != 'f':
  735. raise TypeError("'dtype' must be a floating point type")
  736. i = np.iinfo(middle_data.dtype)
  737. abs_max = 2 ** (i.bits - 1)
  738. offset = i.min + abs_max
  739. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  740. return array
  741. def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
  742. if not Path(yaml_path).exists():
  743. raise FileExistsError(f'The {yaml_path} does not exist.')
  744. with open(str(yaml_path), 'rb') as f:
  745. data = yaml.load(f, Loader=yaml.Loader)
  746. return data
  747. def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
  748. if len(cache) > 0:
  749. return cache
  750. config = _read_yaml(asr_train_config)
  751. enc_output_size = config["encoder_conf"]["output_size"]
  752. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  753. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  754. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  755. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
  756. cache["encoder"] = cache_en
  757. cache_de = {"decode_fsmn": None}
  758. cache["decoder"] = cache_de
  759. return cache
  760. def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
  761. if len(cache) > 0:
  762. config = _read_yaml(asr_train_config)
  763. enc_output_size = config["encoder_conf"]["output_size"]
  764. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  765. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  766. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  767. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
  768. "tail_chunk": False}
  769. cache["encoder"] = cache_en
  770. cache_de = {"decode_fsmn": None}
  771. cache["decoder"] = cache_de
  772. return cache
  773. def _forward(
  774. data_path_and_name_and_type,
  775. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  776. output_dir_v2: Optional[str] = None,
  777. fs: dict = None,
  778. param_dict: dict = None,
  779. **kwargs,
  780. ):
  781. # 3. Build data-iterator
  782. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
  783. raw_inputs = _load_bytes(data_path_and_name_and_type[0])
  784. raw_inputs = torch.tensor(raw_inputs)
  785. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  786. try:
  787. raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
  788. except:
  789. raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
  790. if raw_inputs.ndim == 2:
  791. raw_inputs = raw_inputs[:, 0]
  792. raw_inputs = torch.tensor(raw_inputs)
  793. if data_path_and_name_and_type is None and raw_inputs is not None:
  794. if isinstance(raw_inputs, np.ndarray):
  795. raw_inputs = torch.tensor(raw_inputs)
  796. is_final = False
  797. cache = {}
  798. chunk_size = [5, 10, 5]
  799. if param_dict is not None and "cache" in param_dict:
  800. cache = param_dict["cache"]
  801. if param_dict is not None and "is_final" in param_dict:
  802. is_final = param_dict["is_final"]
  803. if param_dict is not None and "chunk_size" in param_dict:
  804. chunk_size = param_dict["chunk_size"]
  805. # 7 .Start for-loop
  806. # FIXME(kamo): The output format should be discussed about
  807. raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
  808. asr_result_list = []
  809. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  810. item = {}
  811. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  812. sample_offset = 0
  813. speech_length = raw_inputs.shape[1]
  814. stride_size = chunk_size[1] * 960
  815. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  816. final_result = ""
  817. for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
  818. if sample_offset + stride_size >= speech_length - 1:
  819. stride_size = speech_length - sample_offset
  820. cache["encoder"]["is_final"] = True
  821. else:
  822. cache["encoder"]["is_final"] = False
  823. input_lens = torch.tensor([stride_size])
  824. asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
  825. if len(asr_result) != 0:
  826. final_result += " ".join(asr_result) + " "
  827. item = {'key': "utt", 'value': final_result.strip()}
  828. else:
  829. input_lens = torch.tensor([raw_inputs.shape[1]])
  830. cache["encoder"]["is_final"] = is_final
  831. asr_result = speech2text(cache, raw_inputs, input_lens)
  832. item = {'key': "utt", 'value': " ".join(asr_result)}
  833. asr_result_list.append(item)
  834. if is_final:
  835. cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
  836. return asr_result_list
  837. return _forward
  838. def inference_uniasr(
  839. maxlenratio: float,
  840. minlenratio: float,
  841. batch_size: int,
  842. beam_size: int,
  843. ngpu: int,
  844. ctc_weight: float,
  845. lm_weight: float,
  846. penalty: float,
  847. log_level: Union[int, str],
  848. # data_path_and_name_and_type,
  849. asr_train_config: Optional[str],
  850. asr_model_file: Optional[str],
  851. ngram_file: Optional[str] = None,
  852. cmvn_file: Optional[str] = None,
  853. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  854. lm_train_config: Optional[str] = None,
  855. lm_file: Optional[str] = None,
  856. token_type: Optional[str] = None,
  857. key_file: Optional[str] = None,
  858. word_lm_train_config: Optional[str] = None,
  859. bpemodel: Optional[str] = None,
  860. allow_variable_data_keys: bool = False,
  861. streaming: bool = False,
  862. output_dir: Optional[str] = None,
  863. dtype: str = "float32",
  864. seed: int = 0,
  865. ngram_weight: float = 0.9,
  866. nbest: int = 1,
  867. num_workers: int = 1,
  868. token_num_relax: int = 1,
  869. decoding_ind: int = 0,
  870. decoding_mode: str = "model1",
  871. param_dict: dict = None,
  872. **kwargs,
  873. ):
  874. ncpu = kwargs.get("ncpu", 1)
  875. torch.set_num_threads(ncpu)
  876. if batch_size > 1:
  877. raise NotImplementedError("batch decoding is not implemented")
  878. if word_lm_train_config is not None:
  879. raise NotImplementedError("Word LM is not implemented")
  880. if ngpu > 1:
  881. raise NotImplementedError("only single GPU decoding is supported")
  882. logging.basicConfig(
  883. level=log_level,
  884. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  885. )
  886. if ngpu >= 1 and torch.cuda.is_available():
  887. device = "cuda"
  888. else:
  889. device = "cpu"
  890. if param_dict is not None and "decoding_model" in param_dict:
  891. if param_dict["decoding_model"] == "fast":
  892. decoding_ind = 0
  893. decoding_mode = "model1"
  894. elif param_dict["decoding_model"] == "normal":
  895. decoding_ind = 0
  896. decoding_mode = "model2"
  897. elif param_dict["decoding_model"] == "offline":
  898. decoding_ind = 1
  899. decoding_mode = "model2"
  900. else:
  901. raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
  902. # 1. Set random-seed
  903. set_all_random_seed(seed)
  904. # 2. Build speech2text
  905. speech2text_kwargs = dict(
  906. asr_train_config=asr_train_config,
  907. asr_model_file=asr_model_file,
  908. cmvn_file=cmvn_file,
  909. lm_train_config=lm_train_config,
  910. lm_file=lm_file,
  911. ngram_file=ngram_file,
  912. token_type=token_type,
  913. bpemodel=bpemodel,
  914. device=device,
  915. maxlenratio=maxlenratio,
  916. minlenratio=minlenratio,
  917. dtype=dtype,
  918. beam_size=beam_size,
  919. ctc_weight=ctc_weight,
  920. lm_weight=lm_weight,
  921. ngram_weight=ngram_weight,
  922. penalty=penalty,
  923. nbest=nbest,
  924. streaming=streaming,
  925. token_num_relax=token_num_relax,
  926. decoding_ind=decoding_ind,
  927. decoding_mode=decoding_mode,
  928. )
  929. speech2text = Speech2TextUniASR(**speech2text_kwargs)
  930. def _forward(data_path_and_name_and_type,
  931. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  932. output_dir_v2: Optional[str] = None,
  933. fs: dict = None,
  934. param_dict: dict = None,
  935. **kwargs,
  936. ):
  937. # 3. Build data-iterator
  938. if data_path_and_name_and_type is None and raw_inputs is not None:
  939. if isinstance(raw_inputs, torch.Tensor):
  940. raw_inputs = raw_inputs.numpy()
  941. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  942. loader = build_streaming_iterator(
  943. task_name="asr",
  944. preprocess_args=speech2text.asr_train_args,
  945. data_path_and_name_and_type=data_path_and_name_and_type,
  946. dtype=dtype,
  947. fs=fs,
  948. batch_size=batch_size,
  949. key_file=key_file,
  950. num_workers=num_workers,
  951. )
  952. finish_count = 0
  953. file_count = 1
  954. # 7 .Start for-loop
  955. # FIXME(kamo): The output format should be discussed about
  956. asr_result_list = []
  957. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  958. if output_path is not None:
  959. writer = DatadirWriter(output_path)
  960. else:
  961. writer = None
  962. for keys, batch in loader:
  963. assert isinstance(batch, dict), type(batch)
  964. assert all(isinstance(s, str) for s in keys), keys
  965. _bs = len(next(iter(batch.values())))
  966. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  967. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  968. # N-best list of (text, token, token_int, hyp_object)
  969. try:
  970. results = speech2text(**batch)
  971. except TooShortUttError as e:
  972. logging.warning(f"Utterance {keys} {e}")
  973. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  974. results = [[" ", ["sil"], [2], hyp]] * nbest
  975. # Only supporting batch_size==1
  976. key = keys[0]
  977. logging.info(f"Utterance: {key}")
  978. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  979. # Create a directory: outdir/{n}best_recog
  980. if writer is not None:
  981. ibest_writer = writer[f"{n}best_recog"]
  982. # Write the result to each file
  983. ibest_writer["token"][key] = " ".join(token)
  984. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  985. ibest_writer["score"][key] = str(hyp.score)
  986. if text is not None:
  987. text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
  988. item = {'key': key, 'value': text_postprocessed}
  989. asr_result_list.append(item)
  990. finish_count += 1
  991. asr_utils.print_progress(finish_count / file_count)
  992. if writer is not None:
  993. ibest_writer["text"][key] = " ".join(word_lists)
  994. return asr_result_list
  995. return _forward
  996. def inference_mfcca(
  997. maxlenratio: float,
  998. minlenratio: float,
  999. batch_size: int,
  1000. beam_size: int,
  1001. ngpu: int,
  1002. ctc_weight: float,
  1003. lm_weight: float,
  1004. penalty: float,
  1005. log_level: Union[int, str],
  1006. # data_path_and_name_and_type,
  1007. asr_train_config: Optional[str],
  1008. asr_model_file: Optional[str],
  1009. cmvn_file: Optional[str] = None,
  1010. lm_train_config: Optional[str] = None,
  1011. lm_file: Optional[str] = None,
  1012. token_type: Optional[str] = None,
  1013. key_file: Optional[str] = None,
  1014. word_lm_train_config: Optional[str] = None,
  1015. bpemodel: Optional[str] = None,
  1016. allow_variable_data_keys: bool = False,
  1017. streaming: bool = False,
  1018. output_dir: Optional[str] = None,
  1019. dtype: str = "float32",
  1020. seed: int = 0,
  1021. ngram_weight: float = 0.9,
  1022. nbest: int = 1,
  1023. num_workers: int = 1,
  1024. param_dict: dict = None,
  1025. **kwargs,
  1026. ):
  1027. ncpu = kwargs.get("ncpu", 1)
  1028. torch.set_num_threads(ncpu)
  1029. if batch_size > 1:
  1030. raise NotImplementedError("batch decoding is not implemented")
  1031. if word_lm_train_config is not None:
  1032. raise NotImplementedError("Word LM is not implemented")
  1033. if ngpu > 1:
  1034. raise NotImplementedError("only single GPU decoding is supported")
  1035. logging.basicConfig(
  1036. level=log_level,
  1037. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1038. )
  1039. if ngpu >= 1 and torch.cuda.is_available():
  1040. device = "cuda"
  1041. else:
  1042. device = "cpu"
  1043. # 1. Set random-seed
  1044. set_all_random_seed(seed)
  1045. # 2. Build speech2text
  1046. speech2text_kwargs = dict(
  1047. asr_train_config=asr_train_config,
  1048. asr_model_file=asr_model_file,
  1049. cmvn_file=cmvn_file,
  1050. lm_train_config=lm_train_config,
  1051. lm_file=lm_file,
  1052. token_type=token_type,
  1053. bpemodel=bpemodel,
  1054. device=device,
  1055. maxlenratio=maxlenratio,
  1056. minlenratio=minlenratio,
  1057. dtype=dtype,
  1058. beam_size=beam_size,
  1059. ctc_weight=ctc_weight,
  1060. lm_weight=lm_weight,
  1061. ngram_weight=ngram_weight,
  1062. penalty=penalty,
  1063. nbest=nbest,
  1064. streaming=streaming,
  1065. )
  1066. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1067. speech2text = Speech2TextMFCCA(**speech2text_kwargs)
  1068. def _forward(data_path_and_name_and_type,
  1069. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1070. output_dir_v2: Optional[str] = None,
  1071. fs: dict = None,
  1072. param_dict: dict = None,
  1073. **kwargs,
  1074. ):
  1075. # 3. Build data-iterator
  1076. if data_path_and_name_and_type is None and raw_inputs is not None:
  1077. if isinstance(raw_inputs, torch.Tensor):
  1078. raw_inputs = raw_inputs.numpy()
  1079. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1080. loader = build_streaming_iterator(
  1081. task_name="asr",
  1082. preprocess_args=speech2text.asr_train_args,
  1083. data_path_and_name_and_type=data_path_and_name_and_type,
  1084. dtype=dtype,
  1085. batch_size=batch_size,
  1086. fs=fs,
  1087. mc=True,
  1088. key_file=key_file,
  1089. num_workers=num_workers,
  1090. )
  1091. finish_count = 0
  1092. file_count = 1
  1093. # 7 .Start for-loop
  1094. # FIXME(kamo): The output format should be discussed about
  1095. asr_result_list = []
  1096. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1097. if output_path is not None:
  1098. writer = DatadirWriter(output_path)
  1099. else:
  1100. writer = None
  1101. for keys, batch in loader:
  1102. assert isinstance(batch, dict), type(batch)
  1103. assert all(isinstance(s, str) for s in keys), keys
  1104. _bs = len(next(iter(batch.values())))
  1105. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1106. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1107. # N-best list of (text, token, token_int, hyp_object)
  1108. try:
  1109. results = speech2text(**batch)
  1110. except TooShortUttError as e:
  1111. logging.warning(f"Utterance {keys} {e}")
  1112. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1113. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1114. # Only supporting batch_size==1
  1115. key = keys[0]
  1116. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1117. # Create a directory: outdir/{n}best_recog
  1118. if writer is not None:
  1119. ibest_writer = writer[f"{n}best_recog"]
  1120. # Write the result to each file
  1121. ibest_writer["token"][key] = " ".join(token)
  1122. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1123. ibest_writer["score"][key] = str(hyp.score)
  1124. if text is not None:
  1125. text_postprocessed = postprocess_utils.sentence_postprocess(token)
  1126. item = {'key': key, 'value': text_postprocessed}
  1127. asr_result_list.append(item)
  1128. finish_count += 1
  1129. asr_utils.print_progress(finish_count / file_count)
  1130. if writer is not None:
  1131. ibest_writer["text"][key] = text
  1132. return asr_result_list
  1133. return _forward
  1134. def inference_transducer(
  1135. output_dir: str,
  1136. batch_size: int,
  1137. dtype: str,
  1138. beam_size: int,
  1139. ngpu: int,
  1140. seed: int,
  1141. lm_weight: float,
  1142. nbest: int,
  1143. num_workers: int,
  1144. log_level: Union[int, str],
  1145. # data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  1146. asr_train_config: Optional[str],
  1147. asr_model_file: Optional[str],
  1148. cmvn_file: Optional[str] = None,
  1149. beam_search_config: Optional[dict] = None,
  1150. lm_train_config: Optional[str] = None,
  1151. lm_file: Optional[str] = None,
  1152. model_tag: Optional[str] = None,
  1153. token_type: Optional[str] = None,
  1154. bpemodel: Optional[str] = None,
  1155. key_file: Optional[str] = None,
  1156. allow_variable_data_keys: bool = False,
  1157. quantize_asr_model: Optional[bool] = False,
  1158. quantize_modules: Optional[List[str]] = None,
  1159. quantize_dtype: Optional[str] = "float16",
  1160. streaming: Optional[bool] = False,
  1161. simu_streaming: Optional[bool] = False,
  1162. chunk_size: Optional[int] = 16,
  1163. left_context: Optional[int] = 16,
  1164. right_context: Optional[int] = 0,
  1165. display_partial_hypotheses: bool = False,
  1166. **kwargs,
  1167. ) -> None:
  1168. """Transducer model inference.
  1169. Args:
  1170. output_dir: Output directory path.
  1171. batch_size: Batch decoding size.
  1172. dtype: Data type.
  1173. beam_size: Beam size.
  1174. ngpu: Number of GPUs.
  1175. seed: Random number generator seed.
  1176. lm_weight: Weight of language model.
  1177. nbest: Number of final hypothesis.
  1178. num_workers: Number of workers.
  1179. log_level: Level of verbose for logs.
  1180. data_path_and_name_and_type:
  1181. asr_train_config: ASR model training config path.
  1182. asr_model_file: ASR model path.
  1183. beam_search_config: Beam search config path.
  1184. lm_train_config: Language Model training config path.
  1185. lm_file: Language Model path.
  1186. model_tag: Model tag.
  1187. token_type: Type of token units.
  1188. bpemodel: BPE model path.
  1189. key_file: File key.
  1190. allow_variable_data_keys: Whether to allow variable data keys.
  1191. quantize_asr_model: Whether to apply dynamic quantization to ASR model.
  1192. quantize_modules: List of module names to apply dynamic quantization on.
  1193. quantize_dtype: Dynamic quantization data type.
  1194. streaming: Whether to perform chunk-by-chunk inference.
  1195. chunk_size: Number of frames in chunk AFTER subsampling.
  1196. left_context: Number of frames in left context AFTER subsampling.
  1197. right_context: Number of frames in right context AFTER subsampling.
  1198. display_partial_hypotheses: Whether to display partial hypotheses.
  1199. """
  1200. if batch_size > 1:
  1201. raise NotImplementedError("batch decoding is not implemented")
  1202. if ngpu > 1:
  1203. raise NotImplementedError("only single GPU decoding is supported")
  1204. logging.basicConfig(
  1205. level=log_level,
  1206. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1207. )
  1208. if ngpu >= 1 and torch.cuda.is_available():
  1209. device = "cuda"
  1210. else:
  1211. device = "cpu"
  1212. # 1. Set random-seed
  1213. set_all_random_seed(seed)
  1214. # 2. Build speech2text
  1215. speech2text_kwargs = dict(
  1216. asr_train_config=asr_train_config,
  1217. asr_model_file=asr_model_file,
  1218. cmvn_file=cmvn_file,
  1219. beam_search_config=beam_search_config,
  1220. lm_train_config=lm_train_config,
  1221. lm_file=lm_file,
  1222. token_type=token_type,
  1223. bpemodel=bpemodel,
  1224. device=device,
  1225. dtype=dtype,
  1226. beam_size=beam_size,
  1227. lm_weight=lm_weight,
  1228. nbest=nbest,
  1229. quantize_asr_model=quantize_asr_model,
  1230. quantize_modules=quantize_modules,
  1231. quantize_dtype=quantize_dtype,
  1232. streaming=streaming,
  1233. simu_streaming=simu_streaming,
  1234. chunk_size=chunk_size,
  1235. left_context=left_context,
  1236. right_context=right_context,
  1237. )
  1238. speech2text = Speech2TextTransducer(**speech2text_kwargs)
  1239. def _forward(data_path_and_name_and_type,
  1240. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1241. output_dir_v2: Optional[str] = None,
  1242. fs: dict = None,
  1243. param_dict: dict = None,
  1244. **kwargs,
  1245. ):
  1246. # 3. Build data-iterator
  1247. loader = build_streaming_iterator(
  1248. task_name="asr",
  1249. preprocess_args=speech2text.asr_train_args,
  1250. data_path_and_name_and_type=data_path_and_name_and_type,
  1251. dtype=dtype,
  1252. batch_size=batch_size,
  1253. key_file=key_file,
  1254. num_workers=num_workers,
  1255. )
  1256. asr_result_list = []
  1257. if output_dir is not None:
  1258. writer = DatadirWriter(output_dir)
  1259. else:
  1260. writer = None
  1261. # 4 .Start for-loop
  1262. for keys, batch in loader:
  1263. assert isinstance(batch, dict), type(batch)
  1264. assert all(isinstance(s, str) for s in keys), keys
  1265. _bs = len(next(iter(batch.values())))
  1266. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1267. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1268. assert len(batch.keys()) == 1
  1269. try:
  1270. if speech2text.streaming:
  1271. speech = batch["speech"]
  1272. _steps = len(speech) // speech2text._ctx
  1273. _end = 0
  1274. for i in range(_steps):
  1275. _end = (i + 1) * speech2text._ctx
  1276. speech2text.streaming_decode(
  1277. speech[i * speech2text._ctx: _end], is_final=False
  1278. )
  1279. final_hyps = speech2text.streaming_decode(
  1280. speech[_end: len(speech)], is_final=True
  1281. )
  1282. elif speech2text.simu_streaming:
  1283. final_hyps = speech2text.simu_streaming_decode(**batch)
  1284. else:
  1285. final_hyps = speech2text(**batch)
  1286. results = speech2text.hypotheses_to_results(final_hyps)
  1287. except TooShortUttError as e:
  1288. logging.warning(f"Utterance {keys} {e}")
  1289. hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
  1290. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1291. key = keys[0]
  1292. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1293. item = {'key': key, 'value': text}
  1294. asr_result_list.append(item)
  1295. if writer is not None:
  1296. ibest_writer = writer[f"{n}best_recog"]
  1297. ibest_writer["token"][key] = " ".join(token)
  1298. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1299. ibest_writer["score"][key] = str(hyp.score)
  1300. if text is not None:
  1301. ibest_writer["text"][key] = text
  1302. logging.info("decoding, utt: {}, predictions: {}".format(key, text))
  1303. return asr_result_list
  1304. return _forward
  1305. def inference_sa_asr(
  1306. maxlenratio: float,
  1307. minlenratio: float,
  1308. batch_size: int,
  1309. beam_size: int,
  1310. ngpu: int,
  1311. ctc_weight: float,
  1312. lm_weight: float,
  1313. penalty: float,
  1314. log_level: Union[int, str],
  1315. # data_path_and_name_and_type,
  1316. asr_train_config: Optional[str],
  1317. asr_model_file: Optional[str],
  1318. cmvn_file: Optional[str] = None,
  1319. lm_train_config: Optional[str] = None,
  1320. lm_file: Optional[str] = None,
  1321. token_type: Optional[str] = None,
  1322. key_file: Optional[str] = None,
  1323. word_lm_train_config: Optional[str] = None,
  1324. bpemodel: Optional[str] = None,
  1325. allow_variable_data_keys: bool = False,
  1326. streaming: bool = False,
  1327. output_dir: Optional[str] = None,
  1328. dtype: str = "float32",
  1329. seed: int = 0,
  1330. ngram_weight: float = 0.9,
  1331. nbest: int = 1,
  1332. num_workers: int = 1,
  1333. mc: bool = False,
  1334. param_dict: dict = None,
  1335. **kwargs,
  1336. ):
  1337. if batch_size > 1:
  1338. raise NotImplementedError("batch decoding is not implemented")
  1339. if word_lm_train_config is not None:
  1340. raise NotImplementedError("Word LM is not implemented")
  1341. if ngpu > 1:
  1342. raise NotImplementedError("only single GPU decoding is supported")
  1343. for handler in logging.root.handlers[:]:
  1344. logging.root.removeHandler(handler)
  1345. logging.basicConfig(
  1346. level=log_level,
  1347. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1348. )
  1349. if ngpu >= 1 and torch.cuda.is_available():
  1350. device = "cuda"
  1351. else:
  1352. device = "cpu"
  1353. # 1. Set random-seed
  1354. set_all_random_seed(seed)
  1355. # 2. Build speech2text
  1356. speech2text_kwargs = dict(
  1357. asr_train_config=asr_train_config,
  1358. asr_model_file=asr_model_file,
  1359. cmvn_file=cmvn_file,
  1360. lm_train_config=lm_train_config,
  1361. lm_file=lm_file,
  1362. token_type=token_type,
  1363. bpemodel=bpemodel,
  1364. device=device,
  1365. maxlenratio=maxlenratio,
  1366. minlenratio=minlenratio,
  1367. dtype=dtype,
  1368. beam_size=beam_size,
  1369. ctc_weight=ctc_weight,
  1370. lm_weight=lm_weight,
  1371. ngram_weight=ngram_weight,
  1372. penalty=penalty,
  1373. nbest=nbest,
  1374. streaming=streaming,
  1375. )
  1376. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1377. speech2text = Speech2TextSAASR(**speech2text_kwargs)
  1378. def _forward(data_path_and_name_and_type,
  1379. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1380. output_dir_v2: Optional[str] = None,
  1381. fs: dict = None,
  1382. param_dict: dict = None,
  1383. **kwargs,
  1384. ):
  1385. # 3. Build data-iterator
  1386. if data_path_and_name_and_type is None and raw_inputs is not None:
  1387. if isinstance(raw_inputs, torch.Tensor):
  1388. raw_inputs = raw_inputs.numpy()
  1389. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1390. loader = build_streaming_iterator(
  1391. task_name="asr",
  1392. preprocess_args=speech2text.asr_train_args,
  1393. data_path_and_name_and_type=data_path_and_name_and_type,
  1394. dtype=dtype,
  1395. fs=fs,
  1396. mc=mc,
  1397. batch_size=batch_size,
  1398. key_file=key_file,
  1399. num_workers=num_workers,
  1400. )
  1401. finish_count = 0
  1402. file_count = 1
  1403. # 7 .Start for-loop
  1404. # FIXME(kamo): The output format should be discussed about
  1405. asr_result_list = []
  1406. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1407. if output_path is not None:
  1408. writer = DatadirWriter(output_path)
  1409. else:
  1410. writer = None
  1411. for keys, batch in loader:
  1412. assert isinstance(batch, dict), type(batch)
  1413. assert all(isinstance(s, str) for s in keys), keys
  1414. _bs = len(next(iter(batch.values())))
  1415. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1416. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1417. # N-best list of (text, token, token_int, hyp_object)
  1418. try:
  1419. results = speech2text(**batch)
  1420. except TooShortUttError as e:
  1421. logging.warning(f"Utterance {keys} {e}")
  1422. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1423. results = [[" ", ["sil"], [2], hyp]] * nbest
  1424. # Only supporting batch_size==1
  1425. key = keys[0]
  1426. for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1427. # Create a directory: outdir/{n}best_recog
  1428. if writer is not None:
  1429. ibest_writer = writer[f"{n}best_recog"]
  1430. # Write the result to each file
  1431. ibest_writer["token"][key] = " ".join(token)
  1432. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1433. ibest_writer["score"][key] = str(hyp.score)
  1434. ibest_writer["text_id"][key] = text_id
  1435. if text is not None:
  1436. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  1437. item = {'key': key, 'value': text_postprocessed}
  1438. asr_result_list.append(item)
  1439. finish_count += 1
  1440. asr_utils.print_progress(finish_count / file_count)
  1441. if writer is not None:
  1442. ibest_writer["text"][key] = text
  1443. logging.info("uttid: {}".format(key))
  1444. logging.info("text predictions: {}".format(text))
  1445. logging.info("text_id predictions: {}\n".format(text_id))
  1446. return asr_result_list
  1447. return _forward
  1448. def inference_launch(**kwargs):
  1449. if 'mode' in kwargs:
  1450. mode = kwargs['mode']
  1451. else:
  1452. logging.info("Unknown decoding mode.")
  1453. return None
  1454. if mode == "asr":
  1455. return inference_asr(**kwargs)
  1456. elif mode == "uniasr":
  1457. return inference_uniasr(**kwargs)
  1458. elif mode == "paraformer":
  1459. return inference_paraformer(**kwargs)
  1460. elif mode == "paraformer_fake_streaming":
  1461. return inference_paraformer(**kwargs)
  1462. elif mode == "paraformer_streaming":
  1463. return inference_paraformer_online(**kwargs)
  1464. elif mode.startswith("paraformer_vad"):
  1465. return inference_paraformer_vad_punc(**kwargs)
  1466. elif mode == "mfcca":
  1467. return inference_mfcca(**kwargs)
  1468. elif mode == "rnnt":
  1469. return inference_transducer(**kwargs)
  1470. elif mode == "bat":
  1471. return inference_transducer(**kwargs)
  1472. elif mode == "sa_asr":
  1473. return inference_sa_asr(**kwargs)
  1474. else:
  1475. logging.info("Unknown decoding mode: {}".format(mode))
  1476. return None
  1477. def get_parser():
  1478. parser = config_argparse.ArgumentParser(
  1479. description="ASR Decoding",
  1480. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  1481. )
  1482. # Note(kamo): Use '_' instead of '-' as separator.
  1483. # '-' is confusing if written in yaml.
  1484. parser.add_argument(
  1485. "--log_level",
  1486. type=lambda x: x.upper(),
  1487. default="INFO",
  1488. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  1489. help="The verbose level of logging",
  1490. )
  1491. parser.add_argument("--output_dir", type=str, required=True)
  1492. parser.add_argument(
  1493. "--ngpu",
  1494. type=int,
  1495. default=0,
  1496. help="The number of gpus. 0 indicates CPU mode",
  1497. )
  1498. parser.add_argument(
  1499. "--njob",
  1500. type=int,
  1501. default=1,
  1502. help="The number of jobs for each gpu",
  1503. )
  1504. parser.add_argument(
  1505. "--gpuid_list",
  1506. type=str,
  1507. default="",
  1508. help="The visible gpus",
  1509. )
  1510. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  1511. parser.add_argument(
  1512. "--dtype",
  1513. default="float32",
  1514. choices=["float16", "float32", "float64"],
  1515. help="Data type",
  1516. )
  1517. parser.add_argument(
  1518. "--num_workers",
  1519. type=int,
  1520. default=1,
  1521. help="The number of workers used for DataLoader",
  1522. )
  1523. group = parser.add_argument_group("Input data related")
  1524. group.add_argument(
  1525. "--data_path_and_name_and_type",
  1526. type=str2triple_str,
  1527. required=True,
  1528. action="append",
  1529. )
  1530. group.add_argument("--key_file", type=str_or_none)
  1531. parser.add_argument(
  1532. "--hotword",
  1533. type=str_or_none,
  1534. default=None,
  1535. help="hotword file path or hotwords seperated by space"
  1536. )
  1537. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  1538. group.add_argument(
  1539. "--mc",
  1540. type=bool,
  1541. default=False,
  1542. help="MultiChannel input",
  1543. )
  1544. group = parser.add_argument_group("The model configuration related")
  1545. group.add_argument(
  1546. "--vad_infer_config",
  1547. type=str,
  1548. help="VAD infer configuration",
  1549. )
  1550. group.add_argument(
  1551. "--vad_model_file",
  1552. type=str,
  1553. help="VAD model parameter file",
  1554. )
  1555. group.add_argument(
  1556. "--cmvn_file",
  1557. type=str,
  1558. help="Global CMVN file",
  1559. )
  1560. group.add_argument(
  1561. "--asr_train_config",
  1562. type=str,
  1563. help="ASR training configuration",
  1564. )
  1565. group.add_argument(
  1566. "--asr_model_file",
  1567. type=str,
  1568. help="ASR model parameter file",
  1569. )
  1570. group.add_argument(
  1571. "--lm_train_config",
  1572. type=str,
  1573. help="LM training configuration",
  1574. )
  1575. group.add_argument(
  1576. "--lm_file",
  1577. type=str,
  1578. help="LM parameter file",
  1579. )
  1580. group.add_argument(
  1581. "--word_lm_train_config",
  1582. type=str,
  1583. help="Word LM training configuration",
  1584. )
  1585. group.add_argument(
  1586. "--word_lm_file",
  1587. type=str,
  1588. help="Word LM parameter file",
  1589. )
  1590. group.add_argument(
  1591. "--ngram_file",
  1592. type=str,
  1593. help="N-gram parameter file",
  1594. )
  1595. group.add_argument(
  1596. "--model_tag",
  1597. type=str,
  1598. help="Pretrained model tag. If specify this option, *_train_config and "
  1599. "*_file will be overwritten",
  1600. )
  1601. group.add_argument(
  1602. "--beam_search_config",
  1603. default={},
  1604. help="The keyword arguments for transducer beam search.",
  1605. )
  1606. group = parser.add_argument_group("Beam-search related")
  1607. group.add_argument(
  1608. "--batch_size",
  1609. type=int,
  1610. default=1,
  1611. help="The batch size for inference",
  1612. )
  1613. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  1614. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  1615. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  1616. group.add_argument(
  1617. "--maxlenratio",
  1618. type=float,
  1619. default=0.0,
  1620. help="Input length ratio to obtain max output length. "
  1621. "If maxlenratio=0.0 (default), it uses a end-detect "
  1622. "function "
  1623. "to automatically find maximum hypothesis lengths."
  1624. "If maxlenratio<0.0, its absolute value is interpreted"
  1625. "as a constant max output length",
  1626. )
  1627. group.add_argument(
  1628. "--minlenratio",
  1629. type=float,
  1630. default=0.0,
  1631. help="Input length ratio to obtain min output length",
  1632. )
  1633. group.add_argument(
  1634. "--ctc_weight",
  1635. type=float,
  1636. default=0.0,
  1637. help="CTC weight in joint decoding",
  1638. )
  1639. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  1640. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  1641. group.add_argument("--streaming", type=str2bool, default=False)
  1642. group.add_argument("--simu_streaming", type=str2bool, default=False)
  1643. group.add_argument("--chunk_size", type=int, default=16)
  1644. group.add_argument("--left_context", type=int, default=16)
  1645. group.add_argument("--right_context", type=int, default=0)
  1646. group.add_argument(
  1647. "--display_partial_hypotheses",
  1648. type=bool,
  1649. default=False,
  1650. help="Whether to display partial hypotheses during chunk-by-chunk inference.",
  1651. )
  1652. group = parser.add_argument_group("Dynamic quantization related")
  1653. group.add_argument(
  1654. "--quantize_asr_model",
  1655. type=bool,
  1656. default=False,
  1657. help="Apply dynamic quantization to ASR model.",
  1658. )
  1659. group.add_argument(
  1660. "--quantize_modules",
  1661. nargs="*",
  1662. default=None,
  1663. help="""Module names to apply dynamic quantization on.
  1664. The module names are provided as a list, where each name is separated
  1665. by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
  1666. Each specified name should be an attribute of 'torch.nn', e.g.:
  1667. torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
  1668. )
  1669. group.add_argument(
  1670. "--quantize_dtype",
  1671. type=str,
  1672. default="qint8",
  1673. choices=["float16", "qint8"],
  1674. help="Dtype for dynamic quantization.",
  1675. )
  1676. group = parser.add_argument_group("Text converter related")
  1677. group.add_argument(
  1678. "--token_type",
  1679. type=str_or_none,
  1680. default=None,
  1681. choices=["char", "bpe", None],
  1682. help="The token type for ASR model. "
  1683. "If not given, refers from the training args",
  1684. )
  1685. group.add_argument(
  1686. "--bpemodel",
  1687. type=str_or_none,
  1688. default=None,
  1689. help="The model path of sentencepiece. "
  1690. "If not given, refers from the training args",
  1691. )
  1692. group.add_argument("--token_num_relax", type=int, default=1, help="")
  1693. group.add_argument("--decoding_ind", type=int, default=0, help="")
  1694. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  1695. group.add_argument(
  1696. "--ctc_weight2",
  1697. type=float,
  1698. default=0.0,
  1699. help="CTC weight in joint decoding",
  1700. )
  1701. return parser
  1702. def main(cmd=None):
  1703. print(get_commandline_args(), file=sys.stderr)
  1704. parser = get_parser()
  1705. parser.add_argument(
  1706. "--mode",
  1707. type=str,
  1708. default="asr",
  1709. help="The decoding mode",
  1710. )
  1711. args = parser.parse_args(cmd)
  1712. kwargs = vars(args)
  1713. kwargs.pop("config", None)
  1714. # set logging messages
  1715. logging.basicConfig(
  1716. level=args.log_level,
  1717. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1718. )
  1719. logging.info("Decoding args: {}".format(kwargs))
  1720. # gpu setting
  1721. if args.ngpu > 0:
  1722. jobid = int(args.output_dir.split(".")[-1])
  1723. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  1724. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  1725. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  1726. inference_pipeline = inference_launch(**kwargs)
  1727. return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
  1728. if __name__ == "__main__":
  1729. main()