asr_inference_launch.py 69 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. import time
  10. from pathlib import Path
  11. from typing import Dict
  12. from typing import List
  13. from typing import Optional
  14. from typing import Sequence
  15. from typing import Tuple
  16. from typing import Union
  17. import numpy as np
  18. import torch
  19. import torchaudio
  20. import soundfile
  21. import yaml
  22. from funasr.bin.asr_infer import Speech2Text
  23. from funasr.bin.asr_infer import Speech2TextMFCCA
  24. from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
  25. from funasr.bin.asr_infer import Speech2TextSAASR
  26. from funasr.bin.asr_infer import Speech2TextTransducer
  27. from funasr.bin.asr_infer import Speech2TextUniASR
  28. from funasr.bin.punc_infer import Text2Punc
  29. from funasr.bin.tp_infer import Speech2Timestamp
  30. from funasr.bin.vad_infer import Speech2VadSegment
  31. from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
  32. from funasr.fileio.datadir_writer import DatadirWriter
  33. from funasr.modules.beam_search.beam_search import Hypothesis
  34. from funasr.modules.subsampling import TooShortUttError
  35. from funasr.torch_utils.device_funcs import to_device
  36. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  37. from funasr.utils import asr_utils, postprocess_utils
  38. from funasr.utils import config_argparse
  39. from funasr.utils.cli_utils import get_commandline_args
  40. from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
  41. from funasr.utils.types import str2bool
  42. from funasr.utils.types import str2triple_str
  43. from funasr.utils.types import str_or_none
  44. from funasr.utils.vad_utils import slice_padding_fbank
  45. def inference_asr(
  46. maxlenratio: float,
  47. minlenratio: float,
  48. batch_size: int,
  49. beam_size: int,
  50. ngpu: int,
  51. ctc_weight: float,
  52. lm_weight: float,
  53. penalty: float,
  54. log_level: Union[int, str],
  55. # data_path_and_name_and_type,
  56. asr_train_config: Optional[str],
  57. asr_model_file: Optional[str],
  58. cmvn_file: Optional[str] = None,
  59. lm_train_config: Optional[str] = None,
  60. lm_file: Optional[str] = None,
  61. token_type: Optional[str] = None,
  62. key_file: Optional[str] = None,
  63. word_lm_train_config: Optional[str] = None,
  64. bpemodel: Optional[str] = None,
  65. allow_variable_data_keys: bool = False,
  66. streaming: bool = False,
  67. output_dir: Optional[str] = None,
  68. dtype: str = "float32",
  69. seed: int = 0,
  70. ngram_weight: float = 0.9,
  71. nbest: int = 1,
  72. num_workers: int = 1,
  73. mc: bool = False,
  74. param_dict: dict = None,
  75. **kwargs,
  76. ):
  77. ncpu = kwargs.get("ncpu", 1)
  78. torch.set_num_threads(ncpu)
  79. if batch_size > 1:
  80. raise NotImplementedError("batch decoding is not implemented")
  81. if word_lm_train_config is not None:
  82. raise NotImplementedError("Word LM is not implemented")
  83. if ngpu > 1:
  84. raise NotImplementedError("only single GPU decoding is supported")
  85. for handler in logging.root.handlers[:]:
  86. logging.root.removeHandler(handler)
  87. logging.basicConfig(
  88. level=log_level,
  89. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  90. )
  91. if ngpu >= 1 and torch.cuda.is_available():
  92. device = "cuda"
  93. else:
  94. device = "cpu"
  95. # 1. Set random-seed
  96. set_all_random_seed(seed)
  97. # 2. Build speech2text
  98. speech2text_kwargs = dict(
  99. asr_train_config=asr_train_config,
  100. asr_model_file=asr_model_file,
  101. cmvn_file=cmvn_file,
  102. lm_train_config=lm_train_config,
  103. lm_file=lm_file,
  104. token_type=token_type,
  105. bpemodel=bpemodel,
  106. device=device,
  107. maxlenratio=maxlenratio,
  108. minlenratio=minlenratio,
  109. dtype=dtype,
  110. beam_size=beam_size,
  111. ctc_weight=ctc_weight,
  112. lm_weight=lm_weight,
  113. ngram_weight=ngram_weight,
  114. penalty=penalty,
  115. nbest=nbest,
  116. streaming=streaming,
  117. )
  118. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  119. speech2text = Speech2Text(**speech2text_kwargs)
  120. def _forward(data_path_and_name_and_type,
  121. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  122. output_dir_v2: Optional[str] = None,
  123. fs: dict = None,
  124. param_dict: dict = None,
  125. **kwargs,
  126. ):
  127. # 3. Build data-iterator
  128. if data_path_and_name_and_type is None and raw_inputs is not None:
  129. if isinstance(raw_inputs, torch.Tensor):
  130. raw_inputs = raw_inputs.numpy()
  131. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  132. loader = build_streaming_iterator(
  133. task_name="asr",
  134. preprocess_args=speech2text.asr_train_args,
  135. data_path_and_name_and_type=data_path_and_name_and_type,
  136. dtype=dtype,
  137. fs=fs,
  138. mc=mc,
  139. batch_size=batch_size,
  140. key_file=key_file,
  141. num_workers=num_workers,
  142. )
  143. finish_count = 0
  144. file_count = 1
  145. # 7 .Start for-loop
  146. # FIXME(kamo): The output format should be discussed about
  147. asr_result_list = []
  148. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  149. if output_path is not None:
  150. writer = DatadirWriter(output_path)
  151. else:
  152. writer = None
  153. for keys, batch in loader:
  154. assert isinstance(batch, dict), type(batch)
  155. assert all(isinstance(s, str) for s in keys), keys
  156. _bs = len(next(iter(batch.values())))
  157. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  158. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  159. # N-best list of (text, token, token_int, hyp_object)
  160. try:
  161. results = speech2text(**batch)
  162. except TooShortUttError as e:
  163. logging.warning(f"Utterance {keys} {e}")
  164. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  165. results = [[" ", ["sil"], [2], hyp]] * nbest
  166. # Only supporting batch_size==1
  167. key = keys[0]
  168. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  169. # Create a directory: outdir/{n}best_recog
  170. if writer is not None:
  171. ibest_writer = writer[f"{n}best_recog"]
  172. # Write the result to each file
  173. ibest_writer["token"][key] = " ".join(token)
  174. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  175. ibest_writer["score"][key] = str(hyp.score)
  176. if text is not None:
  177. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  178. item = {'key': key, 'value': text_postprocessed}
  179. asr_result_list.append(item)
  180. finish_count += 1
  181. asr_utils.print_progress(finish_count / file_count)
  182. if writer is not None:
  183. ibest_writer["text"][key] = text
  184. logging.info("uttid: {}".format(key))
  185. logging.info("text predictions: {}\n".format(text))
  186. return asr_result_list
  187. return _forward
  188. def inference_paraformer(
  189. maxlenratio: float,
  190. minlenratio: float,
  191. batch_size: int,
  192. beam_size: int,
  193. ngpu: int,
  194. ctc_weight: float,
  195. lm_weight: float,
  196. penalty: float,
  197. log_level: Union[int, str],
  198. # data_path_and_name_and_type,
  199. asr_train_config: Optional[str],
  200. asr_model_file: Optional[str],
  201. cmvn_file: Optional[str] = None,
  202. lm_train_config: Optional[str] = None,
  203. lm_file: Optional[str] = None,
  204. token_type: Optional[str] = None,
  205. key_file: Optional[str] = None,
  206. word_lm_train_config: Optional[str] = None,
  207. bpemodel: Optional[str] = None,
  208. allow_variable_data_keys: bool = False,
  209. dtype: str = "float32",
  210. seed: int = 0,
  211. ngram_weight: float = 0.9,
  212. nbest: int = 1,
  213. num_workers: int = 1,
  214. output_dir: Optional[str] = None,
  215. timestamp_infer_config: Union[Path, str] = None,
  216. timestamp_model_file: Union[Path, str] = None,
  217. param_dict: dict = None,
  218. **kwargs,
  219. ):
  220. ncpu = kwargs.get("ncpu", 1)
  221. torch.set_num_threads(ncpu)
  222. if word_lm_train_config is not None:
  223. raise NotImplementedError("Word LM is not implemented")
  224. if ngpu > 1:
  225. raise NotImplementedError("only single GPU decoding is supported")
  226. logging.basicConfig(
  227. level=log_level,
  228. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  229. )
  230. export_mode = False
  231. if param_dict is not None:
  232. hotword_list_or_file = param_dict.get('hotword')
  233. export_mode = param_dict.get("export_mode", False)
  234. else:
  235. hotword_list_or_file = None
  236. if kwargs.get("device", None) == "cpu":
  237. ngpu = 0
  238. if ngpu >= 1 and torch.cuda.is_available():
  239. device = "cuda"
  240. else:
  241. device = "cpu"
  242. batch_size = 1
  243. # 1. Set random-seed
  244. set_all_random_seed(seed)
  245. # 2. Build speech2text
  246. speech2text_kwargs = dict(
  247. asr_train_config=asr_train_config,
  248. asr_model_file=asr_model_file,
  249. cmvn_file=cmvn_file,
  250. lm_train_config=lm_train_config,
  251. lm_file=lm_file,
  252. token_type=token_type,
  253. bpemodel=bpemodel,
  254. device=device,
  255. maxlenratio=maxlenratio,
  256. minlenratio=minlenratio,
  257. dtype=dtype,
  258. beam_size=beam_size,
  259. ctc_weight=ctc_weight,
  260. lm_weight=lm_weight,
  261. ngram_weight=ngram_weight,
  262. penalty=penalty,
  263. nbest=nbest,
  264. hotword_list_or_file=hotword_list_or_file,
  265. )
  266. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  267. if timestamp_model_file is not None:
  268. speechtext2timestamp = Speech2Timestamp(
  269. timestamp_cmvn_file=cmvn_file,
  270. timestamp_model_file=timestamp_model_file,
  271. timestamp_infer_config=timestamp_infer_config,
  272. )
  273. else:
  274. speechtext2timestamp = None
  275. def _forward(
  276. data_path_and_name_and_type,
  277. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  278. output_dir_v2: Optional[str] = None,
  279. fs: dict = None,
  280. param_dict: dict = None,
  281. **kwargs,
  282. ):
  283. hotword_list_or_file = None
  284. if param_dict is not None:
  285. hotword_list_or_file = param_dict.get('hotword')
  286. if 'hotword' in kwargs and kwargs['hotword'] is not None:
  287. hotword_list_or_file = kwargs['hotword']
  288. if hotword_list_or_file is not None or 'hotword' in kwargs:
  289. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  290. # 3. Build data-iterator
  291. if data_path_and_name_and_type is None and raw_inputs is not None:
  292. if isinstance(raw_inputs, torch.Tensor):
  293. raw_inputs = raw_inputs.numpy()
  294. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  295. loader = build_streaming_iterator(
  296. task_name="asr",
  297. preprocess_args=speech2text.asr_train_args,
  298. data_path_and_name_and_type=data_path_and_name_and_type,
  299. dtype=dtype,
  300. fs=fs,
  301. batch_size=batch_size,
  302. key_file=key_file,
  303. num_workers=num_workers,
  304. )
  305. if param_dict is not None:
  306. use_timestamp = param_dict.get('use_timestamp', True)
  307. else:
  308. use_timestamp = True
  309. forward_time_total = 0.0
  310. length_total = 0.0
  311. finish_count = 0
  312. file_count = 1
  313. # 7 .Start for-loop
  314. # FIXME(kamo): The output format should be discussed about
  315. asr_result_list = []
  316. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  317. if output_path is not None:
  318. writer = DatadirWriter(output_path)
  319. else:
  320. writer = None
  321. for keys, batch in loader:
  322. assert isinstance(batch, dict), type(batch)
  323. assert all(isinstance(s, str) for s in keys), keys
  324. _bs = len(next(iter(batch.values())))
  325. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  326. # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
  327. logging.info("decoding, utt_id: {}".format(keys))
  328. # N-best list of (text, token, token_int, hyp_object)
  329. time_beg = time.time()
  330. results = speech2text(**batch)
  331. if len(results) < 1:
  332. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  333. results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
  334. time_end = time.time()
  335. forward_time = time_end - time_beg
  336. lfr_factor = results[0][-1]
  337. length = results[0][-2]
  338. forward_time_total += forward_time
  339. length_total += length
  340. rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
  341. 100 * forward_time / (
  342. length * lfr_factor))
  343. logging.info(rtf_cur)
  344. for batch_id in range(_bs):
  345. result = [results[batch_id][:-2]]
  346. key = keys[batch_id]
  347. for n, result in zip(range(1, nbest + 1), result):
  348. text, token, token_int, hyp = result[0], result[1], result[2], result[3]
  349. timestamp = result[4] if len(result[4]) > 0 else None
  350. # conduct timestamp prediction here
  351. # timestamp inference requires token length
  352. # thus following inference cannot be conducted in batch
  353. if timestamp is None and speechtext2timestamp:
  354. ts_batch = {}
  355. ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
  356. ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
  357. ts_batch['text_lengths'] = torch.tensor([len(token)])
  358. us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
  359. ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token,
  360. force_time_shift=-3.0)
  361. # Create a directory: outdir/{n}best_recog
  362. if writer is not None:
  363. ibest_writer = writer[f"{n}best_recog"]
  364. # Write the result to each file
  365. ibest_writer["token"][key] = " ".join(token)
  366. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  367. ibest_writer["score"][key] = str(hyp.score)
  368. ibest_writer["rtf"][key] = rtf_cur
  369. if text is not None:
  370. if use_timestamp and timestamp is not None:
  371. postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
  372. else:
  373. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  374. timestamp_postprocessed = ""
  375. if len(postprocessed_result) == 3:
  376. text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
  377. postprocessed_result[1], \
  378. postprocessed_result[2]
  379. else:
  380. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  381. item = {'key': key, 'value': text_postprocessed}
  382. if timestamp_postprocessed != "":
  383. item['timestamp'] = timestamp_postprocessed
  384. asr_result_list.append(item)
  385. finish_count += 1
  386. # asr_utils.print_progress(finish_count / file_count)
  387. if writer is not None:
  388. ibest_writer["text"][key] = " ".join(word_lists)
  389. logging.info("decoding, utt: {}, predictions: {}".format(key, text))
  390. rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
  391. forward_time_total,
  392. 100 * forward_time_total / (
  393. length_total * lfr_factor))
  394. logging.info(rtf_avg)
  395. if writer is not None:
  396. ibest_writer["rtf"]["rtf_avf"] = rtf_avg
  397. return asr_result_list
  398. return _forward
  399. def inference_paraformer_vad_punc(
  400. maxlenratio: float,
  401. minlenratio: float,
  402. batch_size: int,
  403. beam_size: int,
  404. ngpu: int,
  405. ctc_weight: float,
  406. lm_weight: float,
  407. penalty: float,
  408. log_level: Union[int, str],
  409. # data_path_and_name_and_type,
  410. asr_train_config: Optional[str],
  411. asr_model_file: Optional[str],
  412. cmvn_file: Optional[str] = None,
  413. lm_train_config: Optional[str] = None,
  414. lm_file: Optional[str] = None,
  415. token_type: Optional[str] = None,
  416. key_file: Optional[str] = None,
  417. word_lm_train_config: Optional[str] = None,
  418. bpemodel: Optional[str] = None,
  419. allow_variable_data_keys: bool = False,
  420. output_dir: Optional[str] = None,
  421. dtype: str = "float32",
  422. seed: int = 0,
  423. ngram_weight: float = 0.9,
  424. nbest: int = 1,
  425. num_workers: int = 1,
  426. vad_infer_config: Optional[str] = None,
  427. vad_model_file: Optional[str] = None,
  428. vad_cmvn_file: Optional[str] = None,
  429. time_stamp_writer: bool = True,
  430. punc_infer_config: Optional[str] = None,
  431. punc_model_file: Optional[str] = None,
  432. outputs_dict: Optional[bool] = True,
  433. param_dict: dict = None,
  434. **kwargs,
  435. ):
  436. ncpu = kwargs.get("ncpu", 1)
  437. torch.set_num_threads(ncpu)
  438. if word_lm_train_config is not None:
  439. raise NotImplementedError("Word LM is not implemented")
  440. if ngpu > 1:
  441. raise NotImplementedError("only single GPU decoding is supported")
  442. logging.basicConfig(
  443. level=log_level,
  444. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  445. )
  446. if param_dict is not None:
  447. hotword_list_or_file = param_dict.get('hotword')
  448. else:
  449. hotword_list_or_file = None
  450. if ngpu >= 1 and torch.cuda.is_available():
  451. device = "cuda"
  452. else:
  453. device = "cpu"
  454. # 1. Set random-seed
  455. set_all_random_seed(seed)
  456. # 2. Build speech2vadsegment
  457. speech2vadsegment_kwargs = dict(
  458. vad_infer_config=vad_infer_config,
  459. vad_model_file=vad_model_file,
  460. vad_cmvn_file=vad_cmvn_file,
  461. device=device,
  462. dtype=dtype,
  463. )
  464. # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  465. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  466. # 3. Build speech2text
  467. speech2text_kwargs = dict(
  468. asr_train_config=asr_train_config,
  469. asr_model_file=asr_model_file,
  470. cmvn_file=cmvn_file,
  471. lm_train_config=lm_train_config,
  472. lm_file=lm_file,
  473. token_type=token_type,
  474. bpemodel=bpemodel,
  475. device=device,
  476. maxlenratio=maxlenratio,
  477. minlenratio=minlenratio,
  478. dtype=dtype,
  479. beam_size=beam_size,
  480. ctc_weight=ctc_weight,
  481. lm_weight=lm_weight,
  482. ngram_weight=ngram_weight,
  483. penalty=penalty,
  484. nbest=nbest,
  485. hotword_list_or_file=hotword_list_or_file,
  486. )
  487. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  488. text2punc = None
  489. if punc_model_file is not None:
  490. text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
  491. if output_dir is not None:
  492. writer = DatadirWriter(output_dir)
  493. ibest_writer = writer[f"1best_recog"]
  494. ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
  495. def _forward(data_path_and_name_and_type,
  496. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  497. output_dir_v2: Optional[str] = None,
  498. fs: dict = None,
  499. param_dict: dict = None,
  500. **kwargs,
  501. ):
  502. hotword_list_or_file = None
  503. if param_dict is not None:
  504. hotword_list_or_file = param_dict.get('hotword')
  505. if 'hotword' in kwargs:
  506. hotword_list_or_file = kwargs['hotword']
  507. batch_size_token = kwargs.get("batch_size_token", 6000)
  508. print("batch_size_token: ", batch_size_token)
  509. if speech2text.hotword_list is None:
  510. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  511. # 3. Build data-iterator
  512. if data_path_and_name_and_type is None and raw_inputs is not None:
  513. if isinstance(raw_inputs, torch.Tensor):
  514. raw_inputs = raw_inputs.numpy()
  515. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  516. loader = build_streaming_iterator(
  517. task_name="asr",
  518. preprocess_args=None,
  519. data_path_and_name_and_type=data_path_and_name_and_type,
  520. dtype=dtype,
  521. fs=fs,
  522. batch_size=1,
  523. key_file=key_file,
  524. num_workers=num_workers,
  525. )
  526. if param_dict is not None:
  527. use_timestamp = param_dict.get('use_timestamp', True)
  528. else:
  529. use_timestamp = True
  530. finish_count = 0
  531. file_count = 1
  532. lfr_factor = 6
  533. # 7 .Start for-loop
  534. asr_result_list = []
  535. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  536. writer = None
  537. if output_path is not None:
  538. writer = DatadirWriter(output_path)
  539. ibest_writer = writer[f"1best_recog"]
  540. for keys, batch in loader:
  541. assert isinstance(batch, dict), type(batch)
  542. assert all(isinstance(s, str) for s in keys), keys
  543. _bs = len(next(iter(batch.values())))
  544. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  545. beg_vad = time.time()
  546. vad_results = speech2vadsegment(**batch)
  547. end_vad = time.time()
  548. print("time cost vad: ", end_vad - beg_vad)
  549. _, vadsegments = vad_results[0], vad_results[1][0]
  550. speech, speech_lengths = batch["speech"], batch["speech_lengths"]
  551. n = len(vadsegments)
  552. data_with_index = [(vadsegments[i], i) for i in range(n)]
  553. sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
  554. results_sorted = []
  555. batch_size_token_ms = batch_size_token*60
  556. if speech2text.device == "cpu":
  557. batch_size_token_ms = 0
  558. batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
  559. batch_size_token_ms_cum = 0
  560. beg_idx = 0
  561. for j, _ in enumerate(range(0, n)):
  562. batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
  563. if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
  564. 0]) < batch_size_token_ms:
  565. continue
  566. batch_size_token_ms_cum = 0
  567. end_idx = j + 1
  568. speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
  569. beg_idx = end_idx
  570. batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
  571. batch = to_device(batch, device=device)
  572. print("batch: ", speech_j.shape[0])
  573. beg_asr = time.time()
  574. results = speech2text(**batch)
  575. end_asr = time.time()
  576. print("time cost asr: ", end_asr - beg_asr)
  577. if len(results) < 1:
  578. results = [["", [], [], [], [], [], []]]
  579. results_sorted.extend(results)
  580. restored_data = [0] * n
  581. for j in range(n):
  582. index = sorted_data[j][1]
  583. restored_data[index] = results_sorted[j]
  584. result = ["", [], [], [], [], [], []]
  585. for j in range(n):
  586. result[0] += restored_data[j][0]
  587. result[1] += restored_data[j][1]
  588. result[2] += restored_data[j][2]
  589. if len(restored_data[j][4]) > 0:
  590. for t in restored_data[j][4]:
  591. t[0] += vadsegments[j][0]
  592. t[1] += vadsegments[j][0]
  593. result[4] += restored_data[j][4]
  594. # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
  595. key = keys[0]
  596. # result = result_segments[0]
  597. text, token, token_int = result[0], result[1], result[2]
  598. time_stamp = result[4] if len(result[4]) > 0 else None
  599. if use_timestamp and time_stamp is not None:
  600. postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
  601. else:
  602. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  603. text_postprocessed = ""
  604. time_stamp_postprocessed = ""
  605. text_postprocessed_punc = postprocessed_result
  606. if len(postprocessed_result) == 3:
  607. text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
  608. postprocessed_result[1], \
  609. postprocessed_result[2]
  610. else:
  611. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  612. text_postprocessed_punc = text_postprocessed
  613. punc_id_list = []
  614. if len(word_lists) > 0 and text2punc is not None:
  615. beg_punc = time.time()
  616. text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
  617. end_punc = time.time()
  618. print("time cost punc: ", end_punc - beg_punc)
  619. item = {'key': key, 'value': text_postprocessed_punc}
  620. if text_postprocessed != "":
  621. item['text_postprocessed'] = text_postprocessed
  622. if time_stamp_postprocessed != "":
  623. item['time_stamp'] = time_stamp_postprocessed
  624. item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
  625. asr_result_list.append(item)
  626. finish_count += 1
  627. # asr_utils.print_progress(finish_count / file_count)
  628. if writer is not None:
  629. # Write the result to each file
  630. ibest_writer["token"][key] = " ".join(token)
  631. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  632. ibest_writer["vad"][key] = "{}".format(vadsegments)
  633. ibest_writer["text"][key] = " ".join(word_lists)
  634. ibest_writer["text_with_punc"][key] = text_postprocessed_punc
  635. if time_stamp_postprocessed is not None:
  636. ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
  637. logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
  638. return asr_result_list
  639. return _forward
  640. def inference_paraformer_online(
  641. maxlenratio: float,
  642. minlenratio: float,
  643. batch_size: int,
  644. beam_size: int,
  645. ngpu: int,
  646. ctc_weight: float,
  647. lm_weight: float,
  648. penalty: float,
  649. log_level: Union[int, str],
  650. # data_path_and_name_and_type,
  651. asr_train_config: Optional[str],
  652. asr_model_file: Optional[str],
  653. cmvn_file: Optional[str] = None,
  654. lm_train_config: Optional[str] = None,
  655. lm_file: Optional[str] = None,
  656. token_type: Optional[str] = None,
  657. key_file: Optional[str] = None,
  658. word_lm_train_config: Optional[str] = None,
  659. bpemodel: Optional[str] = None,
  660. allow_variable_data_keys: bool = False,
  661. dtype: str = "float32",
  662. seed: int = 0,
  663. ngram_weight: float = 0.9,
  664. nbest: int = 1,
  665. num_workers: int = 1,
  666. output_dir: Optional[str] = None,
  667. param_dict: dict = None,
  668. **kwargs,
  669. ):
  670. if word_lm_train_config is not None:
  671. raise NotImplementedError("Word LM is not implemented")
  672. if ngpu > 1:
  673. raise NotImplementedError("only single GPU decoding is supported")
  674. logging.basicConfig(
  675. level=log_level,
  676. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  677. )
  678. export_mode = False
  679. if ngpu >= 1 and torch.cuda.is_available():
  680. device = "cuda"
  681. else:
  682. device = "cpu"
  683. batch_size = 1
  684. # 1. Set random-seed
  685. set_all_random_seed(seed)
  686. # 2. Build speech2text
  687. speech2text_kwargs = dict(
  688. asr_train_config=asr_train_config,
  689. asr_model_file=asr_model_file,
  690. cmvn_file=cmvn_file,
  691. lm_train_config=lm_train_config,
  692. lm_file=lm_file,
  693. token_type=token_type,
  694. bpemodel=bpemodel,
  695. device=device,
  696. maxlenratio=maxlenratio,
  697. minlenratio=minlenratio,
  698. dtype=dtype,
  699. beam_size=beam_size,
  700. ctc_weight=ctc_weight,
  701. lm_weight=lm_weight,
  702. ngram_weight=ngram_weight,
  703. penalty=penalty,
  704. nbest=nbest,
  705. )
  706. speech2text = Speech2TextParaformerOnline(**speech2text_kwargs)
  707. def _load_bytes(input):
  708. middle_data = np.frombuffer(input, dtype=np.int16)
  709. middle_data = np.asarray(middle_data)
  710. if middle_data.dtype.kind not in 'iu':
  711. raise TypeError("'middle_data' must be an array of integers")
  712. dtype = np.dtype('float32')
  713. if dtype.kind != 'f':
  714. raise TypeError("'dtype' must be a floating point type")
  715. i = np.iinfo(middle_data.dtype)
  716. abs_max = 2 ** (i.bits - 1)
  717. offset = i.min + abs_max
  718. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  719. return array
  720. def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
  721. if not Path(yaml_path).exists():
  722. raise FileExistsError(f'The {yaml_path} does not exist.')
  723. with open(str(yaml_path), 'rb') as f:
  724. data = yaml.load(f, Loader=yaml.Loader)
  725. return data
  726. def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
  727. if len(cache) > 0:
  728. return cache
  729. config = _read_yaml(asr_train_config)
  730. enc_output_size = config["encoder_conf"]["output_size"]
  731. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  732. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  733. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  734. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
  735. cache["encoder"] = cache_en
  736. cache_de = {"decode_fsmn": None}
  737. cache["decoder"] = cache_de
  738. return cache
  739. def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
  740. if len(cache) > 0:
  741. config = _read_yaml(asr_train_config)
  742. enc_output_size = config["encoder_conf"]["output_size"]
  743. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  744. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  745. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  746. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
  747. "tail_chunk": False}
  748. cache["encoder"] = cache_en
  749. cache_de = {"decode_fsmn": None}
  750. cache["decoder"] = cache_de
  751. return cache
  752. def _forward(
  753. data_path_and_name_and_type,
  754. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  755. output_dir_v2: Optional[str] = None,
  756. fs: dict = None,
  757. param_dict: dict = None,
  758. **kwargs,
  759. ):
  760. # 3. Build data-iterator
  761. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
  762. raw_inputs = _load_bytes(data_path_and_name_and_type[0])
  763. raw_inputs = torch.tensor(raw_inputs)
  764. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  765. try:
  766. raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
  767. except:
  768. raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
  769. if raw_inputs.ndim == 2:
  770. raw_inputs = raw_inputs[:, 0]
  771. raw_inputs = torch.tensor(raw_inputs)
  772. if data_path_and_name_and_type is None and raw_inputs is not None:
  773. if isinstance(raw_inputs, np.ndarray):
  774. raw_inputs = torch.tensor(raw_inputs)
  775. is_final = False
  776. cache = {}
  777. chunk_size = [5, 10, 5]
  778. if param_dict is not None and "cache" in param_dict:
  779. cache = param_dict["cache"]
  780. if param_dict is not None and "is_final" in param_dict:
  781. is_final = param_dict["is_final"]
  782. if param_dict is not None and "chunk_size" in param_dict:
  783. chunk_size = param_dict["chunk_size"]
  784. # 7 .Start for-loop
  785. # FIXME(kamo): The output format should be discussed about
  786. raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
  787. asr_result_list = []
  788. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  789. item = {}
  790. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  791. sample_offset = 0
  792. speech_length = raw_inputs.shape[1]
  793. stride_size = chunk_size[1] * 960
  794. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  795. final_result = ""
  796. for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
  797. if sample_offset + stride_size >= speech_length - 1:
  798. stride_size = speech_length - sample_offset
  799. cache["encoder"]["is_final"] = True
  800. else:
  801. cache["encoder"]["is_final"] = False
  802. input_lens = torch.tensor([stride_size])
  803. asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
  804. if len(asr_result) != 0:
  805. final_result += " ".join(asr_result) + " "
  806. item = {'key': "utt", 'value': final_result.strip()}
  807. else:
  808. input_lens = torch.tensor([raw_inputs.shape[1]])
  809. cache["encoder"]["is_final"] = is_final
  810. asr_result = speech2text(cache, raw_inputs, input_lens)
  811. item = {'key': "utt", 'value': " ".join(asr_result)}
  812. asr_result_list.append(item)
  813. if is_final:
  814. cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
  815. return asr_result_list
  816. return _forward
  817. def inference_uniasr(
  818. maxlenratio: float,
  819. minlenratio: float,
  820. batch_size: int,
  821. beam_size: int,
  822. ngpu: int,
  823. ctc_weight: float,
  824. lm_weight: float,
  825. penalty: float,
  826. log_level: Union[int, str],
  827. # data_path_and_name_and_type,
  828. asr_train_config: Optional[str],
  829. asr_model_file: Optional[str],
  830. ngram_file: Optional[str] = None,
  831. cmvn_file: Optional[str] = None,
  832. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  833. lm_train_config: Optional[str] = None,
  834. lm_file: Optional[str] = None,
  835. token_type: Optional[str] = None,
  836. key_file: Optional[str] = None,
  837. word_lm_train_config: Optional[str] = None,
  838. bpemodel: Optional[str] = None,
  839. allow_variable_data_keys: bool = False,
  840. streaming: bool = False,
  841. output_dir: Optional[str] = None,
  842. dtype: str = "float32",
  843. seed: int = 0,
  844. ngram_weight: float = 0.9,
  845. nbest: int = 1,
  846. num_workers: int = 1,
  847. token_num_relax: int = 1,
  848. decoding_ind: int = 0,
  849. decoding_mode: str = "model1",
  850. param_dict: dict = None,
  851. **kwargs,
  852. ):
  853. ncpu = kwargs.get("ncpu", 1)
  854. torch.set_num_threads(ncpu)
  855. if batch_size > 1:
  856. raise NotImplementedError("batch decoding is not implemented")
  857. if word_lm_train_config is not None:
  858. raise NotImplementedError("Word LM is not implemented")
  859. if ngpu > 1:
  860. raise NotImplementedError("only single GPU decoding is supported")
  861. logging.basicConfig(
  862. level=log_level,
  863. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  864. )
  865. if ngpu >= 1 and torch.cuda.is_available():
  866. device = "cuda"
  867. else:
  868. device = "cpu"
  869. if param_dict is not None and "decoding_model" in param_dict:
  870. if param_dict["decoding_model"] == "fast":
  871. decoding_ind = 0
  872. decoding_mode = "model1"
  873. elif param_dict["decoding_model"] == "normal":
  874. decoding_ind = 0
  875. decoding_mode = "model2"
  876. elif param_dict["decoding_model"] == "offline":
  877. decoding_ind = 1
  878. decoding_mode = "model2"
  879. else:
  880. raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
  881. # 1. Set random-seed
  882. set_all_random_seed(seed)
  883. # 2. Build speech2text
  884. speech2text_kwargs = dict(
  885. asr_train_config=asr_train_config,
  886. asr_model_file=asr_model_file,
  887. cmvn_file=cmvn_file,
  888. lm_train_config=lm_train_config,
  889. lm_file=lm_file,
  890. ngram_file=ngram_file,
  891. token_type=token_type,
  892. bpemodel=bpemodel,
  893. device=device,
  894. maxlenratio=maxlenratio,
  895. minlenratio=minlenratio,
  896. dtype=dtype,
  897. beam_size=beam_size,
  898. ctc_weight=ctc_weight,
  899. lm_weight=lm_weight,
  900. ngram_weight=ngram_weight,
  901. penalty=penalty,
  902. nbest=nbest,
  903. streaming=streaming,
  904. token_num_relax=token_num_relax,
  905. decoding_ind=decoding_ind,
  906. decoding_mode=decoding_mode,
  907. )
  908. speech2text = Speech2TextUniASR(**speech2text_kwargs)
  909. def _forward(data_path_and_name_and_type,
  910. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  911. output_dir_v2: Optional[str] = None,
  912. fs: dict = None,
  913. param_dict: dict = None,
  914. **kwargs,
  915. ):
  916. # 3. Build data-iterator
  917. if data_path_and_name_and_type is None and raw_inputs is not None:
  918. if isinstance(raw_inputs, torch.Tensor):
  919. raw_inputs = raw_inputs.numpy()
  920. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  921. loader = build_streaming_iterator(
  922. task_name="asr",
  923. preprocess_args=speech2text.asr_train_args,
  924. data_path_and_name_and_type=data_path_and_name_and_type,
  925. dtype=dtype,
  926. fs=fs,
  927. batch_size=batch_size,
  928. key_file=key_file,
  929. num_workers=num_workers,
  930. )
  931. finish_count = 0
  932. file_count = 1
  933. # 7 .Start for-loop
  934. # FIXME(kamo): The output format should be discussed about
  935. asr_result_list = []
  936. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  937. if output_path is not None:
  938. writer = DatadirWriter(output_path)
  939. else:
  940. writer = None
  941. for keys, batch in loader:
  942. assert isinstance(batch, dict), type(batch)
  943. assert all(isinstance(s, str) for s in keys), keys
  944. _bs = len(next(iter(batch.values())))
  945. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  946. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  947. # N-best list of (text, token, token_int, hyp_object)
  948. try:
  949. results = speech2text(**batch)
  950. except TooShortUttError as e:
  951. logging.warning(f"Utterance {keys} {e}")
  952. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  953. results = [[" ", ["sil"], [2], hyp]] * nbest
  954. # Only supporting batch_size==1
  955. key = keys[0]
  956. logging.info(f"Utterance: {key}")
  957. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  958. # Create a directory: outdir/{n}best_recog
  959. if writer is not None:
  960. ibest_writer = writer[f"{n}best_recog"]
  961. # Write the result to each file
  962. ibest_writer["token"][key] = " ".join(token)
  963. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  964. ibest_writer["score"][key] = str(hyp.score)
  965. if text is not None:
  966. text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
  967. item = {'key': key, 'value': text_postprocessed}
  968. asr_result_list.append(item)
  969. finish_count += 1
  970. asr_utils.print_progress(finish_count / file_count)
  971. if writer is not None:
  972. ibest_writer["text"][key] = " ".join(word_lists)
  973. return asr_result_list
  974. return _forward
  975. def inference_mfcca(
  976. maxlenratio: float,
  977. minlenratio: float,
  978. batch_size: int,
  979. beam_size: int,
  980. ngpu: int,
  981. ctc_weight: float,
  982. lm_weight: float,
  983. penalty: float,
  984. log_level: Union[int, str],
  985. # data_path_and_name_and_type,
  986. asr_train_config: Optional[str],
  987. asr_model_file: Optional[str],
  988. cmvn_file: Optional[str] = None,
  989. lm_train_config: Optional[str] = None,
  990. lm_file: Optional[str] = None,
  991. token_type: Optional[str] = None,
  992. key_file: Optional[str] = None,
  993. word_lm_train_config: Optional[str] = None,
  994. bpemodel: Optional[str] = None,
  995. allow_variable_data_keys: bool = False,
  996. streaming: bool = False,
  997. output_dir: Optional[str] = None,
  998. dtype: str = "float32",
  999. seed: int = 0,
  1000. ngram_weight: float = 0.9,
  1001. nbest: int = 1,
  1002. num_workers: int = 1,
  1003. param_dict: dict = None,
  1004. **kwargs,
  1005. ):
  1006. ncpu = kwargs.get("ncpu", 1)
  1007. torch.set_num_threads(ncpu)
  1008. if batch_size > 1:
  1009. raise NotImplementedError("batch decoding is not implemented")
  1010. if word_lm_train_config is not None:
  1011. raise NotImplementedError("Word LM is not implemented")
  1012. if ngpu > 1:
  1013. raise NotImplementedError("only single GPU decoding is supported")
  1014. logging.basicConfig(
  1015. level=log_level,
  1016. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1017. )
  1018. if ngpu >= 1 and torch.cuda.is_available():
  1019. device = "cuda"
  1020. else:
  1021. device = "cpu"
  1022. # 1. Set random-seed
  1023. set_all_random_seed(seed)
  1024. # 2. Build speech2text
  1025. speech2text_kwargs = dict(
  1026. asr_train_config=asr_train_config,
  1027. asr_model_file=asr_model_file,
  1028. cmvn_file=cmvn_file,
  1029. lm_train_config=lm_train_config,
  1030. lm_file=lm_file,
  1031. token_type=token_type,
  1032. bpemodel=bpemodel,
  1033. device=device,
  1034. maxlenratio=maxlenratio,
  1035. minlenratio=minlenratio,
  1036. dtype=dtype,
  1037. beam_size=beam_size,
  1038. ctc_weight=ctc_weight,
  1039. lm_weight=lm_weight,
  1040. ngram_weight=ngram_weight,
  1041. penalty=penalty,
  1042. nbest=nbest,
  1043. streaming=streaming,
  1044. )
  1045. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1046. speech2text = Speech2TextMFCCA(**speech2text_kwargs)
  1047. def _forward(data_path_and_name_and_type,
  1048. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1049. output_dir_v2: Optional[str] = None,
  1050. fs: dict = None,
  1051. param_dict: dict = None,
  1052. **kwargs,
  1053. ):
  1054. # 3. Build data-iterator
  1055. if data_path_and_name_and_type is None and raw_inputs is not None:
  1056. if isinstance(raw_inputs, torch.Tensor):
  1057. raw_inputs = raw_inputs.numpy()
  1058. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1059. loader = build_streaming_iterator(
  1060. task_name="asr",
  1061. preprocess_args=speech2text.asr_train_args,
  1062. data_path_and_name_and_type=data_path_and_name_and_type,
  1063. dtype=dtype,
  1064. batch_size=batch_size,
  1065. fs=fs,
  1066. mc=True,
  1067. key_file=key_file,
  1068. num_workers=num_workers,
  1069. )
  1070. finish_count = 0
  1071. file_count = 1
  1072. # 7 .Start for-loop
  1073. # FIXME(kamo): The output format should be discussed about
  1074. asr_result_list = []
  1075. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1076. if output_path is not None:
  1077. writer = DatadirWriter(output_path)
  1078. else:
  1079. writer = None
  1080. for keys, batch in loader:
  1081. assert isinstance(batch, dict), type(batch)
  1082. assert all(isinstance(s, str) for s in keys), keys
  1083. _bs = len(next(iter(batch.values())))
  1084. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1085. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1086. # N-best list of (text, token, token_int, hyp_object)
  1087. try:
  1088. results = speech2text(**batch)
  1089. except TooShortUttError as e:
  1090. logging.warning(f"Utterance {keys} {e}")
  1091. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1092. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1093. # Only supporting batch_size==1
  1094. key = keys[0]
  1095. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1096. # Create a directory: outdir/{n}best_recog
  1097. if writer is not None:
  1098. ibest_writer = writer[f"{n}best_recog"]
  1099. # Write the result to each file
  1100. ibest_writer["token"][key] = " ".join(token)
  1101. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1102. ibest_writer["score"][key] = str(hyp.score)
  1103. if text is not None:
  1104. text_postprocessed = postprocess_utils.sentence_postprocess(token)
  1105. item = {'key': key, 'value': text_postprocessed}
  1106. asr_result_list.append(item)
  1107. finish_count += 1
  1108. asr_utils.print_progress(finish_count / file_count)
  1109. if writer is not None:
  1110. ibest_writer["text"][key] = text
  1111. return asr_result_list
  1112. return _forward
  1113. def inference_transducer(
  1114. output_dir: str,
  1115. batch_size: int,
  1116. dtype: str,
  1117. beam_size: int,
  1118. ngpu: int,
  1119. seed: int,
  1120. lm_weight: float,
  1121. nbest: int,
  1122. num_workers: int,
  1123. log_level: Union[int, str],
  1124. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  1125. asr_train_config: Optional[str],
  1126. asr_model_file: Optional[str],
  1127. cmvn_file: Optional[str],
  1128. beam_search_config: Optional[dict],
  1129. lm_train_config: Optional[str],
  1130. lm_file: Optional[str],
  1131. model_tag: Optional[str],
  1132. token_type: Optional[str],
  1133. bpemodel: Optional[str],
  1134. key_file: Optional[str],
  1135. allow_variable_data_keys: bool,
  1136. quantize_asr_model: Optional[bool],
  1137. quantize_modules: Optional[List[str]],
  1138. quantize_dtype: Optional[str],
  1139. streaming: Optional[bool],
  1140. simu_streaming: Optional[bool],
  1141. chunk_size: Optional[int],
  1142. left_context: Optional[int],
  1143. right_context: Optional[int],
  1144. display_partial_hypotheses: bool,
  1145. **kwargs,
  1146. ) -> None:
  1147. """Transducer model inference.
  1148. Args:
  1149. output_dir: Output directory path.
  1150. batch_size: Batch decoding size.
  1151. dtype: Data type.
  1152. beam_size: Beam size.
  1153. ngpu: Number of GPUs.
  1154. seed: Random number generator seed.
  1155. lm_weight: Weight of language model.
  1156. nbest: Number of final hypothesis.
  1157. num_workers: Number of workers.
  1158. log_level: Level of verbose for logs.
  1159. data_path_and_name_and_type:
  1160. asr_train_config: ASR model training config path.
  1161. asr_model_file: ASR model path.
  1162. beam_search_config: Beam search config path.
  1163. lm_train_config: Language Model training config path.
  1164. lm_file: Language Model path.
  1165. model_tag: Model tag.
  1166. token_type: Type of token units.
  1167. bpemodel: BPE model path.
  1168. key_file: File key.
  1169. allow_variable_data_keys: Whether to allow variable data keys.
  1170. quantize_asr_model: Whether to apply dynamic quantization to ASR model.
  1171. quantize_modules: List of module names to apply dynamic quantization on.
  1172. quantize_dtype: Dynamic quantization data type.
  1173. streaming: Whether to perform chunk-by-chunk inference.
  1174. chunk_size: Number of frames in chunk AFTER subsampling.
  1175. left_context: Number of frames in left context AFTER subsampling.
  1176. right_context: Number of frames in right context AFTER subsampling.
  1177. display_partial_hypotheses: Whether to display partial hypotheses.
  1178. """
  1179. if batch_size > 1:
  1180. raise NotImplementedError("batch decoding is not implemented")
  1181. if ngpu > 1:
  1182. raise NotImplementedError("only single GPU decoding is supported")
  1183. logging.basicConfig(
  1184. level=log_level,
  1185. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1186. )
  1187. if ngpu >= 1:
  1188. device = "cuda"
  1189. else:
  1190. device = "cpu"
  1191. # 1. Set random-seed
  1192. set_all_random_seed(seed)
  1193. # 2. Build speech2text
  1194. speech2text_kwargs = dict(
  1195. asr_train_config=asr_train_config,
  1196. asr_model_file=asr_model_file,
  1197. cmvn_file=cmvn_file,
  1198. beam_search_config=beam_search_config,
  1199. lm_train_config=lm_train_config,
  1200. lm_file=lm_file,
  1201. token_type=token_type,
  1202. bpemodel=bpemodel,
  1203. device=device,
  1204. dtype=dtype,
  1205. beam_size=beam_size,
  1206. lm_weight=lm_weight,
  1207. nbest=nbest,
  1208. quantize_asr_model=quantize_asr_model,
  1209. quantize_modules=quantize_modules,
  1210. quantize_dtype=quantize_dtype,
  1211. streaming=streaming,
  1212. simu_streaming=simu_streaming,
  1213. chunk_size=chunk_size,
  1214. left_context=left_context,
  1215. right_context=right_context,
  1216. )
  1217. speech2text = Speech2TextTransducer(**speech2text_kwargs)
  1218. def _forward(data_path_and_name_and_type,
  1219. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1220. output_dir_v2: Optional[str] = None,
  1221. fs: dict = None,
  1222. param_dict: dict = None,
  1223. **kwargs,
  1224. ):
  1225. # 3. Build data-iterator
  1226. loader = build_streaming_iterator(
  1227. task_name="asr",
  1228. preprocess_args=speech2text.asr_train_args,
  1229. data_path_and_name_and_type=data_path_and_name_and_type,
  1230. dtype=dtype,
  1231. batch_size=batch_size,
  1232. key_file=key_file,
  1233. num_workers=num_workers,
  1234. )
  1235. # 4 .Start for-loop
  1236. with DatadirWriter(output_dir) as writer:
  1237. for keys, batch in loader:
  1238. assert isinstance(batch, dict), type(batch)
  1239. assert all(isinstance(s, str) for s in keys), keys
  1240. _bs = len(next(iter(batch.values())))
  1241. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1242. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1243. assert len(batch.keys()) == 1
  1244. try:
  1245. if speech2text.streaming:
  1246. speech = batch["speech"]
  1247. _steps = len(speech) // speech2text._ctx
  1248. _end = 0
  1249. for i in range(_steps):
  1250. _end = (i + 1) * speech2text._ctx
  1251. speech2text.streaming_decode(
  1252. speech[i * speech2text._ctx: _end], is_final=False
  1253. )
  1254. final_hyps = speech2text.streaming_decode(
  1255. speech[_end: len(speech)], is_final=True
  1256. )
  1257. elif speech2text.simu_streaming:
  1258. final_hyps = speech2text.simu_streaming_decode(**batch)
  1259. else:
  1260. final_hyps = speech2text(**batch)
  1261. results = speech2text.hypotheses_to_results(final_hyps)
  1262. except TooShortUttError as e:
  1263. logging.warning(f"Utterance {keys} {e}")
  1264. hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
  1265. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1266. key = keys[0]
  1267. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1268. ibest_writer = writer[f"{n}best_recog"]
  1269. ibest_writer["token"][key] = " ".join(token)
  1270. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1271. ibest_writer["score"][key] = str(hyp.score)
  1272. if text is not None:
  1273. ibest_writer["text"][key] = text
  1274. return _forward
  1275. def inference_sa_asr(
  1276. maxlenratio: float,
  1277. minlenratio: float,
  1278. batch_size: int,
  1279. beam_size: int,
  1280. ngpu: int,
  1281. ctc_weight: float,
  1282. lm_weight: float,
  1283. penalty: float,
  1284. log_level: Union[int, str],
  1285. # data_path_and_name_and_type,
  1286. asr_train_config: Optional[str],
  1287. asr_model_file: Optional[str],
  1288. cmvn_file: Optional[str] = None,
  1289. lm_train_config: Optional[str] = None,
  1290. lm_file: Optional[str] = None,
  1291. token_type: Optional[str] = None,
  1292. key_file: Optional[str] = None,
  1293. word_lm_train_config: Optional[str] = None,
  1294. bpemodel: Optional[str] = None,
  1295. allow_variable_data_keys: bool = False,
  1296. streaming: bool = False,
  1297. output_dir: Optional[str] = None,
  1298. dtype: str = "float32",
  1299. seed: int = 0,
  1300. ngram_weight: float = 0.9,
  1301. nbest: int = 1,
  1302. num_workers: int = 1,
  1303. mc: bool = False,
  1304. param_dict: dict = None,
  1305. **kwargs,
  1306. ):
  1307. if batch_size > 1:
  1308. raise NotImplementedError("batch decoding is not implemented")
  1309. if word_lm_train_config is not None:
  1310. raise NotImplementedError("Word LM is not implemented")
  1311. if ngpu > 1:
  1312. raise NotImplementedError("only single GPU decoding is supported")
  1313. for handler in logging.root.handlers[:]:
  1314. logging.root.removeHandler(handler)
  1315. logging.basicConfig(
  1316. level=log_level,
  1317. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1318. )
  1319. if ngpu >= 1 and torch.cuda.is_available():
  1320. device = "cuda"
  1321. else:
  1322. device = "cpu"
  1323. # 1. Set random-seed
  1324. set_all_random_seed(seed)
  1325. # 2. Build speech2text
  1326. speech2text_kwargs = dict(
  1327. asr_train_config=asr_train_config,
  1328. asr_model_file=asr_model_file,
  1329. cmvn_file=cmvn_file,
  1330. lm_train_config=lm_train_config,
  1331. lm_file=lm_file,
  1332. token_type=token_type,
  1333. bpemodel=bpemodel,
  1334. device=device,
  1335. maxlenratio=maxlenratio,
  1336. minlenratio=minlenratio,
  1337. dtype=dtype,
  1338. beam_size=beam_size,
  1339. ctc_weight=ctc_weight,
  1340. lm_weight=lm_weight,
  1341. ngram_weight=ngram_weight,
  1342. penalty=penalty,
  1343. nbest=nbest,
  1344. streaming=streaming,
  1345. )
  1346. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1347. speech2text = Speech2TextSAASR(**speech2text_kwargs)
  1348. def _forward(data_path_and_name_and_type,
  1349. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1350. output_dir_v2: Optional[str] = None,
  1351. fs: dict = None,
  1352. param_dict: dict = None,
  1353. **kwargs,
  1354. ):
  1355. # 3. Build data-iterator
  1356. if data_path_and_name_and_type is None and raw_inputs is not None:
  1357. if isinstance(raw_inputs, torch.Tensor):
  1358. raw_inputs = raw_inputs.numpy()
  1359. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1360. loader = build_streaming_iterator(
  1361. task_name="asr",
  1362. preprocess_args=speech2text.asr_train_args,
  1363. data_path_and_name_and_type=data_path_and_name_and_type,
  1364. dtype=dtype,
  1365. fs=fs,
  1366. mc=mc,
  1367. batch_size=batch_size,
  1368. key_file=key_file,
  1369. num_workers=num_workers,
  1370. )
  1371. finish_count = 0
  1372. file_count = 1
  1373. # 7 .Start for-loop
  1374. # FIXME(kamo): The output format should be discussed about
  1375. asr_result_list = []
  1376. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1377. if output_path is not None:
  1378. writer = DatadirWriter(output_path)
  1379. else:
  1380. writer = None
  1381. for keys, batch in loader:
  1382. assert isinstance(batch, dict), type(batch)
  1383. assert all(isinstance(s, str) for s in keys), keys
  1384. _bs = len(next(iter(batch.values())))
  1385. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1386. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1387. # N-best list of (text, token, token_int, hyp_object)
  1388. try:
  1389. results = speech2text(**batch)
  1390. except TooShortUttError as e:
  1391. logging.warning(f"Utterance {keys} {e}")
  1392. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1393. results = [[" ", ["sil"], [2], hyp]] * nbest
  1394. # Only supporting batch_size==1
  1395. key = keys[0]
  1396. for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1397. # Create a directory: outdir/{n}best_recog
  1398. if writer is not None:
  1399. ibest_writer = writer[f"{n}best_recog"]
  1400. # Write the result to each file
  1401. ibest_writer["token"][key] = " ".join(token)
  1402. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1403. ibest_writer["score"][key] = str(hyp.score)
  1404. ibest_writer["text_id"][key] = text_id
  1405. if text is not None:
  1406. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  1407. item = {'key': key, 'value': text_postprocessed}
  1408. asr_result_list.append(item)
  1409. finish_count += 1
  1410. asr_utils.print_progress(finish_count / file_count)
  1411. if writer is not None:
  1412. ibest_writer["text"][key] = text
  1413. logging.info("uttid: {}".format(key))
  1414. logging.info("text predictions: {}".format(text))
  1415. logging.info("text_id predictions: {}\n".format(text_id))
  1416. return asr_result_list
  1417. return _forward
  1418. def inference_launch(**kwargs):
  1419. if 'mode' in kwargs:
  1420. mode = kwargs['mode']
  1421. else:
  1422. logging.info("Unknown decoding mode.")
  1423. return None
  1424. if mode == "asr":
  1425. return inference_asr(**kwargs)
  1426. elif mode == "uniasr":
  1427. return inference_uniasr(**kwargs)
  1428. elif mode == "paraformer":
  1429. return inference_paraformer(**kwargs)
  1430. elif mode == "paraformer_fake_streaming":
  1431. return inference_paraformer(**kwargs)
  1432. elif mode == "paraformer_streaming":
  1433. return inference_paraformer_online(**kwargs)
  1434. elif mode.startswith("paraformer_vad"):
  1435. return inference_paraformer_vad_punc(**kwargs)
  1436. elif mode == "mfcca":
  1437. return inference_mfcca(**kwargs)
  1438. elif mode == "rnnt":
  1439. return inference_transducer(**kwargs)
  1440. elif mode == "bat":
  1441. return inference_transducer(**kwargs)
  1442. elif mode == "sa_asr":
  1443. return inference_sa_asr(**kwargs)
  1444. else:
  1445. logging.info("Unknown decoding mode: {}".format(mode))
  1446. return None
  1447. def get_parser():
  1448. parser = config_argparse.ArgumentParser(
  1449. description="ASR Decoding",
  1450. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  1451. )
  1452. # Note(kamo): Use '_' instead of '-' as separator.
  1453. # '-' is confusing if written in yaml.
  1454. parser.add_argument(
  1455. "--log_level",
  1456. type=lambda x: x.upper(),
  1457. default="INFO",
  1458. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  1459. help="The verbose level of logging",
  1460. )
  1461. parser.add_argument("--output_dir", type=str, required=True)
  1462. parser.add_argument(
  1463. "--ngpu",
  1464. type=int,
  1465. default=0,
  1466. help="The number of gpus. 0 indicates CPU mode",
  1467. )
  1468. parser.add_argument(
  1469. "--njob",
  1470. type=int,
  1471. default=1,
  1472. help="The number of jobs for each gpu",
  1473. )
  1474. parser.add_argument(
  1475. "--gpuid_list",
  1476. type=str,
  1477. default="",
  1478. help="The visible gpus",
  1479. )
  1480. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  1481. parser.add_argument(
  1482. "--dtype",
  1483. default="float32",
  1484. choices=["float16", "float32", "float64"],
  1485. help="Data type",
  1486. )
  1487. parser.add_argument(
  1488. "--num_workers",
  1489. type=int,
  1490. default=1,
  1491. help="The number of workers used for DataLoader",
  1492. )
  1493. group = parser.add_argument_group("Input data related")
  1494. group.add_argument(
  1495. "--data_path_and_name_and_type",
  1496. type=str2triple_str,
  1497. required=True,
  1498. action="append",
  1499. )
  1500. group.add_argument("--key_file", type=str_or_none)
  1501. parser.add_argument(
  1502. "--hotword",
  1503. type=str_or_none,
  1504. default=None,
  1505. help="hotword file path or hotwords seperated by space"
  1506. )
  1507. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  1508. group.add_argument(
  1509. "--mc",
  1510. type=bool,
  1511. default=False,
  1512. help="MultiChannel input",
  1513. )
  1514. group = parser.add_argument_group("The model configuration related")
  1515. group.add_argument(
  1516. "--vad_infer_config",
  1517. type=str,
  1518. help="VAD infer configuration",
  1519. )
  1520. group.add_argument(
  1521. "--vad_model_file",
  1522. type=str,
  1523. help="VAD model parameter file",
  1524. )
  1525. group.add_argument(
  1526. "--cmvn_file",
  1527. type=str,
  1528. help="Global CMVN file",
  1529. )
  1530. group.add_argument(
  1531. "--asr_train_config",
  1532. type=str,
  1533. help="ASR training configuration",
  1534. )
  1535. group.add_argument(
  1536. "--asr_model_file",
  1537. type=str,
  1538. help="ASR model parameter file",
  1539. )
  1540. group.add_argument(
  1541. "--lm_train_config",
  1542. type=str,
  1543. help="LM training configuration",
  1544. )
  1545. group.add_argument(
  1546. "--lm_file",
  1547. type=str,
  1548. help="LM parameter file",
  1549. )
  1550. group.add_argument(
  1551. "--word_lm_train_config",
  1552. type=str,
  1553. help="Word LM training configuration",
  1554. )
  1555. group.add_argument(
  1556. "--word_lm_file",
  1557. type=str,
  1558. help="Word LM parameter file",
  1559. )
  1560. group.add_argument(
  1561. "--ngram_file",
  1562. type=str,
  1563. help="N-gram parameter file",
  1564. )
  1565. group.add_argument(
  1566. "--model_tag",
  1567. type=str,
  1568. help="Pretrained model tag. If specify this option, *_train_config and "
  1569. "*_file will be overwritten",
  1570. )
  1571. group.add_argument(
  1572. "--beam_search_config",
  1573. default={},
  1574. help="The keyword arguments for transducer beam search.",
  1575. )
  1576. group = parser.add_argument_group("Beam-search related")
  1577. group.add_argument(
  1578. "--batch_size",
  1579. type=int,
  1580. default=1,
  1581. help="The batch size for inference",
  1582. )
  1583. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  1584. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  1585. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  1586. group.add_argument(
  1587. "--maxlenratio",
  1588. type=float,
  1589. default=0.0,
  1590. help="Input length ratio to obtain max output length. "
  1591. "If maxlenratio=0.0 (default), it uses a end-detect "
  1592. "function "
  1593. "to automatically find maximum hypothesis lengths."
  1594. "If maxlenratio<0.0, its absolute value is interpreted"
  1595. "as a constant max output length",
  1596. )
  1597. group.add_argument(
  1598. "--minlenratio",
  1599. type=float,
  1600. default=0.0,
  1601. help="Input length ratio to obtain min output length",
  1602. )
  1603. group.add_argument(
  1604. "--ctc_weight",
  1605. type=float,
  1606. default=0.0,
  1607. help="CTC weight in joint decoding",
  1608. )
  1609. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  1610. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  1611. group.add_argument("--streaming", type=str2bool, default=False)
  1612. group.add_argument("--simu_streaming", type=str2bool, default=False)
  1613. group.add_argument("--chunk_size", type=int, default=16)
  1614. group.add_argument("--left_context", type=int, default=16)
  1615. group.add_argument("--right_context", type=int, default=0)
  1616. group.add_argument(
  1617. "--display_partial_hypotheses",
  1618. type=bool,
  1619. default=False,
  1620. help="Whether to display partial hypotheses during chunk-by-chunk inference.",
  1621. )
  1622. group = parser.add_argument_group("Dynamic quantization related")
  1623. group.add_argument(
  1624. "--quantize_asr_model",
  1625. type=bool,
  1626. default=False,
  1627. help="Apply dynamic quantization to ASR model.",
  1628. )
  1629. group.add_argument(
  1630. "--quantize_modules",
  1631. nargs="*",
  1632. default=None,
  1633. help="""Module names to apply dynamic quantization on.
  1634. The module names are provided as a list, where each name is separated
  1635. by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
  1636. Each specified name should be an attribute of 'torch.nn', e.g.:
  1637. torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
  1638. )
  1639. group.add_argument(
  1640. "--quantize_dtype",
  1641. type=str,
  1642. default="qint8",
  1643. choices=["float16", "qint8"],
  1644. help="Dtype for dynamic quantization.",
  1645. )
  1646. group = parser.add_argument_group("Text converter related")
  1647. group.add_argument(
  1648. "--token_type",
  1649. type=str_or_none,
  1650. default=None,
  1651. choices=["char", "bpe", None],
  1652. help="The token type for ASR model. "
  1653. "If not given, refers from the training args",
  1654. )
  1655. group.add_argument(
  1656. "--bpemodel",
  1657. type=str_or_none,
  1658. default=None,
  1659. help="The model path of sentencepiece. "
  1660. "If not given, refers from the training args",
  1661. )
  1662. group.add_argument("--token_num_relax", type=int, default=1, help="")
  1663. group.add_argument("--decoding_ind", type=int, default=0, help="")
  1664. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  1665. group.add_argument(
  1666. "--ctc_weight2",
  1667. type=float,
  1668. default=0.0,
  1669. help="CTC weight in joint decoding",
  1670. )
  1671. return parser
  1672. def main(cmd=None):
  1673. print(get_commandline_args(), file=sys.stderr)
  1674. parser = get_parser()
  1675. parser.add_argument(
  1676. "--mode",
  1677. type=str,
  1678. default="asr",
  1679. help="The decoding mode",
  1680. )
  1681. args = parser.parse_args(cmd)
  1682. kwargs = vars(args)
  1683. kwargs.pop("config", None)
  1684. # set logging messages
  1685. logging.basicConfig(
  1686. level=args.log_level,
  1687. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1688. )
  1689. logging.info("Decoding args: {}".format(kwargs))
  1690. # gpu setting
  1691. if args.ngpu > 0:
  1692. jobid = int(args.output_dir.split(".")[-1])
  1693. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  1694. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  1695. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  1696. inference_pipeline = inference_launch(**kwargs)
  1697. return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
  1698. if __name__ == "__main__":
  1699. main()