asr_inference_launch.py 70 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. import time
  10. from pathlib import Path
  11. from typing import Dict
  12. from typing import List
  13. from typing import Optional
  14. from typing import Sequence
  15. from typing import Tuple
  16. from typing import Union
  17. import numpy as np
  18. import torch
  19. import torchaudio
  20. import soundfile
  21. import yaml
  22. from funasr.bin.asr_infer import Speech2Text
  23. from funasr.bin.asr_infer import Speech2TextMFCCA
  24. from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
  25. from funasr.bin.asr_infer import Speech2TextSAASR
  26. from funasr.bin.asr_infer import Speech2TextTransducer
  27. from funasr.bin.asr_infer import Speech2TextUniASR
  28. from funasr.bin.punc_infer import Text2Punc
  29. from funasr.bin.tp_infer import Speech2Timestamp
  30. from funasr.bin.vad_infer import Speech2VadSegment
  31. from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
  32. from funasr.fileio.datadir_writer import DatadirWriter
  33. from funasr.modules.beam_search.beam_search import Hypothesis
  34. from funasr.modules.subsampling import TooShortUttError
  35. from funasr.torch_utils.device_funcs import to_device
  36. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  37. from funasr.utils import asr_utils, postprocess_utils
  38. from funasr.utils import config_argparse
  39. from funasr.utils.cli_utils import get_commandline_args
  40. from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
  41. from funasr.utils.types import str2bool
  42. from funasr.utils.types import str2triple_str
  43. from funasr.utils.types import str_or_none
  44. from funasr.utils.vad_utils import slice_padding_fbank
  45. def inference_asr(
  46. maxlenratio: float,
  47. minlenratio: float,
  48. batch_size: int,
  49. beam_size: int,
  50. ngpu: int,
  51. ctc_weight: float,
  52. lm_weight: float,
  53. penalty: float,
  54. log_level: Union[int, str],
  55. # data_path_and_name_and_type,
  56. asr_train_config: Optional[str],
  57. asr_model_file: Optional[str],
  58. cmvn_file: Optional[str] = None,
  59. lm_train_config: Optional[str] = None,
  60. lm_file: Optional[str] = None,
  61. token_type: Optional[str] = None,
  62. key_file: Optional[str] = None,
  63. word_lm_train_config: Optional[str] = None,
  64. bpemodel: Optional[str] = None,
  65. allow_variable_data_keys: bool = False,
  66. streaming: bool = False,
  67. output_dir: Optional[str] = None,
  68. dtype: str = "float32",
  69. seed: int = 0,
  70. ngram_weight: float = 0.9,
  71. nbest: int = 1,
  72. num_workers: int = 1,
  73. mc: bool = False,
  74. param_dict: dict = None,
  75. **kwargs,
  76. ):
  77. ncpu = kwargs.get("ncpu", 1)
  78. torch.set_num_threads(ncpu)
  79. if batch_size > 1:
  80. raise NotImplementedError("batch decoding is not implemented")
  81. if word_lm_train_config is not None:
  82. raise NotImplementedError("Word LM is not implemented")
  83. if ngpu > 1:
  84. raise NotImplementedError("only single GPU decoding is supported")
  85. for handler in logging.root.handlers[:]:
  86. logging.root.removeHandler(handler)
  87. logging.basicConfig(
  88. level=log_level,
  89. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  90. )
  91. if ngpu >= 1 and torch.cuda.is_available():
  92. device = "cuda"
  93. else:
  94. device = "cpu"
  95. # 1. Set random-seed
  96. set_all_random_seed(seed)
  97. # 2. Build speech2text
  98. speech2text_kwargs = dict(
  99. asr_train_config=asr_train_config,
  100. asr_model_file=asr_model_file,
  101. cmvn_file=cmvn_file,
  102. lm_train_config=lm_train_config,
  103. lm_file=lm_file,
  104. token_type=token_type,
  105. bpemodel=bpemodel,
  106. device=device,
  107. maxlenratio=maxlenratio,
  108. minlenratio=minlenratio,
  109. dtype=dtype,
  110. beam_size=beam_size,
  111. ctc_weight=ctc_weight,
  112. lm_weight=lm_weight,
  113. ngram_weight=ngram_weight,
  114. penalty=penalty,
  115. nbest=nbest,
  116. streaming=streaming,
  117. )
  118. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  119. speech2text = Speech2Text(**speech2text_kwargs)
  120. def _forward(data_path_and_name_and_type,
  121. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  122. output_dir_v2: Optional[str] = None,
  123. fs: dict = None,
  124. param_dict: dict = None,
  125. **kwargs,
  126. ):
  127. # 3. Build data-iterator
  128. if data_path_and_name_and_type is None and raw_inputs is not None:
  129. if isinstance(raw_inputs, torch.Tensor):
  130. raw_inputs = raw_inputs.numpy()
  131. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  132. loader = build_streaming_iterator(
  133. task_name="asr",
  134. preprocess_args=speech2text.asr_train_args,
  135. data_path_and_name_and_type=data_path_and_name_and_type,
  136. dtype=dtype,
  137. fs=fs,
  138. mc=mc,
  139. batch_size=batch_size,
  140. key_file=key_file,
  141. num_workers=num_workers,
  142. )
  143. finish_count = 0
  144. file_count = 1
  145. # 7 .Start for-loop
  146. # FIXME(kamo): The output format should be discussed about
  147. asr_result_list = []
  148. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  149. if output_path is not None:
  150. writer = DatadirWriter(output_path)
  151. else:
  152. writer = None
  153. for keys, batch in loader:
  154. assert isinstance(batch, dict), type(batch)
  155. assert all(isinstance(s, str) for s in keys), keys
  156. _bs = len(next(iter(batch.values())))
  157. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  158. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  159. # N-best list of (text, token, token_int, hyp_object)
  160. try:
  161. results = speech2text(**batch)
  162. except TooShortUttError as e:
  163. logging.warning(f"Utterance {keys} {e}")
  164. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  165. results = [[" ", ["sil"], [2], hyp]] * nbest
  166. # Only supporting batch_size==1
  167. key = keys[0]
  168. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  169. # Create a directory: outdir/{n}best_recog
  170. if writer is not None:
  171. ibest_writer = writer[f"{n}best_recog"]
  172. # Write the result to each file
  173. ibest_writer["token"][key] = " ".join(token)
  174. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  175. ibest_writer["score"][key] = str(hyp.score)
  176. if text is not None:
  177. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  178. item = {'key': key, 'value': text_postprocessed}
  179. asr_result_list.append(item)
  180. finish_count += 1
  181. asr_utils.print_progress(finish_count / file_count)
  182. if writer is not None:
  183. ibest_writer["text"][key] = text
  184. logging.info("uttid: {}".format(key))
  185. logging.info("text predictions: {}\n".format(text))
  186. return asr_result_list
  187. return _forward
  188. def inference_paraformer(
  189. maxlenratio: float,
  190. minlenratio: float,
  191. batch_size: int,
  192. beam_size: int,
  193. ngpu: int,
  194. ctc_weight: float,
  195. lm_weight: float,
  196. penalty: float,
  197. log_level: Union[int, str],
  198. # data_path_and_name_and_type,
  199. asr_train_config: Optional[str],
  200. asr_model_file: Optional[str],
  201. cmvn_file: Optional[str] = None,
  202. lm_train_config: Optional[str] = None,
  203. lm_file: Optional[str] = None,
  204. token_type: Optional[str] = None,
  205. key_file: Optional[str] = None,
  206. word_lm_train_config: Optional[str] = None,
  207. bpemodel: Optional[str] = None,
  208. allow_variable_data_keys: bool = False,
  209. dtype: str = "float32",
  210. seed: int = 0,
  211. ngram_weight: float = 0.9,
  212. nbest: int = 1,
  213. num_workers: int = 1,
  214. output_dir: Optional[str] = None,
  215. timestamp_infer_config: Union[Path, str] = None,
  216. timestamp_model_file: Union[Path, str] = None,
  217. param_dict: dict = None,
  218. **kwargs,
  219. ):
  220. ncpu = kwargs.get("ncpu", 1)
  221. torch.set_num_threads(ncpu)
  222. if word_lm_train_config is not None:
  223. raise NotImplementedError("Word LM is not implemented")
  224. if ngpu > 1:
  225. raise NotImplementedError("only single GPU decoding is supported")
  226. logging.basicConfig(
  227. level=log_level,
  228. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  229. )
  230. export_mode = False
  231. if param_dict is not None:
  232. hotword_list_or_file = param_dict.get('hotword')
  233. export_mode = param_dict.get("export_mode", False)
  234. clas_scale = param_dict.get('clas_scale', 1.0)
  235. else:
  236. hotword_list_or_file = None
  237. clas_scale = 1.0
  238. if kwargs.get("device", None) == "cpu":
  239. ngpu = 0
  240. if ngpu >= 1 and torch.cuda.is_available():
  241. device = "cuda"
  242. else:
  243. device = "cpu"
  244. batch_size = 1
  245. # 1. Set random-seed
  246. set_all_random_seed(seed)
  247. # 2. Build speech2text
  248. speech2text_kwargs = dict(
  249. asr_train_config=asr_train_config,
  250. asr_model_file=asr_model_file,
  251. cmvn_file=cmvn_file,
  252. lm_train_config=lm_train_config,
  253. lm_file=lm_file,
  254. token_type=token_type,
  255. bpemodel=bpemodel,
  256. device=device,
  257. maxlenratio=maxlenratio,
  258. minlenratio=minlenratio,
  259. dtype=dtype,
  260. beam_size=beam_size,
  261. ctc_weight=ctc_weight,
  262. lm_weight=lm_weight,
  263. ngram_weight=ngram_weight,
  264. penalty=penalty,
  265. nbest=nbest,
  266. hotword_list_or_file=hotword_list_or_file,
  267. clas_scale=clas_scale,
  268. )
  269. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  270. if timestamp_model_file is not None:
  271. speechtext2timestamp = Speech2Timestamp(
  272. timestamp_cmvn_file=cmvn_file,
  273. timestamp_model_file=timestamp_model_file,
  274. timestamp_infer_config=timestamp_infer_config,
  275. )
  276. else:
  277. speechtext2timestamp = None
  278. def _forward(
  279. data_path_and_name_and_type,
  280. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  281. output_dir_v2: Optional[str] = None,
  282. fs: dict = None,
  283. param_dict: dict = None,
  284. **kwargs,
  285. ):
  286. hotword_list_or_file = None
  287. if param_dict is not None:
  288. hotword_list_or_file = param_dict.get('hotword')
  289. if 'hotword' in kwargs and kwargs['hotword'] is not None:
  290. hotword_list_or_file = kwargs['hotword']
  291. if hotword_list_or_file is not None or 'hotword' in kwargs:
  292. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  293. # 3. Build data-iterator
  294. if data_path_and_name_and_type is None and raw_inputs is not None:
  295. if isinstance(raw_inputs, torch.Tensor):
  296. raw_inputs = raw_inputs.numpy()
  297. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  298. loader = build_streaming_iterator(
  299. task_name="asr",
  300. preprocess_args=speech2text.asr_train_args,
  301. data_path_and_name_and_type=data_path_and_name_and_type,
  302. dtype=dtype,
  303. fs=fs,
  304. batch_size=batch_size,
  305. key_file=key_file,
  306. num_workers=num_workers,
  307. )
  308. if param_dict is not None:
  309. use_timestamp = param_dict.get('use_timestamp', True)
  310. else:
  311. use_timestamp = True
  312. forward_time_total = 0.0
  313. length_total = 0.0
  314. finish_count = 0
  315. file_count = 1
  316. # 7 .Start for-loop
  317. # FIXME(kamo): The output format should be discussed about
  318. asr_result_list = []
  319. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  320. if output_path is not None:
  321. writer = DatadirWriter(output_path)
  322. else:
  323. writer = None
  324. for keys, batch in loader:
  325. assert isinstance(batch, dict), type(batch)
  326. assert all(isinstance(s, str) for s in keys), keys
  327. _bs = len(next(iter(batch.values())))
  328. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  329. # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
  330. logging.info("decoding, utt_id: {}".format(keys))
  331. # N-best list of (text, token, token_int, hyp_object)
  332. time_beg = time.time()
  333. results = speech2text(**batch)
  334. if len(results) < 1:
  335. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  336. results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
  337. time_end = time.time()
  338. forward_time = time_end - time_beg
  339. lfr_factor = results[0][-1]
  340. length = results[0][-2]
  341. forward_time_total += forward_time
  342. length_total += length
  343. rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
  344. 100 * forward_time / (
  345. length * lfr_factor))
  346. logging.info(rtf_cur)
  347. for batch_id in range(_bs):
  348. result = [results[batch_id][:-2]]
  349. key = keys[batch_id]
  350. for n, result in zip(range(1, nbest + 1), result):
  351. text, token, token_int, hyp = result[0], result[1], result[2], result[3]
  352. timestamp = result[4] if len(result[4]) > 0 else None
  353. # conduct timestamp prediction here
  354. # timestamp inference requires token length
  355. # thus following inference cannot be conducted in batch
  356. if timestamp is None and speechtext2timestamp:
  357. ts_batch = {}
  358. ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
  359. ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
  360. ts_batch['text_lengths'] = torch.tensor([len(token)])
  361. us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
  362. ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token,
  363. force_time_shift=-3.0)
  364. # Create a directory: outdir/{n}best_recog
  365. if writer is not None:
  366. ibest_writer = writer[f"{n}best_recog"]
  367. # Write the result to each file
  368. ibest_writer["token"][key] = " ".join(token)
  369. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  370. ibest_writer["score"][key] = str(hyp.score)
  371. ibest_writer["rtf"][key] = rtf_cur
  372. if text is not None:
  373. if use_timestamp and timestamp is not None:
  374. postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
  375. else:
  376. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  377. timestamp_postprocessed = ""
  378. if len(postprocessed_result) == 3:
  379. text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
  380. postprocessed_result[1], \
  381. postprocessed_result[2]
  382. else:
  383. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  384. item = {'key': key, 'value': text_postprocessed}
  385. if timestamp_postprocessed != "":
  386. item['timestamp'] = timestamp_postprocessed
  387. asr_result_list.append(item)
  388. finish_count += 1
  389. # asr_utils.print_progress(finish_count / file_count)
  390. if writer is not None:
  391. ibest_writer["text"][key] = " ".join(word_lists)
  392. logging.info("decoding, utt: {}, predictions: {}".format(key, text))
  393. rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
  394. forward_time_total,
  395. 100 * forward_time_total / (
  396. length_total * lfr_factor))
  397. logging.info(rtf_avg)
  398. if writer is not None:
  399. ibest_writer["rtf"]["rtf_avf"] = rtf_avg
  400. return asr_result_list
  401. return _forward
  402. def inference_paraformer_vad_punc(
  403. maxlenratio: float,
  404. minlenratio: float,
  405. batch_size: int,
  406. beam_size: int,
  407. ngpu: int,
  408. ctc_weight: float,
  409. lm_weight: float,
  410. penalty: float,
  411. log_level: Union[int, str],
  412. # data_path_and_name_and_type,
  413. asr_train_config: Optional[str],
  414. asr_model_file: Optional[str],
  415. cmvn_file: Optional[str] = None,
  416. lm_train_config: Optional[str] = None,
  417. lm_file: Optional[str] = None,
  418. token_type: Optional[str] = None,
  419. key_file: Optional[str] = None,
  420. word_lm_train_config: Optional[str] = None,
  421. bpemodel: Optional[str] = None,
  422. allow_variable_data_keys: bool = False,
  423. output_dir: Optional[str] = None,
  424. dtype: str = "float32",
  425. seed: int = 0,
  426. ngram_weight: float = 0.9,
  427. nbest: int = 1,
  428. num_workers: int = 1,
  429. vad_infer_config: Optional[str] = None,
  430. vad_model_file: Optional[str] = None,
  431. vad_cmvn_file: Optional[str] = None,
  432. time_stamp_writer: bool = True,
  433. punc_infer_config: Optional[str] = None,
  434. punc_model_file: Optional[str] = None,
  435. outputs_dict: Optional[bool] = True,
  436. param_dict: dict = None,
  437. **kwargs,
  438. ):
  439. ncpu = kwargs.get("ncpu", 1)
  440. torch.set_num_threads(ncpu)
  441. if word_lm_train_config is not None:
  442. raise NotImplementedError("Word LM is not implemented")
  443. if ngpu > 1:
  444. raise NotImplementedError("only single GPU decoding is supported")
  445. logging.basicConfig(
  446. level=log_level,
  447. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  448. )
  449. if param_dict is not None:
  450. hotword_list_or_file = param_dict.get('hotword')
  451. else:
  452. hotword_list_or_file = None
  453. if ngpu >= 1 and torch.cuda.is_available():
  454. device = "cuda"
  455. else:
  456. device = "cpu"
  457. # 1. Set random-seed
  458. set_all_random_seed(seed)
  459. # 2. Build speech2vadsegment
  460. speech2vadsegment_kwargs = dict(
  461. vad_infer_config=vad_infer_config,
  462. vad_model_file=vad_model_file,
  463. vad_cmvn_file=vad_cmvn_file,
  464. device=device,
  465. dtype=dtype,
  466. )
  467. # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  468. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  469. # 3. Build speech2text
  470. speech2text_kwargs = dict(
  471. asr_train_config=asr_train_config,
  472. asr_model_file=asr_model_file,
  473. cmvn_file=cmvn_file,
  474. lm_train_config=lm_train_config,
  475. lm_file=lm_file,
  476. token_type=token_type,
  477. bpemodel=bpemodel,
  478. device=device,
  479. maxlenratio=maxlenratio,
  480. minlenratio=minlenratio,
  481. dtype=dtype,
  482. beam_size=beam_size,
  483. ctc_weight=ctc_weight,
  484. lm_weight=lm_weight,
  485. ngram_weight=ngram_weight,
  486. penalty=penalty,
  487. nbest=nbest,
  488. hotword_list_or_file=hotword_list_or_file,
  489. )
  490. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  491. text2punc = None
  492. if punc_model_file is not None:
  493. text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
  494. if output_dir is not None:
  495. writer = DatadirWriter(output_dir)
  496. ibest_writer = writer[f"1best_recog"]
  497. ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
  498. def _forward(data_path_and_name_and_type,
  499. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  500. output_dir_v2: Optional[str] = None,
  501. fs: dict = None,
  502. param_dict: dict = None,
  503. **kwargs,
  504. ):
  505. hotword_list_or_file = None
  506. if param_dict is not None:
  507. hotword_list_or_file = param_dict.get('hotword')
  508. if 'hotword' in kwargs:
  509. hotword_list_or_file = kwargs['hotword']
  510. batch_size_token = kwargs.get("batch_size_token", 6000)
  511. print("batch_size_token: ", batch_size_token)
  512. if speech2text.hotword_list is None:
  513. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  514. # 3. Build data-iterator
  515. if data_path_and_name_and_type is None and raw_inputs is not None:
  516. if isinstance(raw_inputs, torch.Tensor):
  517. raw_inputs = raw_inputs.numpy()
  518. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  519. loader = build_streaming_iterator(
  520. task_name="asr",
  521. preprocess_args=None,
  522. data_path_and_name_and_type=data_path_and_name_and_type,
  523. dtype=dtype,
  524. fs=fs,
  525. batch_size=1,
  526. key_file=key_file,
  527. num_workers=num_workers,
  528. )
  529. if param_dict is not None:
  530. use_timestamp = param_dict.get('use_timestamp', True)
  531. else:
  532. use_timestamp = True
  533. finish_count = 0
  534. file_count = 1
  535. lfr_factor = 6
  536. # 7 .Start for-loop
  537. asr_result_list = []
  538. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  539. writer = None
  540. if output_path is not None:
  541. writer = DatadirWriter(output_path)
  542. ibest_writer = writer[f"1best_recog"]
  543. for keys, batch in loader:
  544. assert isinstance(batch, dict), type(batch)
  545. assert all(isinstance(s, str) for s in keys), keys
  546. _bs = len(next(iter(batch.values())))
  547. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  548. beg_vad = time.time()
  549. vad_results = speech2vadsegment(**batch)
  550. end_vad = time.time()
  551. print("time cost vad: ", end_vad - beg_vad)
  552. _, vadsegments = vad_results[0], vad_results[1][0]
  553. speech, speech_lengths = batch["speech"], batch["speech_lengths"]
  554. n = len(vadsegments)
  555. data_with_index = [(vadsegments[i], i) for i in range(n)]
  556. sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
  557. results_sorted = []
  558. if not len(sorted_data):
  559. key = keys[0]
  560. # no active segments after VAD
  561. if writer is not None:
  562. # Write empty results
  563. ibest_writer["token"][key] = ""
  564. ibest_writer["token_int"][key] = ""
  565. ibest_writer["vad"][key] = ""
  566. ibest_writer["text"][key] = ""
  567. ibest_writer["text_with_punc"][key] = ""
  568. if use_timestamp:
  569. ibest_writer["time_stamp"][key] = ""
  570. logging.info("decoding, utt: {}, empty speech".format(key))
  571. continue
  572. batch_size_token_ms = batch_size_token*60
  573. if speech2text.device == "cpu":
  574. batch_size_token_ms = 0
  575. batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
  576. batch_size_token_ms_cum = 0
  577. beg_idx = 0
  578. for j, _ in enumerate(range(0, n)):
  579. batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
  580. if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
  581. 0]) < batch_size_token_ms:
  582. continue
  583. batch_size_token_ms_cum = 0
  584. end_idx = j + 1
  585. speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
  586. beg_idx = end_idx
  587. batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
  588. batch = to_device(batch, device=device)
  589. print("batch: ", speech_j.shape[0])
  590. beg_asr = time.time()
  591. results = speech2text(**batch)
  592. end_asr = time.time()
  593. print("time cost asr: ", end_asr - beg_asr)
  594. if len(results) < 1:
  595. results = [["", [], [], [], [], [], []]]
  596. results_sorted.extend(results)
  597. restored_data = [0] * n
  598. for j in range(n):
  599. index = sorted_data[j][1]
  600. restored_data[index] = results_sorted[j]
  601. result = ["", [], [], [], [], [], []]
  602. for j in range(n):
  603. result[0] += restored_data[j][0]
  604. result[1] += restored_data[j][1]
  605. result[2] += restored_data[j][2]
  606. if len(restored_data[j][4]) > 0:
  607. for t in restored_data[j][4]:
  608. t[0] += vadsegments[j][0]
  609. t[1] += vadsegments[j][0]
  610. result[4] += restored_data[j][4]
  611. # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
  612. key = keys[0]
  613. # result = result_segments[0]
  614. text, token, token_int = result[0], result[1], result[2]
  615. time_stamp = result[4] if len(result[4]) > 0 else None
  616. if use_timestamp and time_stamp is not None:
  617. postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
  618. else:
  619. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  620. text_postprocessed = ""
  621. time_stamp_postprocessed = ""
  622. text_postprocessed_punc = postprocessed_result
  623. if len(postprocessed_result) == 3:
  624. text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
  625. postprocessed_result[1], \
  626. postprocessed_result[2]
  627. else:
  628. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  629. text_postprocessed_punc = text_postprocessed
  630. punc_id_list = []
  631. if len(word_lists) > 0 and text2punc is not None:
  632. beg_punc = time.time()
  633. text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
  634. end_punc = time.time()
  635. print("time cost punc: ", end_punc - beg_punc)
  636. item = {'key': key, 'value': text_postprocessed_punc}
  637. if text_postprocessed != "":
  638. item['text_postprocessed'] = text_postprocessed
  639. if time_stamp_postprocessed != "":
  640. item['time_stamp'] = time_stamp_postprocessed
  641. item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
  642. asr_result_list.append(item)
  643. finish_count += 1
  644. # asr_utils.print_progress(finish_count / file_count)
  645. if writer is not None:
  646. # Write the result to each file
  647. ibest_writer["token"][key] = " ".join(token)
  648. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  649. ibest_writer["vad"][key] = "{}".format(vadsegments)
  650. ibest_writer["text"][key] = " ".join(word_lists)
  651. ibest_writer["text_with_punc"][key] = text_postprocessed_punc
  652. if time_stamp_postprocessed is not None:
  653. ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
  654. logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
  655. return asr_result_list
  656. return _forward
  657. def inference_paraformer_online(
  658. maxlenratio: float,
  659. minlenratio: float,
  660. batch_size: int,
  661. beam_size: int,
  662. ngpu: int,
  663. ctc_weight: float,
  664. lm_weight: float,
  665. penalty: float,
  666. log_level: Union[int, str],
  667. # data_path_and_name_and_type,
  668. asr_train_config: Optional[str],
  669. asr_model_file: Optional[str],
  670. cmvn_file: Optional[str] = None,
  671. lm_train_config: Optional[str] = None,
  672. lm_file: Optional[str] = None,
  673. token_type: Optional[str] = None,
  674. key_file: Optional[str] = None,
  675. word_lm_train_config: Optional[str] = None,
  676. bpemodel: Optional[str] = None,
  677. allow_variable_data_keys: bool = False,
  678. dtype: str = "float32",
  679. seed: int = 0,
  680. ngram_weight: float = 0.9,
  681. nbest: int = 1,
  682. num_workers: int = 1,
  683. output_dir: Optional[str] = None,
  684. param_dict: dict = None,
  685. **kwargs,
  686. ):
  687. if word_lm_train_config is not None:
  688. raise NotImplementedError("Word LM is not implemented")
  689. if ngpu > 1:
  690. raise NotImplementedError("only single GPU decoding is supported")
  691. logging.basicConfig(
  692. level=log_level,
  693. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  694. )
  695. export_mode = False
  696. if ngpu >= 1 and torch.cuda.is_available():
  697. device = "cuda"
  698. else:
  699. device = "cpu"
  700. batch_size = 1
  701. # 1. Set random-seed
  702. set_all_random_seed(seed)
  703. # 2. Build speech2text
  704. speech2text_kwargs = dict(
  705. asr_train_config=asr_train_config,
  706. asr_model_file=asr_model_file,
  707. cmvn_file=cmvn_file,
  708. lm_train_config=lm_train_config,
  709. lm_file=lm_file,
  710. token_type=token_type,
  711. bpemodel=bpemodel,
  712. device=device,
  713. maxlenratio=maxlenratio,
  714. minlenratio=minlenratio,
  715. dtype=dtype,
  716. beam_size=beam_size,
  717. ctc_weight=ctc_weight,
  718. lm_weight=lm_weight,
  719. ngram_weight=ngram_weight,
  720. penalty=penalty,
  721. nbest=nbest,
  722. )
  723. speech2text = Speech2TextParaformerOnline(**speech2text_kwargs)
  724. def _load_bytes(input):
  725. middle_data = np.frombuffer(input, dtype=np.int16)
  726. middle_data = np.asarray(middle_data)
  727. if middle_data.dtype.kind not in 'iu':
  728. raise TypeError("'middle_data' must be an array of integers")
  729. dtype = np.dtype('float32')
  730. if dtype.kind != 'f':
  731. raise TypeError("'dtype' must be a floating point type")
  732. i = np.iinfo(middle_data.dtype)
  733. abs_max = 2 ** (i.bits - 1)
  734. offset = i.min + abs_max
  735. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  736. return array
  737. def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
  738. if not Path(yaml_path).exists():
  739. raise FileExistsError(f'The {yaml_path} does not exist.')
  740. with open(str(yaml_path), 'rb') as f:
  741. data = yaml.load(f, Loader=yaml.Loader)
  742. return data
  743. def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
  744. if len(cache) > 0:
  745. return cache
  746. config = _read_yaml(asr_train_config)
  747. enc_output_size = config["encoder_conf"]["output_size"]
  748. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  749. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  750. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  751. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
  752. cache["encoder"] = cache_en
  753. cache_de = {"decode_fsmn": None}
  754. cache["decoder"] = cache_de
  755. return cache
  756. def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
  757. if len(cache) > 0:
  758. config = _read_yaml(asr_train_config)
  759. enc_output_size = config["encoder_conf"]["output_size"]
  760. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  761. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  762. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  763. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
  764. "tail_chunk": False}
  765. cache["encoder"] = cache_en
  766. cache_de = {"decode_fsmn": None}
  767. cache["decoder"] = cache_de
  768. return cache
  769. def _forward(
  770. data_path_and_name_and_type,
  771. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  772. output_dir_v2: Optional[str] = None,
  773. fs: dict = None,
  774. param_dict: dict = None,
  775. **kwargs,
  776. ):
  777. # 3. Build data-iterator
  778. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
  779. raw_inputs = _load_bytes(data_path_and_name_and_type[0])
  780. raw_inputs = torch.tensor(raw_inputs)
  781. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  782. try:
  783. raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
  784. except:
  785. raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
  786. if raw_inputs.ndim == 2:
  787. raw_inputs = raw_inputs[:, 0]
  788. raw_inputs = torch.tensor(raw_inputs)
  789. if data_path_and_name_and_type is None and raw_inputs is not None:
  790. if isinstance(raw_inputs, np.ndarray):
  791. raw_inputs = torch.tensor(raw_inputs)
  792. is_final = False
  793. cache = {}
  794. chunk_size = [5, 10, 5]
  795. if param_dict is not None and "cache" in param_dict:
  796. cache = param_dict["cache"]
  797. if param_dict is not None and "is_final" in param_dict:
  798. is_final = param_dict["is_final"]
  799. if param_dict is not None and "chunk_size" in param_dict:
  800. chunk_size = param_dict["chunk_size"]
  801. # 7 .Start for-loop
  802. # FIXME(kamo): The output format should be discussed about
  803. raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
  804. asr_result_list = []
  805. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  806. item = {}
  807. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  808. sample_offset = 0
  809. speech_length = raw_inputs.shape[1]
  810. stride_size = chunk_size[1] * 960
  811. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  812. final_result = ""
  813. for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
  814. if sample_offset + stride_size >= speech_length - 1:
  815. stride_size = speech_length - sample_offset
  816. cache["encoder"]["is_final"] = True
  817. else:
  818. cache["encoder"]["is_final"] = False
  819. input_lens = torch.tensor([stride_size])
  820. asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
  821. if len(asr_result) != 0:
  822. final_result += " ".join(asr_result) + " "
  823. item = {'key': "utt", 'value': final_result.strip()}
  824. else:
  825. input_lens = torch.tensor([raw_inputs.shape[1]])
  826. cache["encoder"]["is_final"] = is_final
  827. asr_result = speech2text(cache, raw_inputs, input_lens)
  828. item = {'key': "utt", 'value': " ".join(asr_result)}
  829. asr_result_list.append(item)
  830. if is_final:
  831. cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
  832. return asr_result_list
  833. return _forward
  834. def inference_uniasr(
  835. maxlenratio: float,
  836. minlenratio: float,
  837. batch_size: int,
  838. beam_size: int,
  839. ngpu: int,
  840. ctc_weight: float,
  841. lm_weight: float,
  842. penalty: float,
  843. log_level: Union[int, str],
  844. # data_path_and_name_and_type,
  845. asr_train_config: Optional[str],
  846. asr_model_file: Optional[str],
  847. ngram_file: Optional[str] = None,
  848. cmvn_file: Optional[str] = None,
  849. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  850. lm_train_config: Optional[str] = None,
  851. lm_file: Optional[str] = None,
  852. token_type: Optional[str] = None,
  853. key_file: Optional[str] = None,
  854. word_lm_train_config: Optional[str] = None,
  855. bpemodel: Optional[str] = None,
  856. allow_variable_data_keys: bool = False,
  857. streaming: bool = False,
  858. output_dir: Optional[str] = None,
  859. dtype: str = "float32",
  860. seed: int = 0,
  861. ngram_weight: float = 0.9,
  862. nbest: int = 1,
  863. num_workers: int = 1,
  864. token_num_relax: int = 1,
  865. decoding_ind: int = 0,
  866. decoding_mode: str = "model1",
  867. param_dict: dict = None,
  868. **kwargs,
  869. ):
  870. ncpu = kwargs.get("ncpu", 1)
  871. torch.set_num_threads(ncpu)
  872. if batch_size > 1:
  873. raise NotImplementedError("batch decoding is not implemented")
  874. if word_lm_train_config is not None:
  875. raise NotImplementedError("Word LM is not implemented")
  876. if ngpu > 1:
  877. raise NotImplementedError("only single GPU decoding is supported")
  878. logging.basicConfig(
  879. level=log_level,
  880. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  881. )
  882. if ngpu >= 1 and torch.cuda.is_available():
  883. device = "cuda"
  884. else:
  885. device = "cpu"
  886. if param_dict is not None and "decoding_model" in param_dict:
  887. if param_dict["decoding_model"] == "fast":
  888. decoding_ind = 0
  889. decoding_mode = "model1"
  890. elif param_dict["decoding_model"] == "normal":
  891. decoding_ind = 0
  892. decoding_mode = "model2"
  893. elif param_dict["decoding_model"] == "offline":
  894. decoding_ind = 1
  895. decoding_mode = "model2"
  896. else:
  897. raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
  898. # 1. Set random-seed
  899. set_all_random_seed(seed)
  900. # 2. Build speech2text
  901. speech2text_kwargs = dict(
  902. asr_train_config=asr_train_config,
  903. asr_model_file=asr_model_file,
  904. cmvn_file=cmvn_file,
  905. lm_train_config=lm_train_config,
  906. lm_file=lm_file,
  907. ngram_file=ngram_file,
  908. token_type=token_type,
  909. bpemodel=bpemodel,
  910. device=device,
  911. maxlenratio=maxlenratio,
  912. minlenratio=minlenratio,
  913. dtype=dtype,
  914. beam_size=beam_size,
  915. ctc_weight=ctc_weight,
  916. lm_weight=lm_weight,
  917. ngram_weight=ngram_weight,
  918. penalty=penalty,
  919. nbest=nbest,
  920. streaming=streaming,
  921. token_num_relax=token_num_relax,
  922. decoding_ind=decoding_ind,
  923. decoding_mode=decoding_mode,
  924. )
  925. speech2text = Speech2TextUniASR(**speech2text_kwargs)
  926. def _forward(data_path_and_name_and_type,
  927. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  928. output_dir_v2: Optional[str] = None,
  929. fs: dict = None,
  930. param_dict: dict = None,
  931. **kwargs,
  932. ):
  933. # 3. Build data-iterator
  934. if data_path_and_name_and_type is None and raw_inputs is not None:
  935. if isinstance(raw_inputs, torch.Tensor):
  936. raw_inputs = raw_inputs.numpy()
  937. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  938. loader = build_streaming_iterator(
  939. task_name="asr",
  940. preprocess_args=speech2text.asr_train_args,
  941. data_path_and_name_and_type=data_path_and_name_and_type,
  942. dtype=dtype,
  943. fs=fs,
  944. batch_size=batch_size,
  945. key_file=key_file,
  946. num_workers=num_workers,
  947. )
  948. finish_count = 0
  949. file_count = 1
  950. # 7 .Start for-loop
  951. # FIXME(kamo): The output format should be discussed about
  952. asr_result_list = []
  953. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  954. if output_path is not None:
  955. writer = DatadirWriter(output_path)
  956. else:
  957. writer = None
  958. for keys, batch in loader:
  959. assert isinstance(batch, dict), type(batch)
  960. assert all(isinstance(s, str) for s in keys), keys
  961. _bs = len(next(iter(batch.values())))
  962. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  963. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  964. # N-best list of (text, token, token_int, hyp_object)
  965. try:
  966. results = speech2text(**batch)
  967. except TooShortUttError as e:
  968. logging.warning(f"Utterance {keys} {e}")
  969. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  970. results = [[" ", ["sil"], [2], hyp]] * nbest
  971. # Only supporting batch_size==1
  972. key = keys[0]
  973. logging.info(f"Utterance: {key}")
  974. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  975. # Create a directory: outdir/{n}best_recog
  976. if writer is not None:
  977. ibest_writer = writer[f"{n}best_recog"]
  978. # Write the result to each file
  979. ibest_writer["token"][key] = " ".join(token)
  980. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  981. ibest_writer["score"][key] = str(hyp.score)
  982. if text is not None:
  983. text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
  984. item = {'key': key, 'value': text_postprocessed}
  985. asr_result_list.append(item)
  986. finish_count += 1
  987. asr_utils.print_progress(finish_count / file_count)
  988. if writer is not None:
  989. ibest_writer["text"][key] = " ".join(word_lists)
  990. return asr_result_list
  991. return _forward
  992. def inference_mfcca(
  993. maxlenratio: float,
  994. minlenratio: float,
  995. batch_size: int,
  996. beam_size: int,
  997. ngpu: int,
  998. ctc_weight: float,
  999. lm_weight: float,
  1000. penalty: float,
  1001. log_level: Union[int, str],
  1002. # data_path_and_name_and_type,
  1003. asr_train_config: Optional[str],
  1004. asr_model_file: Optional[str],
  1005. cmvn_file: Optional[str] = None,
  1006. lm_train_config: Optional[str] = None,
  1007. lm_file: Optional[str] = None,
  1008. token_type: Optional[str] = None,
  1009. key_file: Optional[str] = None,
  1010. word_lm_train_config: Optional[str] = None,
  1011. bpemodel: Optional[str] = None,
  1012. allow_variable_data_keys: bool = False,
  1013. streaming: bool = False,
  1014. output_dir: Optional[str] = None,
  1015. dtype: str = "float32",
  1016. seed: int = 0,
  1017. ngram_weight: float = 0.9,
  1018. nbest: int = 1,
  1019. num_workers: int = 1,
  1020. param_dict: dict = None,
  1021. **kwargs,
  1022. ):
  1023. ncpu = kwargs.get("ncpu", 1)
  1024. torch.set_num_threads(ncpu)
  1025. if batch_size > 1:
  1026. raise NotImplementedError("batch decoding is not implemented")
  1027. if word_lm_train_config is not None:
  1028. raise NotImplementedError("Word LM is not implemented")
  1029. if ngpu > 1:
  1030. raise NotImplementedError("only single GPU decoding is supported")
  1031. logging.basicConfig(
  1032. level=log_level,
  1033. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1034. )
  1035. if ngpu >= 1 and torch.cuda.is_available():
  1036. device = "cuda"
  1037. else:
  1038. device = "cpu"
  1039. # 1. Set random-seed
  1040. set_all_random_seed(seed)
  1041. # 2. Build speech2text
  1042. speech2text_kwargs = dict(
  1043. asr_train_config=asr_train_config,
  1044. asr_model_file=asr_model_file,
  1045. cmvn_file=cmvn_file,
  1046. lm_train_config=lm_train_config,
  1047. lm_file=lm_file,
  1048. token_type=token_type,
  1049. bpemodel=bpemodel,
  1050. device=device,
  1051. maxlenratio=maxlenratio,
  1052. minlenratio=minlenratio,
  1053. dtype=dtype,
  1054. beam_size=beam_size,
  1055. ctc_weight=ctc_weight,
  1056. lm_weight=lm_weight,
  1057. ngram_weight=ngram_weight,
  1058. penalty=penalty,
  1059. nbest=nbest,
  1060. streaming=streaming,
  1061. )
  1062. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1063. speech2text = Speech2TextMFCCA(**speech2text_kwargs)
  1064. def _forward(data_path_and_name_and_type,
  1065. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1066. output_dir_v2: Optional[str] = None,
  1067. fs: dict = None,
  1068. param_dict: dict = None,
  1069. **kwargs,
  1070. ):
  1071. # 3. Build data-iterator
  1072. if data_path_and_name_and_type is None and raw_inputs is not None:
  1073. if isinstance(raw_inputs, torch.Tensor):
  1074. raw_inputs = raw_inputs.numpy()
  1075. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1076. loader = build_streaming_iterator(
  1077. task_name="asr",
  1078. preprocess_args=speech2text.asr_train_args,
  1079. data_path_and_name_and_type=data_path_and_name_and_type,
  1080. dtype=dtype,
  1081. batch_size=batch_size,
  1082. fs=fs,
  1083. mc=True,
  1084. key_file=key_file,
  1085. num_workers=num_workers,
  1086. )
  1087. finish_count = 0
  1088. file_count = 1
  1089. # 7 .Start for-loop
  1090. # FIXME(kamo): The output format should be discussed about
  1091. asr_result_list = []
  1092. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1093. if output_path is not None:
  1094. writer = DatadirWriter(output_path)
  1095. else:
  1096. writer = None
  1097. for keys, batch in loader:
  1098. assert isinstance(batch, dict), type(batch)
  1099. assert all(isinstance(s, str) for s in keys), keys
  1100. _bs = len(next(iter(batch.values())))
  1101. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1102. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1103. # N-best list of (text, token, token_int, hyp_object)
  1104. try:
  1105. results = speech2text(**batch)
  1106. except TooShortUttError as e:
  1107. logging.warning(f"Utterance {keys} {e}")
  1108. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1109. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1110. # Only supporting batch_size==1
  1111. key = keys[0]
  1112. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1113. # Create a directory: outdir/{n}best_recog
  1114. if writer is not None:
  1115. ibest_writer = writer[f"{n}best_recog"]
  1116. # Write the result to each file
  1117. ibest_writer["token"][key] = " ".join(token)
  1118. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1119. ibest_writer["score"][key] = str(hyp.score)
  1120. if text is not None:
  1121. text_postprocessed = postprocess_utils.sentence_postprocess(token)
  1122. item = {'key': key, 'value': text_postprocessed}
  1123. asr_result_list.append(item)
  1124. finish_count += 1
  1125. asr_utils.print_progress(finish_count / file_count)
  1126. if writer is not None:
  1127. ibest_writer["text"][key] = text
  1128. return asr_result_list
  1129. return _forward
  1130. def inference_transducer(
  1131. output_dir: str,
  1132. batch_size: int,
  1133. dtype: str,
  1134. beam_size: int,
  1135. ngpu: int,
  1136. seed: int,
  1137. lm_weight: float,
  1138. nbest: int,
  1139. num_workers: int,
  1140. log_level: Union[int, str],
  1141. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  1142. asr_train_config: Optional[str],
  1143. asr_model_file: Optional[str],
  1144. cmvn_file: Optional[str],
  1145. beam_search_config: Optional[dict],
  1146. lm_train_config: Optional[str],
  1147. lm_file: Optional[str],
  1148. model_tag: Optional[str],
  1149. token_type: Optional[str],
  1150. bpemodel: Optional[str],
  1151. key_file: Optional[str],
  1152. allow_variable_data_keys: bool,
  1153. quantize_asr_model: Optional[bool],
  1154. quantize_modules: Optional[List[str]],
  1155. quantize_dtype: Optional[str],
  1156. streaming: Optional[bool],
  1157. simu_streaming: Optional[bool],
  1158. chunk_size: Optional[int],
  1159. left_context: Optional[int],
  1160. right_context: Optional[int],
  1161. display_partial_hypotheses: bool,
  1162. **kwargs,
  1163. ) -> None:
  1164. """Transducer model inference.
  1165. Args:
  1166. output_dir: Output directory path.
  1167. batch_size: Batch decoding size.
  1168. dtype: Data type.
  1169. beam_size: Beam size.
  1170. ngpu: Number of GPUs.
  1171. seed: Random number generator seed.
  1172. lm_weight: Weight of language model.
  1173. nbest: Number of final hypothesis.
  1174. num_workers: Number of workers.
  1175. log_level: Level of verbose for logs.
  1176. data_path_and_name_and_type:
  1177. asr_train_config: ASR model training config path.
  1178. asr_model_file: ASR model path.
  1179. beam_search_config: Beam search config path.
  1180. lm_train_config: Language Model training config path.
  1181. lm_file: Language Model path.
  1182. model_tag: Model tag.
  1183. token_type: Type of token units.
  1184. bpemodel: BPE model path.
  1185. key_file: File key.
  1186. allow_variable_data_keys: Whether to allow variable data keys.
  1187. quantize_asr_model: Whether to apply dynamic quantization to ASR model.
  1188. quantize_modules: List of module names to apply dynamic quantization on.
  1189. quantize_dtype: Dynamic quantization data type.
  1190. streaming: Whether to perform chunk-by-chunk inference.
  1191. chunk_size: Number of frames in chunk AFTER subsampling.
  1192. left_context: Number of frames in left context AFTER subsampling.
  1193. right_context: Number of frames in right context AFTER subsampling.
  1194. display_partial_hypotheses: Whether to display partial hypotheses.
  1195. """
  1196. if batch_size > 1:
  1197. raise NotImplementedError("batch decoding is not implemented")
  1198. if ngpu > 1:
  1199. raise NotImplementedError("only single GPU decoding is supported")
  1200. logging.basicConfig(
  1201. level=log_level,
  1202. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1203. )
  1204. if ngpu >= 1:
  1205. device = "cuda"
  1206. else:
  1207. device = "cpu"
  1208. # 1. Set random-seed
  1209. set_all_random_seed(seed)
  1210. # 2. Build speech2text
  1211. speech2text_kwargs = dict(
  1212. asr_train_config=asr_train_config,
  1213. asr_model_file=asr_model_file,
  1214. cmvn_file=cmvn_file,
  1215. beam_search_config=beam_search_config,
  1216. lm_train_config=lm_train_config,
  1217. lm_file=lm_file,
  1218. token_type=token_type,
  1219. bpemodel=bpemodel,
  1220. device=device,
  1221. dtype=dtype,
  1222. beam_size=beam_size,
  1223. lm_weight=lm_weight,
  1224. nbest=nbest,
  1225. quantize_asr_model=quantize_asr_model,
  1226. quantize_modules=quantize_modules,
  1227. quantize_dtype=quantize_dtype,
  1228. streaming=streaming,
  1229. simu_streaming=simu_streaming,
  1230. chunk_size=chunk_size,
  1231. left_context=left_context,
  1232. right_context=right_context,
  1233. )
  1234. speech2text = Speech2TextTransducer(**speech2text_kwargs)
  1235. def _forward(data_path_and_name_and_type,
  1236. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1237. output_dir_v2: Optional[str] = None,
  1238. fs: dict = None,
  1239. param_dict: dict = None,
  1240. **kwargs,
  1241. ):
  1242. # 3. Build data-iterator
  1243. loader = build_streaming_iterator(
  1244. task_name="asr",
  1245. preprocess_args=speech2text.asr_train_args,
  1246. data_path_and_name_and_type=data_path_and_name_and_type,
  1247. dtype=dtype,
  1248. batch_size=batch_size,
  1249. key_file=key_file,
  1250. num_workers=num_workers,
  1251. )
  1252. # 4 .Start for-loop
  1253. with DatadirWriter(output_dir) as writer:
  1254. for keys, batch in loader:
  1255. assert isinstance(batch, dict), type(batch)
  1256. assert all(isinstance(s, str) for s in keys), keys
  1257. _bs = len(next(iter(batch.values())))
  1258. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1259. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1260. assert len(batch.keys()) == 1
  1261. try:
  1262. if speech2text.streaming:
  1263. speech = batch["speech"]
  1264. _steps = len(speech) // speech2text._ctx
  1265. _end = 0
  1266. for i in range(_steps):
  1267. _end = (i + 1) * speech2text._ctx
  1268. speech2text.streaming_decode(
  1269. speech[i * speech2text._ctx: _end], is_final=False
  1270. )
  1271. final_hyps = speech2text.streaming_decode(
  1272. speech[_end: len(speech)], is_final=True
  1273. )
  1274. elif speech2text.simu_streaming:
  1275. final_hyps = speech2text.simu_streaming_decode(**batch)
  1276. else:
  1277. final_hyps = speech2text(**batch)
  1278. results = speech2text.hypotheses_to_results(final_hyps)
  1279. except TooShortUttError as e:
  1280. logging.warning(f"Utterance {keys} {e}")
  1281. hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
  1282. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1283. key = keys[0]
  1284. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1285. ibest_writer = writer[f"{n}best_recog"]
  1286. ibest_writer["token"][key] = " ".join(token)
  1287. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1288. ibest_writer["score"][key] = str(hyp.score)
  1289. if text is not None:
  1290. ibest_writer["text"][key] = text
  1291. return _forward
  1292. def inference_sa_asr(
  1293. maxlenratio: float,
  1294. minlenratio: float,
  1295. batch_size: int,
  1296. beam_size: int,
  1297. ngpu: int,
  1298. ctc_weight: float,
  1299. lm_weight: float,
  1300. penalty: float,
  1301. log_level: Union[int, str],
  1302. # data_path_and_name_and_type,
  1303. asr_train_config: Optional[str],
  1304. asr_model_file: Optional[str],
  1305. cmvn_file: Optional[str] = None,
  1306. lm_train_config: Optional[str] = None,
  1307. lm_file: Optional[str] = None,
  1308. token_type: Optional[str] = None,
  1309. key_file: Optional[str] = None,
  1310. word_lm_train_config: Optional[str] = None,
  1311. bpemodel: Optional[str] = None,
  1312. allow_variable_data_keys: bool = False,
  1313. streaming: bool = False,
  1314. output_dir: Optional[str] = None,
  1315. dtype: str = "float32",
  1316. seed: int = 0,
  1317. ngram_weight: float = 0.9,
  1318. nbest: int = 1,
  1319. num_workers: int = 1,
  1320. mc: bool = False,
  1321. param_dict: dict = None,
  1322. **kwargs,
  1323. ):
  1324. if batch_size > 1:
  1325. raise NotImplementedError("batch decoding is not implemented")
  1326. if word_lm_train_config is not None:
  1327. raise NotImplementedError("Word LM is not implemented")
  1328. if ngpu > 1:
  1329. raise NotImplementedError("only single GPU decoding is supported")
  1330. for handler in logging.root.handlers[:]:
  1331. logging.root.removeHandler(handler)
  1332. logging.basicConfig(
  1333. level=log_level,
  1334. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1335. )
  1336. if ngpu >= 1 and torch.cuda.is_available():
  1337. device = "cuda"
  1338. else:
  1339. device = "cpu"
  1340. # 1. Set random-seed
  1341. set_all_random_seed(seed)
  1342. # 2. Build speech2text
  1343. speech2text_kwargs = dict(
  1344. asr_train_config=asr_train_config,
  1345. asr_model_file=asr_model_file,
  1346. cmvn_file=cmvn_file,
  1347. lm_train_config=lm_train_config,
  1348. lm_file=lm_file,
  1349. token_type=token_type,
  1350. bpemodel=bpemodel,
  1351. device=device,
  1352. maxlenratio=maxlenratio,
  1353. minlenratio=minlenratio,
  1354. dtype=dtype,
  1355. beam_size=beam_size,
  1356. ctc_weight=ctc_weight,
  1357. lm_weight=lm_weight,
  1358. ngram_weight=ngram_weight,
  1359. penalty=penalty,
  1360. nbest=nbest,
  1361. streaming=streaming,
  1362. )
  1363. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1364. speech2text = Speech2TextSAASR(**speech2text_kwargs)
  1365. def _forward(data_path_and_name_and_type,
  1366. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1367. output_dir_v2: Optional[str] = None,
  1368. fs: dict = None,
  1369. param_dict: dict = None,
  1370. **kwargs,
  1371. ):
  1372. # 3. Build data-iterator
  1373. if data_path_and_name_and_type is None and raw_inputs is not None:
  1374. if isinstance(raw_inputs, torch.Tensor):
  1375. raw_inputs = raw_inputs.numpy()
  1376. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1377. loader = build_streaming_iterator(
  1378. task_name="asr",
  1379. preprocess_args=speech2text.asr_train_args,
  1380. data_path_and_name_and_type=data_path_and_name_and_type,
  1381. dtype=dtype,
  1382. fs=fs,
  1383. mc=mc,
  1384. batch_size=batch_size,
  1385. key_file=key_file,
  1386. num_workers=num_workers,
  1387. )
  1388. finish_count = 0
  1389. file_count = 1
  1390. # 7 .Start for-loop
  1391. # FIXME(kamo): The output format should be discussed about
  1392. asr_result_list = []
  1393. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1394. if output_path is not None:
  1395. writer = DatadirWriter(output_path)
  1396. else:
  1397. writer = None
  1398. for keys, batch in loader:
  1399. assert isinstance(batch, dict), type(batch)
  1400. assert all(isinstance(s, str) for s in keys), keys
  1401. _bs = len(next(iter(batch.values())))
  1402. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1403. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1404. # N-best list of (text, token, token_int, hyp_object)
  1405. try:
  1406. results = speech2text(**batch)
  1407. except TooShortUttError as e:
  1408. logging.warning(f"Utterance {keys} {e}")
  1409. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1410. results = [[" ", ["sil"], [2], hyp]] * nbest
  1411. # Only supporting batch_size==1
  1412. key = keys[0]
  1413. for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1414. # Create a directory: outdir/{n}best_recog
  1415. if writer is not None:
  1416. ibest_writer = writer[f"{n}best_recog"]
  1417. # Write the result to each file
  1418. ibest_writer["token"][key] = " ".join(token)
  1419. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1420. ibest_writer["score"][key] = str(hyp.score)
  1421. ibest_writer["text_id"][key] = text_id
  1422. if text is not None:
  1423. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  1424. item = {'key': key, 'value': text_postprocessed}
  1425. asr_result_list.append(item)
  1426. finish_count += 1
  1427. asr_utils.print_progress(finish_count / file_count)
  1428. if writer is not None:
  1429. ibest_writer["text"][key] = text
  1430. logging.info("uttid: {}".format(key))
  1431. logging.info("text predictions: {}".format(text))
  1432. logging.info("text_id predictions: {}\n".format(text_id))
  1433. return asr_result_list
  1434. return _forward
  1435. def inference_launch(**kwargs):
  1436. if 'mode' in kwargs:
  1437. mode = kwargs['mode']
  1438. else:
  1439. logging.info("Unknown decoding mode.")
  1440. return None
  1441. if mode == "asr":
  1442. return inference_asr(**kwargs)
  1443. elif mode == "uniasr":
  1444. return inference_uniasr(**kwargs)
  1445. elif mode == "paraformer":
  1446. return inference_paraformer(**kwargs)
  1447. elif mode == "paraformer_fake_streaming":
  1448. return inference_paraformer(**kwargs)
  1449. elif mode == "paraformer_streaming":
  1450. return inference_paraformer_online(**kwargs)
  1451. elif mode.startswith("paraformer_vad"):
  1452. return inference_paraformer_vad_punc(**kwargs)
  1453. elif mode == "mfcca":
  1454. return inference_mfcca(**kwargs)
  1455. elif mode == "rnnt":
  1456. return inference_transducer(**kwargs)
  1457. elif mode == "bat":
  1458. return inference_transducer(**kwargs)
  1459. elif mode == "sa_asr":
  1460. return inference_sa_asr(**kwargs)
  1461. else:
  1462. logging.info("Unknown decoding mode: {}".format(mode))
  1463. return None
  1464. def get_parser():
  1465. parser = config_argparse.ArgumentParser(
  1466. description="ASR Decoding",
  1467. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  1468. )
  1469. # Note(kamo): Use '_' instead of '-' as separator.
  1470. # '-' is confusing if written in yaml.
  1471. parser.add_argument(
  1472. "--log_level",
  1473. type=lambda x: x.upper(),
  1474. default="INFO",
  1475. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  1476. help="The verbose level of logging",
  1477. )
  1478. parser.add_argument("--output_dir", type=str, required=True)
  1479. parser.add_argument(
  1480. "--ngpu",
  1481. type=int,
  1482. default=0,
  1483. help="The number of gpus. 0 indicates CPU mode",
  1484. )
  1485. parser.add_argument(
  1486. "--njob",
  1487. type=int,
  1488. default=1,
  1489. help="The number of jobs for each gpu",
  1490. )
  1491. parser.add_argument(
  1492. "--gpuid_list",
  1493. type=str,
  1494. default="",
  1495. help="The visible gpus",
  1496. )
  1497. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  1498. parser.add_argument(
  1499. "--dtype",
  1500. default="float32",
  1501. choices=["float16", "float32", "float64"],
  1502. help="Data type",
  1503. )
  1504. parser.add_argument(
  1505. "--num_workers",
  1506. type=int,
  1507. default=1,
  1508. help="The number of workers used for DataLoader",
  1509. )
  1510. group = parser.add_argument_group("Input data related")
  1511. group.add_argument(
  1512. "--data_path_and_name_and_type",
  1513. type=str2triple_str,
  1514. required=True,
  1515. action="append",
  1516. )
  1517. group.add_argument("--key_file", type=str_or_none)
  1518. parser.add_argument(
  1519. "--hotword",
  1520. type=str_or_none,
  1521. default=None,
  1522. help="hotword file path or hotwords seperated by space"
  1523. )
  1524. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  1525. group.add_argument(
  1526. "--mc",
  1527. type=bool,
  1528. default=False,
  1529. help="MultiChannel input",
  1530. )
  1531. group = parser.add_argument_group("The model configuration related")
  1532. group.add_argument(
  1533. "--vad_infer_config",
  1534. type=str,
  1535. help="VAD infer configuration",
  1536. )
  1537. group.add_argument(
  1538. "--vad_model_file",
  1539. type=str,
  1540. help="VAD model parameter file",
  1541. )
  1542. group.add_argument(
  1543. "--cmvn_file",
  1544. type=str,
  1545. help="Global CMVN file",
  1546. )
  1547. group.add_argument(
  1548. "--asr_train_config",
  1549. type=str,
  1550. help="ASR training configuration",
  1551. )
  1552. group.add_argument(
  1553. "--asr_model_file",
  1554. type=str,
  1555. help="ASR model parameter file",
  1556. )
  1557. group.add_argument(
  1558. "--lm_train_config",
  1559. type=str,
  1560. help="LM training configuration",
  1561. )
  1562. group.add_argument(
  1563. "--lm_file",
  1564. type=str,
  1565. help="LM parameter file",
  1566. )
  1567. group.add_argument(
  1568. "--word_lm_train_config",
  1569. type=str,
  1570. help="Word LM training configuration",
  1571. )
  1572. group.add_argument(
  1573. "--word_lm_file",
  1574. type=str,
  1575. help="Word LM parameter file",
  1576. )
  1577. group.add_argument(
  1578. "--ngram_file",
  1579. type=str,
  1580. help="N-gram parameter file",
  1581. )
  1582. group.add_argument(
  1583. "--model_tag",
  1584. type=str,
  1585. help="Pretrained model tag. If specify this option, *_train_config and "
  1586. "*_file will be overwritten",
  1587. )
  1588. group.add_argument(
  1589. "--beam_search_config",
  1590. default={},
  1591. help="The keyword arguments for transducer beam search.",
  1592. )
  1593. group = parser.add_argument_group("Beam-search related")
  1594. group.add_argument(
  1595. "--batch_size",
  1596. type=int,
  1597. default=1,
  1598. help="The batch size for inference",
  1599. )
  1600. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  1601. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  1602. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  1603. group.add_argument(
  1604. "--maxlenratio",
  1605. type=float,
  1606. default=0.0,
  1607. help="Input length ratio to obtain max output length. "
  1608. "If maxlenratio=0.0 (default), it uses a end-detect "
  1609. "function "
  1610. "to automatically find maximum hypothesis lengths."
  1611. "If maxlenratio<0.0, its absolute value is interpreted"
  1612. "as a constant max output length",
  1613. )
  1614. group.add_argument(
  1615. "--minlenratio",
  1616. type=float,
  1617. default=0.0,
  1618. help="Input length ratio to obtain min output length",
  1619. )
  1620. group.add_argument(
  1621. "--ctc_weight",
  1622. type=float,
  1623. default=0.0,
  1624. help="CTC weight in joint decoding",
  1625. )
  1626. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  1627. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  1628. group.add_argument("--streaming", type=str2bool, default=False)
  1629. group.add_argument("--simu_streaming", type=str2bool, default=False)
  1630. group.add_argument("--chunk_size", type=int, default=16)
  1631. group.add_argument("--left_context", type=int, default=16)
  1632. group.add_argument("--right_context", type=int, default=0)
  1633. group.add_argument(
  1634. "--display_partial_hypotheses",
  1635. type=bool,
  1636. default=False,
  1637. help="Whether to display partial hypotheses during chunk-by-chunk inference.",
  1638. )
  1639. group = parser.add_argument_group("Dynamic quantization related")
  1640. group.add_argument(
  1641. "--quantize_asr_model",
  1642. type=bool,
  1643. default=False,
  1644. help="Apply dynamic quantization to ASR model.",
  1645. )
  1646. group.add_argument(
  1647. "--quantize_modules",
  1648. nargs="*",
  1649. default=None,
  1650. help="""Module names to apply dynamic quantization on.
  1651. The module names are provided as a list, where each name is separated
  1652. by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
  1653. Each specified name should be an attribute of 'torch.nn', e.g.:
  1654. torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
  1655. )
  1656. group.add_argument(
  1657. "--quantize_dtype",
  1658. type=str,
  1659. default="qint8",
  1660. choices=["float16", "qint8"],
  1661. help="Dtype for dynamic quantization.",
  1662. )
  1663. group = parser.add_argument_group("Text converter related")
  1664. group.add_argument(
  1665. "--token_type",
  1666. type=str_or_none,
  1667. default=None,
  1668. choices=["char", "bpe", None],
  1669. help="The token type for ASR model. "
  1670. "If not given, refers from the training args",
  1671. )
  1672. group.add_argument(
  1673. "--bpemodel",
  1674. type=str_or_none,
  1675. default=None,
  1676. help="The model path of sentencepiece. "
  1677. "If not given, refers from the training args",
  1678. )
  1679. group.add_argument("--token_num_relax", type=int, default=1, help="")
  1680. group.add_argument("--decoding_ind", type=int, default=0, help="")
  1681. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  1682. group.add_argument(
  1683. "--ctc_weight2",
  1684. type=float,
  1685. default=0.0,
  1686. help="CTC weight in joint decoding",
  1687. )
  1688. return parser
  1689. def main(cmd=None):
  1690. print(get_commandline_args(), file=sys.stderr)
  1691. parser = get_parser()
  1692. parser.add_argument(
  1693. "--mode",
  1694. type=str,
  1695. default="asr",
  1696. help="The decoding mode",
  1697. )
  1698. args = parser.parse_args(cmd)
  1699. kwargs = vars(args)
  1700. kwargs.pop("config", None)
  1701. # set logging messages
  1702. logging.basicConfig(
  1703. level=args.log_level,
  1704. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1705. )
  1706. logging.info("Decoding args: {}".format(kwargs))
  1707. # gpu setting
  1708. if args.ngpu > 0:
  1709. jobid = int(args.output_dir.split(".")[-1])
  1710. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  1711. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  1712. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  1713. inference_pipeline = inference_launch(**kwargs)
  1714. return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
  1715. if __name__ == "__main__":
  1716. main()