asr_inference_launch.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import argparse
  6. import logging
  7. import os
  8. import sys
  9. from typing import Union, Dict, Any
  10. from funasr.utils import config_argparse
  11. from funasr.utils.cli_utils import get_commandline_args
  12. from funasr.utils.types import str2bool
  13. from funasr.utils.types import str2triple_str
  14. from funasr.utils.types import str_or_none
  15. #!/usr/bin/env python3
  16. import argparse
  17. import logging
  18. import sys
  19. import time
  20. import copy
  21. import os
  22. import codecs
  23. import tempfile
  24. import requests
  25. from pathlib import Path
  26. from typing import Optional
  27. from typing import Sequence
  28. from typing import Tuple
  29. from typing import Union
  30. from typing import Dict
  31. from typing import Any
  32. from typing import List
  33. import yaml
  34. import numpy as np
  35. import torch
  36. import torchaudio
  37. from typeguard import check_argument_types
  38. from typeguard import check_return_type
  39. from funasr.fileio.datadir_writer import DatadirWriter
  40. from funasr.modules.beam_search.beam_search import BeamSearch
  41. # from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
  42. from funasr.modules.beam_search.beam_search import Hypothesis
  43. from funasr.modules.scorers.ctc import CTCPrefixScorer
  44. from funasr.modules.scorers.length_bonus import LengthBonus
  45. from funasr.modules.subsampling import TooShortUttError
  46. from funasr.tasks.asr import ASRTask
  47. from funasr.tasks.lm import LMTask
  48. from funasr.text.build_tokenizer import build_tokenizer
  49. from funasr.text.token_id_converter import TokenIDConverter
  50. from funasr.torch_utils.device_funcs import to_device
  51. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  52. from funasr.utils import config_argparse
  53. from funasr.utils.cli_utils import get_commandline_args
  54. from funasr.utils.types import str2bool
  55. from funasr.utils.types import str2triple_str
  56. from funasr.utils.types import str_or_none
  57. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  58. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  59. from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
  60. from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
  61. from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
  62. from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
  63. from funasr.utils.vad_utils import slice_padding_fbank
  64. from funasr.tasks.vad import VADTask
  65. from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
  66. from funasr.bin.asr_infer import Speech2Text
  67. from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
  68. from funasr.bin.asr_infer import Speech2TextUniASR
  69. from funasr.bin.asr_infer import Speech2TextMFCCA
  70. from funasr.bin.vad_infer import Speech2VadSegment
  71. from funasr.bin.punc_infer import Text2Punc
  72. from funasr.bin.tp_infer import Speech2Timestamp
  73. from funasr.bin.asr_infer import Speech2TextTransducer
  74. from funasr.bin.asr_infer import Speech2TextSAASR
  75. def inference_asr(
  76. maxlenratio: float,
  77. minlenratio: float,
  78. batch_size: int,
  79. beam_size: int,
  80. ngpu: int,
  81. ctc_weight: float,
  82. lm_weight: float,
  83. penalty: float,
  84. log_level: Union[int, str],
  85. # data_path_and_name_and_type,
  86. asr_train_config: Optional[str],
  87. asr_model_file: Optional[str],
  88. cmvn_file: Optional[str] = None,
  89. lm_train_config: Optional[str] = None,
  90. lm_file: Optional[str] = None,
  91. token_type: Optional[str] = None,
  92. key_file: Optional[str] = None,
  93. word_lm_train_config: Optional[str] = None,
  94. bpemodel: Optional[str] = None,
  95. allow_variable_data_keys: bool = False,
  96. streaming: bool = False,
  97. output_dir: Optional[str] = None,
  98. dtype: str = "float32",
  99. seed: int = 0,
  100. ngram_weight: float = 0.9,
  101. nbest: int = 1,
  102. num_workers: int = 1,
  103. mc: bool = False,
  104. param_dict: dict = None,
  105. **kwargs,
  106. ):
  107. assert check_argument_types()
  108. ncpu = kwargs.get("ncpu", 1)
  109. torch.set_num_threads(ncpu)
  110. if batch_size > 1:
  111. raise NotImplementedError("batch decoding is not implemented")
  112. if word_lm_train_config is not None:
  113. raise NotImplementedError("Word LM is not implemented")
  114. if ngpu > 1:
  115. raise NotImplementedError("only single GPU decoding is supported")
  116. for handler in logging.root.handlers[:]:
  117. logging.root.removeHandler(handler)
  118. logging.basicConfig(
  119. level=log_level,
  120. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  121. )
  122. if ngpu >= 1 and torch.cuda.is_available():
  123. device = "cuda"
  124. else:
  125. device = "cpu"
  126. # 1. Set random-seed
  127. set_all_random_seed(seed)
  128. # 2. Build speech2text
  129. speech2text_kwargs = dict(
  130. asr_train_config=asr_train_config,
  131. asr_model_file=asr_model_file,
  132. cmvn_file=cmvn_file,
  133. lm_train_config=lm_train_config,
  134. lm_file=lm_file,
  135. token_type=token_type,
  136. bpemodel=bpemodel,
  137. device=device,
  138. maxlenratio=maxlenratio,
  139. minlenratio=minlenratio,
  140. dtype=dtype,
  141. beam_size=beam_size,
  142. ctc_weight=ctc_weight,
  143. lm_weight=lm_weight,
  144. ngram_weight=ngram_weight,
  145. penalty=penalty,
  146. nbest=nbest,
  147. streaming=streaming,
  148. )
  149. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  150. speech2text = Speech2Text(**speech2text_kwargs)
  151. def _forward(data_path_and_name_and_type,
  152. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  153. output_dir_v2: Optional[str] = None,
  154. fs: dict = None,
  155. param_dict: dict = None,
  156. **kwargs,
  157. ):
  158. # 3. Build data-iterator
  159. if data_path_and_name_and_type is None and raw_inputs is not None:
  160. if isinstance(raw_inputs, torch.Tensor):
  161. raw_inputs = raw_inputs.numpy()
  162. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  163. loader = ASRTask.build_streaming_iterator(
  164. data_path_and_name_and_type,
  165. dtype=dtype,
  166. fs=fs,
  167. mc=mc,
  168. batch_size=batch_size,
  169. key_file=key_file,
  170. num_workers=num_workers,
  171. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  172. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  173. allow_variable_data_keys=allow_variable_data_keys,
  174. inference=True,
  175. )
  176. finish_count = 0
  177. file_count = 1
  178. # 7 .Start for-loop
  179. # FIXME(kamo): The output format should be discussed about
  180. asr_result_list = []
  181. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  182. if output_path is not None:
  183. writer = DatadirWriter(output_path)
  184. else:
  185. writer = None
  186. for keys, batch in loader:
  187. assert isinstance(batch, dict), type(batch)
  188. assert all(isinstance(s, str) for s in keys), keys
  189. _bs = len(next(iter(batch.values())))
  190. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  191. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  192. # N-best list of (text, token, token_int, hyp_object)
  193. try:
  194. results = speech2text(**batch)
  195. except TooShortUttError as e:
  196. logging.warning(f"Utterance {keys} {e}")
  197. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  198. results = [[" ", ["sil"], [2], hyp]] * nbest
  199. # Only supporting batch_size==1
  200. key = keys[0]
  201. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  202. # Create a directory: outdir/{n}best_recog
  203. if writer is not None:
  204. ibest_writer = writer[f"{n}best_recog"]
  205. # Write the result to each file
  206. ibest_writer["token"][key] = " ".join(token)
  207. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  208. ibest_writer["score"][key] = str(hyp.score)
  209. if text is not None:
  210. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  211. item = {'key': key, 'value': text_postprocessed}
  212. asr_result_list.append(item)
  213. finish_count += 1
  214. asr_utils.print_progress(finish_count / file_count)
  215. if writer is not None:
  216. ibest_writer["text"][key] = text
  217. logging.info("uttid: {}".format(key))
  218. logging.info("text predictions: {}\n".format(text))
  219. return asr_result_list
  220. return _forward
  221. def inference_paraformer(
  222. maxlenratio: float,
  223. minlenratio: float,
  224. batch_size: int,
  225. beam_size: int,
  226. ngpu: int,
  227. ctc_weight: float,
  228. lm_weight: float,
  229. penalty: float,
  230. log_level: Union[int, str],
  231. # data_path_and_name_and_type,
  232. asr_train_config: Optional[str],
  233. asr_model_file: Optional[str],
  234. cmvn_file: Optional[str] = None,
  235. lm_train_config: Optional[str] = None,
  236. lm_file: Optional[str] = None,
  237. token_type: Optional[str] = None,
  238. key_file: Optional[str] = None,
  239. word_lm_train_config: Optional[str] = None,
  240. bpemodel: Optional[str] = None,
  241. allow_variable_data_keys: bool = False,
  242. dtype: str = "float32",
  243. seed: int = 0,
  244. ngram_weight: float = 0.9,
  245. nbest: int = 1,
  246. num_workers: int = 1,
  247. output_dir: Optional[str] = None,
  248. timestamp_infer_config: Union[Path, str] = None,
  249. timestamp_model_file: Union[Path, str] = None,
  250. param_dict: dict = None,
  251. **kwargs,
  252. ):
  253. assert check_argument_types()
  254. ncpu = kwargs.get("ncpu", 1)
  255. torch.set_num_threads(ncpu)
  256. if word_lm_train_config is not None:
  257. raise NotImplementedError("Word LM is not implemented")
  258. if ngpu > 1:
  259. raise NotImplementedError("only single GPU decoding is supported")
  260. logging.basicConfig(
  261. level=log_level,
  262. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  263. )
  264. export_mode = False
  265. if param_dict is not None:
  266. hotword_list_or_file = param_dict.get('hotword')
  267. export_mode = param_dict.get("export_mode", False)
  268. else:
  269. hotword_list_or_file = None
  270. if kwargs.get("device", None) == "cpu":
  271. ngpu = 0
  272. if ngpu >= 1 and torch.cuda.is_available():
  273. device = "cuda"
  274. else:
  275. device = "cpu"
  276. batch_size = 1
  277. # 1. Set random-seed
  278. set_all_random_seed(seed)
  279. # 2. Build speech2text
  280. speech2text_kwargs = dict(
  281. asr_train_config=asr_train_config,
  282. asr_model_file=asr_model_file,
  283. cmvn_file=cmvn_file,
  284. lm_train_config=lm_train_config,
  285. lm_file=lm_file,
  286. token_type=token_type,
  287. bpemodel=bpemodel,
  288. device=device,
  289. maxlenratio=maxlenratio,
  290. minlenratio=minlenratio,
  291. dtype=dtype,
  292. beam_size=beam_size,
  293. ctc_weight=ctc_weight,
  294. lm_weight=lm_weight,
  295. ngram_weight=ngram_weight,
  296. penalty=penalty,
  297. nbest=nbest,
  298. hotword_list_or_file=hotword_list_or_file,
  299. )
  300. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  301. if timestamp_model_file is not None:
  302. speechtext2timestamp = Speech2Timestamp(
  303. timestamp_cmvn_file=cmvn_file,
  304. timestamp_model_file=timestamp_model_file,
  305. timestamp_infer_config=timestamp_infer_config,
  306. )
  307. else:
  308. speechtext2timestamp = None
  309. def _forward(
  310. data_path_and_name_and_type,
  311. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  312. output_dir_v2: Optional[str] = None,
  313. fs: dict = None,
  314. param_dict: dict = None,
  315. **kwargs,
  316. ):
  317. hotword_list_or_file = None
  318. if param_dict is not None:
  319. hotword_list_or_file = param_dict.get('hotword')
  320. if 'hotword' in kwargs and kwargs['hotword'] is not None:
  321. hotword_list_or_file = kwargs['hotword']
  322. if hotword_list_or_file is not None or 'hotword' in kwargs:
  323. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  324. # 3. Build data-iterator
  325. if data_path_and_name_and_type is None and raw_inputs is not None:
  326. if isinstance(raw_inputs, torch.Tensor):
  327. raw_inputs = raw_inputs.numpy()
  328. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  329. loader = ASRTask.build_streaming_iterator(
  330. data_path_and_name_and_type,
  331. dtype=dtype,
  332. fs=fs,
  333. batch_size=batch_size,
  334. key_file=key_file,
  335. num_workers=num_workers,
  336. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  337. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  338. allow_variable_data_keys=allow_variable_data_keys,
  339. inference=True,
  340. )
  341. if param_dict is not None:
  342. use_timestamp = param_dict.get('use_timestamp', True)
  343. else:
  344. use_timestamp = True
  345. forward_time_total = 0.0
  346. length_total = 0.0
  347. finish_count = 0
  348. file_count = 1
  349. # 7 .Start for-loop
  350. # FIXME(kamo): The output format should be discussed about
  351. asr_result_list = []
  352. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  353. if output_path is not None:
  354. writer = DatadirWriter(output_path)
  355. else:
  356. writer = None
  357. for keys, batch in loader:
  358. assert isinstance(batch, dict), type(batch)
  359. assert all(isinstance(s, str) for s in keys), keys
  360. _bs = len(next(iter(batch.values())))
  361. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  362. # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
  363. logging.info("decoding, utt_id: {}".format(keys))
  364. # N-best list of (text, token, token_int, hyp_object)
  365. time_beg = time.time()
  366. results = speech2text(**batch)
  367. if len(results) < 1:
  368. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  369. results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
  370. time_end = time.time()
  371. forward_time = time_end - time_beg
  372. lfr_factor = results[0][-1]
  373. length = results[0][-2]
  374. forward_time_total += forward_time
  375. length_total += length
  376. rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
  377. 100 * forward_time / (
  378. length * lfr_factor))
  379. logging.info(rtf_cur)
  380. for batch_id in range(_bs):
  381. result = [results[batch_id][:-2]]
  382. key = keys[batch_id]
  383. for n, result in zip(range(1, nbest + 1), result):
  384. text, token, token_int, hyp = result[0], result[1], result[2], result[3]
  385. timestamp = result[4] if len(result[4]) > 0 else None
  386. # conduct timestamp prediction here
  387. # timestamp inference requires token length
  388. # thus following inference cannot be conducted in batch
  389. if timestamp is None and speechtext2timestamp:
  390. ts_batch = {}
  391. ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
  392. ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
  393. ts_batch['text_lengths'] = torch.tensor([len(token)])
  394. us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
  395. ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token,
  396. force_time_shift=-3.0)
  397. # Create a directory: outdir/{n}best_recog
  398. if writer is not None:
  399. ibest_writer = writer[f"{n}best_recog"]
  400. # Write the result to each file
  401. ibest_writer["token"][key] = " ".join(token)
  402. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  403. ibest_writer["score"][key] = str(hyp.score)
  404. ibest_writer["rtf"][key] = rtf_cur
  405. if text is not None:
  406. if use_timestamp and timestamp is not None:
  407. postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
  408. else:
  409. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  410. timestamp_postprocessed = ""
  411. if len(postprocessed_result) == 3:
  412. text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
  413. postprocessed_result[1], \
  414. postprocessed_result[2]
  415. else:
  416. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  417. item = {'key': key, 'value': text_postprocessed}
  418. if timestamp_postprocessed != "":
  419. item['timestamp'] = timestamp_postprocessed
  420. asr_result_list.append(item)
  421. finish_count += 1
  422. # asr_utils.print_progress(finish_count / file_count)
  423. if writer is not None:
  424. ibest_writer["text"][key] = " ".join(word_lists)
  425. logging.info("decoding, utt: {}, predictions: {}".format(key, text))
  426. rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
  427. forward_time_total,
  428. 100 * forward_time_total / (
  429. length_total * lfr_factor))
  430. logging.info(rtf_avg)
  431. if writer is not None:
  432. ibest_writer["rtf"]["rtf_avf"] = rtf_avg
  433. return asr_result_list
  434. return _forward
  435. def inference_paraformer_vad_punc(
  436. maxlenratio: float,
  437. minlenratio: float,
  438. batch_size: int,
  439. beam_size: int,
  440. ngpu: int,
  441. ctc_weight: float,
  442. lm_weight: float,
  443. penalty: float,
  444. log_level: Union[int, str],
  445. # data_path_and_name_and_type,
  446. asr_train_config: Optional[str],
  447. asr_model_file: Optional[str],
  448. cmvn_file: Optional[str] = None,
  449. lm_train_config: Optional[str] = None,
  450. lm_file: Optional[str] = None,
  451. token_type: Optional[str] = None,
  452. key_file: Optional[str] = None,
  453. word_lm_train_config: Optional[str] = None,
  454. bpemodel: Optional[str] = None,
  455. allow_variable_data_keys: bool = False,
  456. output_dir: Optional[str] = None,
  457. dtype: str = "float32",
  458. seed: int = 0,
  459. ngram_weight: float = 0.9,
  460. nbest: int = 1,
  461. num_workers: int = 1,
  462. vad_infer_config: Optional[str] = None,
  463. vad_model_file: Optional[str] = None,
  464. vad_cmvn_file: Optional[str] = None,
  465. time_stamp_writer: bool = True,
  466. punc_infer_config: Optional[str] = None,
  467. punc_model_file: Optional[str] = None,
  468. outputs_dict: Optional[bool] = True,
  469. param_dict: dict = None,
  470. **kwargs,
  471. ):
  472. assert check_argument_types()
  473. ncpu = kwargs.get("ncpu", 1)
  474. torch.set_num_threads(ncpu)
  475. if word_lm_train_config is not None:
  476. raise NotImplementedError("Word LM is not implemented")
  477. if ngpu > 1:
  478. raise NotImplementedError("only single GPU decoding is supported")
  479. logging.basicConfig(
  480. level=log_level,
  481. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  482. )
  483. if param_dict is not None:
  484. hotword_list_or_file = param_dict.get('hotword')
  485. else:
  486. hotword_list_or_file = None
  487. if ngpu >= 1 and torch.cuda.is_available():
  488. device = "cuda"
  489. else:
  490. device = "cpu"
  491. # 1. Set random-seed
  492. set_all_random_seed(seed)
  493. # 2. Build speech2vadsegment
  494. speech2vadsegment_kwargs = dict(
  495. vad_infer_config=vad_infer_config,
  496. vad_model_file=vad_model_file,
  497. vad_cmvn_file=vad_cmvn_file,
  498. device=device,
  499. dtype=dtype,
  500. )
  501. # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  502. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  503. # 3. Build speech2text
  504. speech2text_kwargs = dict(
  505. asr_train_config=asr_train_config,
  506. asr_model_file=asr_model_file,
  507. cmvn_file=cmvn_file,
  508. lm_train_config=lm_train_config,
  509. lm_file=lm_file,
  510. token_type=token_type,
  511. bpemodel=bpemodel,
  512. device=device,
  513. maxlenratio=maxlenratio,
  514. minlenratio=minlenratio,
  515. dtype=dtype,
  516. beam_size=beam_size,
  517. ctc_weight=ctc_weight,
  518. lm_weight=lm_weight,
  519. ngram_weight=ngram_weight,
  520. penalty=penalty,
  521. nbest=nbest,
  522. hotword_list_or_file=hotword_list_or_file,
  523. )
  524. speech2text = Speech2TextParaformer(**speech2text_kwargs)
  525. text2punc = None
  526. if punc_model_file is not None:
  527. text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
  528. if output_dir is not None:
  529. writer = DatadirWriter(output_dir)
  530. ibest_writer = writer[f"1best_recog"]
  531. ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
  532. def _forward(data_path_and_name_and_type,
  533. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  534. output_dir_v2: Optional[str] = None,
  535. fs: dict = None,
  536. param_dict: dict = None,
  537. **kwargs,
  538. ):
  539. hotword_list_or_file = None
  540. if param_dict is not None:
  541. hotword_list_or_file = param_dict.get('hotword')
  542. if 'hotword' in kwargs:
  543. hotword_list_or_file = kwargs['hotword']
  544. batch_size_token = kwargs.get("batch_size_token", 6000)
  545. print("batch_size_token: ", batch_size_token)
  546. if speech2text.hotword_list is None:
  547. speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
  548. # 3. Build data-iterator
  549. if data_path_and_name_and_type is None and raw_inputs is not None:
  550. if isinstance(raw_inputs, torch.Tensor):
  551. raw_inputs = raw_inputs.numpy()
  552. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  553. loader = ASRTask.build_streaming_iterator(
  554. data_path_and_name_and_type,
  555. dtype=dtype,
  556. fs=fs,
  557. batch_size=1,
  558. key_file=key_file,
  559. num_workers=num_workers,
  560. preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
  561. collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
  562. allow_variable_data_keys=allow_variable_data_keys,
  563. inference=True,
  564. )
  565. if param_dict is not None:
  566. use_timestamp = param_dict.get('use_timestamp', True)
  567. else:
  568. use_timestamp = True
  569. finish_count = 0
  570. file_count = 1
  571. lfr_factor = 6
  572. # 7 .Start for-loop
  573. asr_result_list = []
  574. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  575. writer = None
  576. if output_path is not None:
  577. writer = DatadirWriter(output_path)
  578. ibest_writer = writer[f"1best_recog"]
  579. for keys, batch in loader:
  580. assert isinstance(batch, dict), type(batch)
  581. assert all(isinstance(s, str) for s in keys), keys
  582. _bs = len(next(iter(batch.values())))
  583. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  584. beg_vad = time.time()
  585. vad_results = speech2vadsegment(**batch)
  586. end_vad = time.time()
  587. print("time cost vad: ", end_vad-beg_vad)
  588. _, vadsegments = vad_results[0], vad_results[1][0]
  589. speech, speech_lengths = batch["speech"], batch["speech_lengths"]
  590. n = len(vadsegments)
  591. data_with_index = [(vadsegments[i], i) for i in range(n)]
  592. sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
  593. results_sorted = []
  594. batch_size_token_ms = batch_size_token*60
  595. batch_size_token_ms_cum = 0
  596. beg_idx = 0
  597. for j, _ in enumerate(range(0, n)):
  598. batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
  599. if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
  600. continue
  601. batch_size_token_ms_cum = 0
  602. end_idx = j + 1
  603. speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
  604. beg_idx = end_idx
  605. batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
  606. batch = to_device(batch, device=device)
  607. print("batch: ", speech_j.shape[0])
  608. beg_asr = time.time()
  609. results = speech2text(**batch)
  610. end_asr = time.time()
  611. print("time cost asr: ", end_asr - beg_asr)
  612. if len(results) < 1:
  613. results = [["", [], [], [], [], [], []]]
  614. results_sorted.extend(results)
  615. restored_data = [0] * n
  616. for j in range(n):
  617. index = sorted_data[j][1]
  618. restored_data[index] = results_sorted[j]
  619. result = ["", [], [], [], [], [], []]
  620. for j in range(n):
  621. result[0] += restored_data[j][0]
  622. result[1] += restored_data[j][1]
  623. result[2] += restored_data[j][2]
  624. if len(restored_data[j][4]) > 0:
  625. for t in restored_data[j][4]:
  626. t[0] += vadsegments[j][0]
  627. t[1] += vadsegments[j][0]
  628. result[4] += restored_data[j][4]
  629. # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
  630. key = keys[0]
  631. # result = result_segments[0]
  632. text, token, token_int = result[0], result[1], result[2]
  633. time_stamp = result[4] if len(result[4]) > 0 else None
  634. if use_timestamp and time_stamp is not None:
  635. postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
  636. else:
  637. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  638. text_postprocessed = ""
  639. time_stamp_postprocessed = ""
  640. text_postprocessed_punc = postprocessed_result
  641. if len(postprocessed_result) == 3:
  642. text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
  643. postprocessed_result[1], \
  644. postprocessed_result[2]
  645. else:
  646. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  647. text_postprocessed_punc = text_postprocessed
  648. punc_id_list = []
  649. if len(word_lists) > 0 and text2punc is not None:
  650. beg_punc = time.time()
  651. text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
  652. end_punc = time.time()
  653. print("time cost punc: ", end_punc-beg_punc)
  654. item = {'key': key, 'value': text_postprocessed_punc}
  655. if text_postprocessed != "":
  656. item['text_postprocessed'] = text_postprocessed
  657. if time_stamp_postprocessed != "":
  658. item['time_stamp'] = time_stamp_postprocessed
  659. item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
  660. asr_result_list.append(item)
  661. finish_count += 1
  662. # asr_utils.print_progress(finish_count / file_count)
  663. if writer is not None:
  664. # Write the result to each file
  665. ibest_writer["token"][key] = " ".join(token)
  666. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  667. ibest_writer["vad"][key] = "{}".format(vadsegments)
  668. ibest_writer["text"][key] = " ".join(word_lists)
  669. ibest_writer["text_with_punc"][key] = text_postprocessed_punc
  670. if time_stamp_postprocessed is not None:
  671. ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
  672. logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
  673. return asr_result_list
  674. return _forward
  675. def inference_paraformer_online(
  676. maxlenratio: float,
  677. minlenratio: float,
  678. batch_size: int,
  679. beam_size: int,
  680. ngpu: int,
  681. ctc_weight: float,
  682. lm_weight: float,
  683. penalty: float,
  684. log_level: Union[int, str],
  685. # data_path_and_name_and_type,
  686. asr_train_config: Optional[str],
  687. asr_model_file: Optional[str],
  688. cmvn_file: Optional[str] = None,
  689. lm_train_config: Optional[str] = None,
  690. lm_file: Optional[str] = None,
  691. token_type: Optional[str] = None,
  692. key_file: Optional[str] = None,
  693. word_lm_train_config: Optional[str] = None,
  694. bpemodel: Optional[str] = None,
  695. allow_variable_data_keys: bool = False,
  696. dtype: str = "float32",
  697. seed: int = 0,
  698. ngram_weight: float = 0.9,
  699. nbest: int = 1,
  700. num_workers: int = 1,
  701. output_dir: Optional[str] = None,
  702. param_dict: dict = None,
  703. **kwargs,
  704. ):
  705. assert check_argument_types()
  706. if word_lm_train_config is not None:
  707. raise NotImplementedError("Word LM is not implemented")
  708. if ngpu > 1:
  709. raise NotImplementedError("only single GPU decoding is supported")
  710. logging.basicConfig(
  711. level=log_level,
  712. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  713. )
  714. export_mode = False
  715. if ngpu >= 1 and torch.cuda.is_available():
  716. device = "cuda"
  717. else:
  718. device = "cpu"
  719. batch_size = 1
  720. # 1. Set random-seed
  721. set_all_random_seed(seed)
  722. # 2. Build speech2text
  723. speech2text_kwargs = dict(
  724. asr_train_config=asr_train_config,
  725. asr_model_file=asr_model_file,
  726. cmvn_file=cmvn_file,
  727. lm_train_config=lm_train_config,
  728. lm_file=lm_file,
  729. token_type=token_type,
  730. bpemodel=bpemodel,
  731. device=device,
  732. maxlenratio=maxlenratio,
  733. minlenratio=minlenratio,
  734. dtype=dtype,
  735. beam_size=beam_size,
  736. ctc_weight=ctc_weight,
  737. lm_weight=lm_weight,
  738. ngram_weight=ngram_weight,
  739. penalty=penalty,
  740. nbest=nbest,
  741. )
  742. speech2text = Speech2TextParaformerOnline(**speech2text_kwargs)
  743. def _load_bytes(input):
  744. middle_data = np.frombuffer(input, dtype=np.int16)
  745. middle_data = np.asarray(middle_data)
  746. if middle_data.dtype.kind not in 'iu':
  747. raise TypeError("'middle_data' must be an array of integers")
  748. dtype = np.dtype('float32')
  749. if dtype.kind != 'f':
  750. raise TypeError("'dtype' must be a floating point type")
  751. i = np.iinfo(middle_data.dtype)
  752. abs_max = 2 ** (i.bits - 1)
  753. offset = i.min + abs_max
  754. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  755. return array
  756. def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
  757. if not Path(yaml_path).exists():
  758. raise FileExistsError(f'The {yaml_path} does not exist.')
  759. with open(str(yaml_path), 'rb') as f:
  760. data = yaml.load(f, Loader=yaml.Loader)
  761. return data
  762. def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
  763. if len(cache) > 0:
  764. return cache
  765. config = _read_yaml(asr_train_config)
  766. enc_output_size = config["encoder_conf"]["output_size"]
  767. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  768. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  769. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  770. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
  771. cache["encoder"] = cache_en
  772. cache_de = {"decode_fsmn": None}
  773. cache["decoder"] = cache_de
  774. return cache
  775. def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
  776. if len(cache) > 0:
  777. config = _read_yaml(asr_train_config)
  778. enc_output_size = config["encoder_conf"]["output_size"]
  779. feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
  780. cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  781. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
  782. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
  783. cache["encoder"] = cache_en
  784. cache_de = {"decode_fsmn": None}
  785. cache["decoder"] = cache_de
  786. return cache
  787. def _forward(
  788. data_path_and_name_and_type,
  789. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  790. output_dir_v2: Optional[str] = None,
  791. fs: dict = None,
  792. param_dict: dict = None,
  793. **kwargs,
  794. ):
  795. # 3. Build data-iterator
  796. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
  797. raw_inputs = _load_bytes(data_path_and_name_and_type[0])
  798. raw_inputs = torch.tensor(raw_inputs)
  799. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  800. raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
  801. if data_path_and_name_and_type is None and raw_inputs is not None:
  802. if isinstance(raw_inputs, np.ndarray):
  803. raw_inputs = torch.tensor(raw_inputs)
  804. is_final = False
  805. cache = {}
  806. chunk_size = [5, 10, 5]
  807. if param_dict is not None and "cache" in param_dict:
  808. cache = param_dict["cache"]
  809. if param_dict is not None and "is_final" in param_dict:
  810. is_final = param_dict["is_final"]
  811. if param_dict is not None and "chunk_size" in param_dict:
  812. chunk_size = param_dict["chunk_size"]
  813. # 7 .Start for-loop
  814. # FIXME(kamo): The output format should be discussed about
  815. raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
  816. asr_result_list = []
  817. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  818. item = {}
  819. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  820. sample_offset = 0
  821. speech_length = raw_inputs.shape[1]
  822. stride_size = chunk_size[1] * 960
  823. cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
  824. final_result = ""
  825. for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
  826. if sample_offset + stride_size >= speech_length - 1:
  827. stride_size = speech_length - sample_offset
  828. cache["encoder"]["is_final"] = True
  829. else:
  830. cache["encoder"]["is_final"] = False
  831. input_lens = torch.tensor([stride_size])
  832. asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
  833. if len(asr_result) != 0:
  834. final_result += " ".join(asr_result) + " "
  835. item = {'key': "utt", 'value': final_result.strip()}
  836. else:
  837. input_lens = torch.tensor([raw_inputs.shape[1]])
  838. cache["encoder"]["is_final"] = is_final
  839. asr_result = speech2text(cache, raw_inputs, input_lens)
  840. item = {'key': "utt", 'value': " ".join(asr_result)}
  841. asr_result_list.append(item)
  842. if is_final:
  843. cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
  844. return asr_result_list
  845. return _forward
  846. def inference_uniasr(
  847. maxlenratio: float,
  848. minlenratio: float,
  849. batch_size: int,
  850. beam_size: int,
  851. ngpu: int,
  852. ctc_weight: float,
  853. lm_weight: float,
  854. penalty: float,
  855. log_level: Union[int, str],
  856. # data_path_and_name_and_type,
  857. asr_train_config: Optional[str],
  858. asr_model_file: Optional[str],
  859. ngram_file: Optional[str] = None,
  860. cmvn_file: Optional[str] = None,
  861. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  862. lm_train_config: Optional[str] = None,
  863. lm_file: Optional[str] = None,
  864. token_type: Optional[str] = None,
  865. key_file: Optional[str] = None,
  866. word_lm_train_config: Optional[str] = None,
  867. bpemodel: Optional[str] = None,
  868. allow_variable_data_keys: bool = False,
  869. streaming: bool = False,
  870. output_dir: Optional[str] = None,
  871. dtype: str = "float32",
  872. seed: int = 0,
  873. ngram_weight: float = 0.9,
  874. nbest: int = 1,
  875. num_workers: int = 1,
  876. token_num_relax: int = 1,
  877. decoding_ind: int = 0,
  878. decoding_mode: str = "model1",
  879. param_dict: dict = None,
  880. **kwargs,
  881. ):
  882. assert check_argument_types()
  883. ncpu = kwargs.get("ncpu", 1)
  884. torch.set_num_threads(ncpu)
  885. if batch_size > 1:
  886. raise NotImplementedError("batch decoding is not implemented")
  887. if word_lm_train_config is not None:
  888. raise NotImplementedError("Word LM is not implemented")
  889. if ngpu > 1:
  890. raise NotImplementedError("only single GPU decoding is supported")
  891. logging.basicConfig(
  892. level=log_level,
  893. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  894. )
  895. if ngpu >= 1 and torch.cuda.is_available():
  896. device = "cuda"
  897. else:
  898. device = "cpu"
  899. if param_dict is not None and "decoding_model" in param_dict:
  900. if param_dict["decoding_model"] == "fast":
  901. decoding_ind = 0
  902. decoding_mode = "model1"
  903. elif param_dict["decoding_model"] == "normal":
  904. decoding_ind = 0
  905. decoding_mode = "model2"
  906. elif param_dict["decoding_model"] == "offline":
  907. decoding_ind = 1
  908. decoding_mode = "model2"
  909. else:
  910. raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
  911. # 1. Set random-seed
  912. set_all_random_seed(seed)
  913. # 2. Build speech2text
  914. speech2text_kwargs = dict(
  915. asr_train_config=asr_train_config,
  916. asr_model_file=asr_model_file,
  917. cmvn_file=cmvn_file,
  918. lm_train_config=lm_train_config,
  919. lm_file=lm_file,
  920. ngram_file=ngram_file,
  921. token_type=token_type,
  922. bpemodel=bpemodel,
  923. device=device,
  924. maxlenratio=maxlenratio,
  925. minlenratio=minlenratio,
  926. dtype=dtype,
  927. beam_size=beam_size,
  928. ctc_weight=ctc_weight,
  929. lm_weight=lm_weight,
  930. ngram_weight=ngram_weight,
  931. penalty=penalty,
  932. nbest=nbest,
  933. streaming=streaming,
  934. token_num_relax=token_num_relax,
  935. decoding_ind=decoding_ind,
  936. decoding_mode=decoding_mode,
  937. )
  938. speech2text = Speech2TextUniASR(**speech2text_kwargs)
  939. def _forward(data_path_and_name_and_type,
  940. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  941. output_dir_v2: Optional[str] = None,
  942. fs: dict = None,
  943. param_dict: dict = None,
  944. **kwargs,
  945. ):
  946. # 3. Build data-iterator
  947. if data_path_and_name_and_type is None and raw_inputs is not None:
  948. if isinstance(raw_inputs, torch.Tensor):
  949. raw_inputs = raw_inputs.numpy()
  950. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  951. loader = ASRTask.build_streaming_iterator(
  952. data_path_and_name_and_type,
  953. dtype=dtype,
  954. fs=fs,
  955. batch_size=batch_size,
  956. key_file=key_file,
  957. num_workers=num_workers,
  958. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  959. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  960. allow_variable_data_keys=allow_variable_data_keys,
  961. inference=True,
  962. )
  963. finish_count = 0
  964. file_count = 1
  965. # 7 .Start for-loop
  966. # FIXME(kamo): The output format should be discussed about
  967. asr_result_list = []
  968. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  969. if output_path is not None:
  970. writer = DatadirWriter(output_path)
  971. else:
  972. writer = None
  973. for keys, batch in loader:
  974. assert isinstance(batch, dict), type(batch)
  975. assert all(isinstance(s, str) for s in keys), keys
  976. _bs = len(next(iter(batch.values())))
  977. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  978. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  979. # N-best list of (text, token, token_int, hyp_object)
  980. try:
  981. results = speech2text(**batch)
  982. except TooShortUttError as e:
  983. logging.warning(f"Utterance {keys} {e}")
  984. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  985. results = [[" ", ["sil"], [2], hyp]] * nbest
  986. # Only supporting batch_size==1
  987. key = keys[0]
  988. logging.info(f"Utterance: {key}")
  989. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  990. # Create a directory: outdir/{n}best_recog
  991. if writer is not None:
  992. ibest_writer = writer[f"{n}best_recog"]
  993. # Write the result to each file
  994. ibest_writer["token"][key] = " ".join(token)
  995. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  996. ibest_writer["score"][key] = str(hyp.score)
  997. if text is not None:
  998. text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
  999. item = {'key': key, 'value': text_postprocessed}
  1000. asr_result_list.append(item)
  1001. finish_count += 1
  1002. asr_utils.print_progress(finish_count / file_count)
  1003. if writer is not None:
  1004. ibest_writer["text"][key] = " ".join(word_lists)
  1005. return asr_result_list
  1006. return _forward
  1007. def inference_mfcca(
  1008. maxlenratio: float,
  1009. minlenratio: float,
  1010. batch_size: int,
  1011. beam_size: int,
  1012. ngpu: int,
  1013. ctc_weight: float,
  1014. lm_weight: float,
  1015. penalty: float,
  1016. log_level: Union[int, str],
  1017. # data_path_and_name_and_type,
  1018. asr_train_config: Optional[str],
  1019. asr_model_file: Optional[str],
  1020. cmvn_file: Optional[str] = None,
  1021. lm_train_config: Optional[str] = None,
  1022. lm_file: Optional[str] = None,
  1023. token_type: Optional[str] = None,
  1024. key_file: Optional[str] = None,
  1025. word_lm_train_config: Optional[str] = None,
  1026. bpemodel: Optional[str] = None,
  1027. allow_variable_data_keys: bool = False,
  1028. streaming: bool = False,
  1029. output_dir: Optional[str] = None,
  1030. dtype: str = "float32",
  1031. seed: int = 0,
  1032. ngram_weight: float = 0.9,
  1033. nbest: int = 1,
  1034. num_workers: int = 1,
  1035. param_dict: dict = None,
  1036. **kwargs,
  1037. ):
  1038. assert check_argument_types()
  1039. ncpu = kwargs.get("ncpu", 1)
  1040. torch.set_num_threads(ncpu)
  1041. if batch_size > 1:
  1042. raise NotImplementedError("batch decoding is not implemented")
  1043. if word_lm_train_config is not None:
  1044. raise NotImplementedError("Word LM is not implemented")
  1045. if ngpu > 1:
  1046. raise NotImplementedError("only single GPU decoding is supported")
  1047. logging.basicConfig(
  1048. level=log_level,
  1049. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1050. )
  1051. if ngpu >= 1 and torch.cuda.is_available():
  1052. device = "cuda"
  1053. else:
  1054. device = "cpu"
  1055. # 1. Set random-seed
  1056. set_all_random_seed(seed)
  1057. # 2. Build speech2text
  1058. speech2text_kwargs = dict(
  1059. asr_train_config=asr_train_config,
  1060. asr_model_file=asr_model_file,
  1061. cmvn_file=cmvn_file,
  1062. lm_train_config=lm_train_config,
  1063. lm_file=lm_file,
  1064. token_type=token_type,
  1065. bpemodel=bpemodel,
  1066. device=device,
  1067. maxlenratio=maxlenratio,
  1068. minlenratio=minlenratio,
  1069. dtype=dtype,
  1070. beam_size=beam_size,
  1071. ctc_weight=ctc_weight,
  1072. lm_weight=lm_weight,
  1073. ngram_weight=ngram_weight,
  1074. penalty=penalty,
  1075. nbest=nbest,
  1076. streaming=streaming,
  1077. )
  1078. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1079. speech2text = Speech2TextMFCCA(**speech2text_kwargs)
  1080. def _forward(data_path_and_name_and_type,
  1081. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1082. output_dir_v2: Optional[str] = None,
  1083. fs: dict = None,
  1084. param_dict: dict = None,
  1085. **kwargs,
  1086. ):
  1087. # 3. Build data-iterator
  1088. if data_path_and_name_and_type is None and raw_inputs is not None:
  1089. if isinstance(raw_inputs, torch.Tensor):
  1090. raw_inputs = raw_inputs.numpy()
  1091. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1092. loader = ASRTask.build_streaming_iterator(
  1093. data_path_and_name_and_type,
  1094. dtype=dtype,
  1095. batch_size=batch_size,
  1096. fs=fs,
  1097. mc=True,
  1098. key_file=key_file,
  1099. num_workers=num_workers,
  1100. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  1101. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  1102. allow_variable_data_keys=allow_variable_data_keys,
  1103. inference=True,
  1104. )
  1105. finish_count = 0
  1106. file_count = 1
  1107. # 7 .Start for-loop
  1108. # FIXME(kamo): The output format should be discussed about
  1109. asr_result_list = []
  1110. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1111. if output_path is not None:
  1112. writer = DatadirWriter(output_path)
  1113. else:
  1114. writer = None
  1115. for keys, batch in loader:
  1116. assert isinstance(batch, dict), type(batch)
  1117. assert all(isinstance(s, str) for s in keys), keys
  1118. _bs = len(next(iter(batch.values())))
  1119. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1120. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1121. # N-best list of (text, token, token_int, hyp_object)
  1122. try:
  1123. results = speech2text(**batch)
  1124. except TooShortUttError as e:
  1125. logging.warning(f"Utterance {keys} {e}")
  1126. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1127. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1128. # Only supporting batch_size==1
  1129. key = keys[0]
  1130. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1131. # Create a directory: outdir/{n}best_recog
  1132. if writer is not None:
  1133. ibest_writer = writer[f"{n}best_recog"]
  1134. # Write the result to each file
  1135. ibest_writer["token"][key] = " ".join(token)
  1136. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1137. ibest_writer["score"][key] = str(hyp.score)
  1138. if text is not None:
  1139. text_postprocessed = postprocess_utils.sentence_postprocess(token)
  1140. item = {'key': key, 'value': text_postprocessed}
  1141. asr_result_list.append(item)
  1142. finish_count += 1
  1143. asr_utils.print_progress(finish_count / file_count)
  1144. if writer is not None:
  1145. ibest_writer["text"][key] = text
  1146. return asr_result_list
  1147. return _forward
  1148. def inference_transducer(
  1149. output_dir: str,
  1150. batch_size: int,
  1151. dtype: str,
  1152. beam_size: int,
  1153. ngpu: int,
  1154. seed: int,
  1155. lm_weight: float,
  1156. nbest: int,
  1157. num_workers: int,
  1158. log_level: Union[int, str],
  1159. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  1160. asr_train_config: Optional[str],
  1161. asr_model_file: Optional[str],
  1162. cmvn_file: Optional[str],
  1163. beam_search_config: Optional[dict],
  1164. lm_train_config: Optional[str],
  1165. lm_file: Optional[str],
  1166. model_tag: Optional[str],
  1167. token_type: Optional[str],
  1168. bpemodel: Optional[str],
  1169. key_file: Optional[str],
  1170. allow_variable_data_keys: bool,
  1171. quantize_asr_model: Optional[bool],
  1172. quantize_modules: Optional[List[str]],
  1173. quantize_dtype: Optional[str],
  1174. streaming: Optional[bool],
  1175. simu_streaming: Optional[bool],
  1176. chunk_size: Optional[int],
  1177. left_context: Optional[int],
  1178. right_context: Optional[int],
  1179. display_partial_hypotheses: bool,
  1180. **kwargs,
  1181. ) -> None:
  1182. """Transducer model inference.
  1183. Args:
  1184. output_dir: Output directory path.
  1185. batch_size: Batch decoding size.
  1186. dtype: Data type.
  1187. beam_size: Beam size.
  1188. ngpu: Number of GPUs.
  1189. seed: Random number generator seed.
  1190. lm_weight: Weight of language model.
  1191. nbest: Number of final hypothesis.
  1192. num_workers: Number of workers.
  1193. log_level: Level of verbose for logs.
  1194. data_path_and_name_and_type:
  1195. asr_train_config: ASR model training config path.
  1196. asr_model_file: ASR model path.
  1197. beam_search_config: Beam search config path.
  1198. lm_train_config: Language Model training config path.
  1199. lm_file: Language Model path.
  1200. model_tag: Model tag.
  1201. token_type: Type of token units.
  1202. bpemodel: BPE model path.
  1203. key_file: File key.
  1204. allow_variable_data_keys: Whether to allow variable data keys.
  1205. quantize_asr_model: Whether to apply dynamic quantization to ASR model.
  1206. quantize_modules: List of module names to apply dynamic quantization on.
  1207. quantize_dtype: Dynamic quantization data type.
  1208. streaming: Whether to perform chunk-by-chunk inference.
  1209. chunk_size: Number of frames in chunk AFTER subsampling.
  1210. left_context: Number of frames in left context AFTER subsampling.
  1211. right_context: Number of frames in right context AFTER subsampling.
  1212. display_partial_hypotheses: Whether to display partial hypotheses.
  1213. """
  1214. assert check_argument_types()
  1215. if batch_size > 1:
  1216. raise NotImplementedError("batch decoding is not implemented")
  1217. if ngpu > 1:
  1218. raise NotImplementedError("only single GPU decoding is supported")
  1219. logging.basicConfig(
  1220. level=log_level,
  1221. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1222. )
  1223. if ngpu >= 1:
  1224. device = "cuda"
  1225. else:
  1226. device = "cpu"
  1227. # 1. Set random-seed
  1228. set_all_random_seed(seed)
  1229. # 2. Build speech2text
  1230. speech2text_kwargs = dict(
  1231. asr_train_config=asr_train_config,
  1232. asr_model_file=asr_model_file,
  1233. cmvn_file=cmvn_file,
  1234. beam_search_config=beam_search_config,
  1235. lm_train_config=lm_train_config,
  1236. lm_file=lm_file,
  1237. token_type=token_type,
  1238. bpemodel=bpemodel,
  1239. device=device,
  1240. dtype=dtype,
  1241. beam_size=beam_size,
  1242. lm_weight=lm_weight,
  1243. nbest=nbest,
  1244. quantize_asr_model=quantize_asr_model,
  1245. quantize_modules=quantize_modules,
  1246. quantize_dtype=quantize_dtype,
  1247. streaming=streaming,
  1248. simu_streaming=simu_streaming,
  1249. chunk_size=chunk_size,
  1250. left_context=left_context,
  1251. right_context=right_context,
  1252. )
  1253. speech2text = Speech2TextTransducer.from_pretrained(
  1254. model_tag=model_tag,
  1255. **speech2text_kwargs,
  1256. )
  1257. def _forward(data_path_and_name_and_type,
  1258. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1259. output_dir_v2: Optional[str] = None,
  1260. fs: dict = None,
  1261. param_dict: dict = None,
  1262. **kwargs,
  1263. ):
  1264. # 3. Build data-iterator
  1265. loader = ASRTask.build_streaming_iterator(
  1266. data_path_and_name_and_type,
  1267. dtype=dtype,
  1268. batch_size=batch_size,
  1269. key_file=key_file,
  1270. num_workers=num_workers,
  1271. preprocess_fn=ASRTask.build_preprocess_fn(
  1272. speech2text.asr_train_args, False
  1273. ),
  1274. collate_fn=ASRTask.build_collate_fn(
  1275. speech2text.asr_train_args, False
  1276. ),
  1277. allow_variable_data_keys=allow_variable_data_keys,
  1278. inference=True,
  1279. )
  1280. # 4 .Start for-loop
  1281. with DatadirWriter(output_dir) as writer:
  1282. for keys, batch in loader:
  1283. assert isinstance(batch, dict), type(batch)
  1284. assert all(isinstance(s, str) for s in keys), keys
  1285. _bs = len(next(iter(batch.values())))
  1286. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1287. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1288. assert len(batch.keys()) == 1
  1289. try:
  1290. if speech2text.streaming:
  1291. speech = batch["speech"]
  1292. _steps = len(speech) // speech2text._ctx
  1293. _end = 0
  1294. for i in range(_steps):
  1295. _end = (i + 1) * speech2text._ctx
  1296. speech2text.streaming_decode(
  1297. speech[i * speech2text._ctx : _end], is_final=False
  1298. )
  1299. final_hyps = speech2text.streaming_decode(
  1300. speech[_end : len(speech)], is_final=True
  1301. )
  1302. elif speech2text.simu_streaming:
  1303. final_hyps = speech2text.simu_streaming_decode(**batch)
  1304. else:
  1305. final_hyps = speech2text(**batch)
  1306. results = speech2text.hypotheses_to_results(final_hyps)
  1307. except TooShortUttError as e:
  1308. logging.warning(f"Utterance {keys} {e}")
  1309. hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
  1310. results = [[" ", ["<space>"], [2], hyp]] * nbest
  1311. key = keys[0]
  1312. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1313. ibest_writer = writer[f"{n}best_recog"]
  1314. ibest_writer["token"][key] = " ".join(token)
  1315. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1316. ibest_writer["score"][key] = str(hyp.score)
  1317. if text is not None:
  1318. ibest_writer["text"][key] = text
  1319. return _forward
  1320. def inference_sa_asr(
  1321. maxlenratio: float,
  1322. minlenratio: float,
  1323. batch_size: int,
  1324. beam_size: int,
  1325. ngpu: int,
  1326. ctc_weight: float,
  1327. lm_weight: float,
  1328. penalty: float,
  1329. log_level: Union[int, str],
  1330. # data_path_and_name_and_type,
  1331. asr_train_config: Optional[str],
  1332. asr_model_file: Optional[str],
  1333. cmvn_file: Optional[str] = None,
  1334. lm_train_config: Optional[str] = None,
  1335. lm_file: Optional[str] = None,
  1336. token_type: Optional[str] = None,
  1337. key_file: Optional[str] = None,
  1338. word_lm_train_config: Optional[str] = None,
  1339. bpemodel: Optional[str] = None,
  1340. allow_variable_data_keys: bool = False,
  1341. streaming: bool = False,
  1342. output_dir: Optional[str] = None,
  1343. dtype: str = "float32",
  1344. seed: int = 0,
  1345. ngram_weight: float = 0.9,
  1346. nbest: int = 1,
  1347. num_workers: int = 1,
  1348. mc: bool = False,
  1349. param_dict: dict = None,
  1350. **kwargs,
  1351. ):
  1352. assert check_argument_types()
  1353. if batch_size > 1:
  1354. raise NotImplementedError("batch decoding is not implemented")
  1355. if word_lm_train_config is not None:
  1356. raise NotImplementedError("Word LM is not implemented")
  1357. if ngpu > 1:
  1358. raise NotImplementedError("only single GPU decoding is supported")
  1359. for handler in logging.root.handlers[:]:
  1360. logging.root.removeHandler(handler)
  1361. logging.basicConfig(
  1362. level=log_level,
  1363. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1364. )
  1365. if ngpu >= 1 and torch.cuda.is_available():
  1366. device = "cuda"
  1367. else:
  1368. device = "cpu"
  1369. # 1. Set random-seed
  1370. set_all_random_seed(seed)
  1371. # 2. Build speech2text
  1372. speech2text_kwargs = dict(
  1373. asr_train_config=asr_train_config,
  1374. asr_model_file=asr_model_file,
  1375. cmvn_file=cmvn_file,
  1376. lm_train_config=lm_train_config,
  1377. lm_file=lm_file,
  1378. token_type=token_type,
  1379. bpemodel=bpemodel,
  1380. device=device,
  1381. maxlenratio=maxlenratio,
  1382. minlenratio=minlenratio,
  1383. dtype=dtype,
  1384. beam_size=beam_size,
  1385. ctc_weight=ctc_weight,
  1386. lm_weight=lm_weight,
  1387. ngram_weight=ngram_weight,
  1388. penalty=penalty,
  1389. nbest=nbest,
  1390. streaming=streaming,
  1391. )
  1392. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  1393. speech2text = Speech2TextSAASR(**speech2text_kwargs)
  1394. def _forward(data_path_and_name_and_type,
  1395. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  1396. output_dir_v2: Optional[str] = None,
  1397. fs: dict = None,
  1398. param_dict: dict = None,
  1399. **kwargs,
  1400. ):
  1401. # 3. Build data-iterator
  1402. if data_path_and_name_and_type is None and raw_inputs is not None:
  1403. if isinstance(raw_inputs, torch.Tensor):
  1404. raw_inputs = raw_inputs.numpy()
  1405. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  1406. loader = ASRTask.build_streaming_iterator(
  1407. data_path_and_name_and_type,
  1408. dtype=dtype,
  1409. fs=fs,
  1410. mc=mc,
  1411. batch_size=batch_size,
  1412. key_file=key_file,
  1413. num_workers=num_workers,
  1414. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  1415. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  1416. allow_variable_data_keys=allow_variable_data_keys,
  1417. inference=True,
  1418. )
  1419. finish_count = 0
  1420. file_count = 1
  1421. # 7 .Start for-loop
  1422. # FIXME(kamo): The output format should be discussed about
  1423. asr_result_list = []
  1424. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  1425. if output_path is not None:
  1426. writer = DatadirWriter(output_path)
  1427. else:
  1428. writer = None
  1429. for keys, batch in loader:
  1430. assert isinstance(batch, dict), type(batch)
  1431. assert all(isinstance(s, str) for s in keys), keys
  1432. _bs = len(next(iter(batch.values())))
  1433. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  1434. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  1435. # N-best list of (text, token, token_int, hyp_object)
  1436. try:
  1437. results = speech2text(**batch)
  1438. except TooShortUttError as e:
  1439. logging.warning(f"Utterance {keys} {e}")
  1440. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  1441. results = [[" ", ["sil"], [2], hyp]] * nbest
  1442. # Only supporting batch_size==1
  1443. key = keys[0]
  1444. for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  1445. # Create a directory: outdir/{n}best_recog
  1446. if writer is not None:
  1447. ibest_writer = writer[f"{n}best_recog"]
  1448. # Write the result to each file
  1449. ibest_writer["token"][key] = " ".join(token)
  1450. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  1451. ibest_writer["score"][key] = str(hyp.score)
  1452. ibest_writer["text_id"][key] = text_id
  1453. if text is not None:
  1454. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  1455. item = {'key': key, 'value': text_postprocessed}
  1456. asr_result_list.append(item)
  1457. finish_count += 1
  1458. asr_utils.print_progress(finish_count / file_count)
  1459. if writer is not None:
  1460. ibest_writer["text"][key] = text
  1461. logging.info("uttid: {}".format(key))
  1462. logging.info("text predictions: {}".format(text))
  1463. logging.info("text_id predictions: {}\n".format(text_id))
  1464. return asr_result_list
  1465. return _forward
  1466. def inference_launch(**kwargs):
  1467. if 'mode' in kwargs:
  1468. mode = kwargs['mode']
  1469. else:
  1470. logging.info("Unknown decoding mode.")
  1471. return None
  1472. if mode == "asr":
  1473. return inference_asr(**kwargs)
  1474. elif mode == "uniasr":
  1475. return inference_uniasr(**kwargs)
  1476. elif mode == "paraformer":
  1477. return inference_paraformer(**kwargs)
  1478. elif mode == "paraformer_fake_streaming":
  1479. return inference_paraformer(**kwargs)
  1480. elif mode == "paraformer_streaming":
  1481. return inference_paraformer_online(**kwargs)
  1482. elif mode.startswith("paraformer_vad"):
  1483. return inference_paraformer_vad_punc(**kwargs)
  1484. elif mode == "mfcca":
  1485. return inference_mfcca(**kwargs)
  1486. elif mode == "rnnt":
  1487. return inference_transducer(**kwargs)
  1488. elif mode == "sa_asr":
  1489. return inference_sa_asr(**kwargs)
  1490. else:
  1491. logging.info("Unknown decoding mode: {}".format(mode))
  1492. return None
  1493. def get_parser():
  1494. parser = config_argparse.ArgumentParser(
  1495. description="ASR Decoding",
  1496. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  1497. )
  1498. # Note(kamo): Use '_' instead of '-' as separator.
  1499. # '-' is confusing if written in yaml.
  1500. parser.add_argument(
  1501. "--log_level",
  1502. type=lambda x: x.upper(),
  1503. default="INFO",
  1504. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  1505. help="The verbose level of logging",
  1506. )
  1507. parser.add_argument("--output_dir", type=str, required=True)
  1508. parser.add_argument(
  1509. "--ngpu",
  1510. type=int,
  1511. default=0,
  1512. help="The number of gpus. 0 indicates CPU mode",
  1513. )
  1514. parser.add_argument(
  1515. "--njob",
  1516. type=int,
  1517. default=1,
  1518. help="The number of jobs for each gpu",
  1519. )
  1520. parser.add_argument(
  1521. "--gpuid_list",
  1522. type=str,
  1523. default="",
  1524. help="The visible gpus",
  1525. )
  1526. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  1527. parser.add_argument(
  1528. "--dtype",
  1529. default="float32",
  1530. choices=["float16", "float32", "float64"],
  1531. help="Data type",
  1532. )
  1533. parser.add_argument(
  1534. "--num_workers",
  1535. type=int,
  1536. default=1,
  1537. help="The number of workers used for DataLoader",
  1538. )
  1539. group = parser.add_argument_group("Input data related")
  1540. group.add_argument(
  1541. "--data_path_and_name_and_type",
  1542. type=str2triple_str,
  1543. required=True,
  1544. action="append",
  1545. )
  1546. group.add_argument("--key_file", type=str_or_none)
  1547. parser.add_argument(
  1548. "--hotword",
  1549. type=str_or_none,
  1550. default=None,
  1551. help="hotword file path or hotwords seperated by space"
  1552. )
  1553. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  1554. group.add_argument(
  1555. "--mc",
  1556. type=bool,
  1557. default=False,
  1558. help="MultiChannel input",
  1559. )
  1560. group = parser.add_argument_group("The model configuration related")
  1561. group.add_argument(
  1562. "--vad_infer_config",
  1563. type=str,
  1564. help="VAD infer configuration",
  1565. )
  1566. group.add_argument(
  1567. "--vad_model_file",
  1568. type=str,
  1569. help="VAD model parameter file",
  1570. )
  1571. group.add_argument(
  1572. "--cmvn_file",
  1573. type=str,
  1574. help="Global CMVN file",
  1575. )
  1576. group.add_argument(
  1577. "--asr_train_config",
  1578. type=str,
  1579. help="ASR training configuration",
  1580. )
  1581. group.add_argument(
  1582. "--asr_model_file",
  1583. type=str,
  1584. help="ASR model parameter file",
  1585. )
  1586. group.add_argument(
  1587. "--lm_train_config",
  1588. type=str,
  1589. help="LM training configuration",
  1590. )
  1591. group.add_argument(
  1592. "--lm_file",
  1593. type=str,
  1594. help="LM parameter file",
  1595. )
  1596. group.add_argument(
  1597. "--word_lm_train_config",
  1598. type=str,
  1599. help="Word LM training configuration",
  1600. )
  1601. group.add_argument(
  1602. "--word_lm_file",
  1603. type=str,
  1604. help="Word LM parameter file",
  1605. )
  1606. group.add_argument(
  1607. "--ngram_file",
  1608. type=str,
  1609. help="N-gram parameter file",
  1610. )
  1611. group.add_argument(
  1612. "--model_tag",
  1613. type=str,
  1614. help="Pretrained model tag. If specify this option, *_train_config and "
  1615. "*_file will be overwritten",
  1616. )
  1617. group.add_argument(
  1618. "--beam_search_config",
  1619. default={},
  1620. help="The keyword arguments for transducer beam search.",
  1621. )
  1622. group = parser.add_argument_group("Beam-search related")
  1623. group.add_argument(
  1624. "--batch_size",
  1625. type=int,
  1626. default=1,
  1627. help="The batch size for inference",
  1628. )
  1629. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  1630. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  1631. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  1632. group.add_argument(
  1633. "--maxlenratio",
  1634. type=float,
  1635. default=0.0,
  1636. help="Input length ratio to obtain max output length. "
  1637. "If maxlenratio=0.0 (default), it uses a end-detect "
  1638. "function "
  1639. "to automatically find maximum hypothesis lengths."
  1640. "If maxlenratio<0.0, its absolute value is interpreted"
  1641. "as a constant max output length",
  1642. )
  1643. group.add_argument(
  1644. "--minlenratio",
  1645. type=float,
  1646. default=0.0,
  1647. help="Input length ratio to obtain min output length",
  1648. )
  1649. group.add_argument(
  1650. "--ctc_weight",
  1651. type=float,
  1652. default=0.0,
  1653. help="CTC weight in joint decoding",
  1654. )
  1655. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  1656. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  1657. group.add_argument("--streaming", type=str2bool, default=False)
  1658. group.add_argument("--simu_streaming", type=str2bool, default=False)
  1659. group.add_argument("--chunk_size", type=int, default=16)
  1660. group.add_argument("--left_context", type=int, default=16)
  1661. group.add_argument("--right_context", type=int, default=0)
  1662. group.add_argument(
  1663. "--display_partial_hypotheses",
  1664. type=bool,
  1665. default=False,
  1666. help="Whether to display partial hypotheses during chunk-by-chunk inference.",
  1667. )
  1668. group = parser.add_argument_group("Dynamic quantization related")
  1669. group.add_argument(
  1670. "--quantize_asr_model",
  1671. type=bool,
  1672. default=False,
  1673. help="Apply dynamic quantization to ASR model.",
  1674. )
  1675. group.add_argument(
  1676. "--quantize_modules",
  1677. nargs="*",
  1678. default=None,
  1679. help="""Module names to apply dynamic quantization on.
  1680. The module names are provided as a list, where each name is separated
  1681. by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
  1682. Each specified name should be an attribute of 'torch.nn', e.g.:
  1683. torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
  1684. )
  1685. group.add_argument(
  1686. "--quantize_dtype",
  1687. type=str,
  1688. default="qint8",
  1689. choices=["float16", "qint8"],
  1690. help="Dtype for dynamic quantization.",
  1691. )
  1692. group = parser.add_argument_group("Text converter related")
  1693. group.add_argument(
  1694. "--token_type",
  1695. type=str_or_none,
  1696. default=None,
  1697. choices=["char", "bpe", None],
  1698. help="The token type for ASR model. "
  1699. "If not given, refers from the training args",
  1700. )
  1701. group.add_argument(
  1702. "--bpemodel",
  1703. type=str_or_none,
  1704. default=None,
  1705. help="The model path of sentencepiece. "
  1706. "If not given, refers from the training args",
  1707. )
  1708. group.add_argument("--token_num_relax", type=int, default=1, help="")
  1709. group.add_argument("--decoding_ind", type=int, default=0, help="")
  1710. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  1711. group.add_argument(
  1712. "--ctc_weight2",
  1713. type=float,
  1714. default=0.0,
  1715. help="CTC weight in joint decoding",
  1716. )
  1717. return parser
  1718. def main(cmd=None):
  1719. print(get_commandline_args(), file=sys.stderr)
  1720. parser = get_parser()
  1721. parser.add_argument(
  1722. "--mode",
  1723. type=str,
  1724. default="asr",
  1725. help="The decoding mode",
  1726. )
  1727. args = parser.parse_args(cmd)
  1728. kwargs = vars(args)
  1729. kwargs.pop("config", None)
  1730. # set logging messages
  1731. logging.basicConfig(
  1732. level=args.log_level,
  1733. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1734. )
  1735. logging.info("Decoding args: {}".format(kwargs))
  1736. # gpu setting
  1737. if args.ngpu > 0:
  1738. jobid = int(args.output_dir.split(".")[-1])
  1739. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  1740. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  1741. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  1742. inference_pipeline = inference_launch(**kwargs)
  1743. return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
  1744. if __name__ == "__main__":
  1745. main()