asr_inference_launch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import os
  5. import sys
  6. from typing import Union, Dict, Any
  7. from funasr.utils import config_argparse
  8. from funasr.utils.cli_utils import get_commandline_args
  9. from funasr.utils.types import str2bool
  10. from funasr.utils.types import str2triple_str
  11. from funasr.utils.types import str_or_none
  12. def get_parser():
  13. parser = config_argparse.ArgumentParser(
  14. description="ASR Decoding",
  15. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  16. )
  17. # Note(kamo): Use '_' instead of '-' as separator.
  18. # '-' is confusing if written in yaml.
  19. parser.add_argument(
  20. "--log_level",
  21. type=lambda x: x.upper(),
  22. default="INFO",
  23. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  24. help="The verbose level of logging",
  25. )
  26. parser.add_argument("--output_dir", type=str, required=True)
  27. parser.add_argument(
  28. "--ngpu",
  29. type=int,
  30. default=0,
  31. help="The number of gpus. 0 indicates CPU mode",
  32. )
  33. parser.add_argument(
  34. "--njob",
  35. type=int,
  36. default=1,
  37. help="The number of jobs for each gpu",
  38. )
  39. parser.add_argument(
  40. "--gpuid_list",
  41. type=str,
  42. default="",
  43. help="The visible gpus",
  44. )
  45. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  46. parser.add_argument(
  47. "--dtype",
  48. default="float32",
  49. choices=["float16", "float32", "float64"],
  50. help="Data type",
  51. )
  52. parser.add_argument(
  53. "--num_workers",
  54. type=int,
  55. default=1,
  56. help="The number of workers used for DataLoader",
  57. )
  58. group = parser.add_argument_group("Input data related")
  59. group.add_argument(
  60. "--data_path_and_name_and_type",
  61. type=str2triple_str,
  62. required=True,
  63. action="append",
  64. )
  65. group.add_argument("--key_file", type=str_or_none)
  66. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  67. group = parser.add_argument_group("The model configuration related")
  68. group.add_argument(
  69. "--vad_infer_config",
  70. type=str,
  71. help="VAD infer configuration",
  72. )
  73. group.add_argument(
  74. "--vad_model_file",
  75. type=str,
  76. help="VAD model parameter file",
  77. )
  78. group.add_argument(
  79. "--cmvn_file",
  80. type=str,
  81. help="Global CMVN file",
  82. )
  83. group.add_argument(
  84. "--asr_train_config",
  85. type=str,
  86. help="ASR training configuration",
  87. )
  88. group.add_argument(
  89. "--asr_model_file",
  90. type=str,
  91. help="ASR model parameter file",
  92. )
  93. group.add_argument(
  94. "--lm_train_config",
  95. type=str,
  96. help="LM training configuration",
  97. )
  98. group.add_argument(
  99. "--lm_file",
  100. type=str,
  101. help="LM parameter file",
  102. )
  103. group.add_argument(
  104. "--word_lm_train_config",
  105. type=str,
  106. help="Word LM training configuration",
  107. )
  108. group.add_argument(
  109. "--word_lm_file",
  110. type=str,
  111. help="Word LM parameter file",
  112. )
  113. group.add_argument(
  114. "--ngram_file",
  115. type=str,
  116. help="N-gram parameter file",
  117. )
  118. group.add_argument(
  119. "--model_tag",
  120. type=str,
  121. help="Pretrained model tag. If specify this option, *_train_config and "
  122. "*_file will be overwritten",
  123. )
  124. group.add_argument(
  125. "--beam_search_config",
  126. default={},
  127. help="The keyword arguments for transducer beam search.",
  128. )
  129. group = parser.add_argument_group("Beam-search related")
  130. group.add_argument(
  131. "--batch_size",
  132. type=int,
  133. default=1,
  134. help="The batch size for inference",
  135. )
  136. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  137. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  138. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  139. group.add_argument(
  140. "--maxlenratio",
  141. type=float,
  142. default=0.0,
  143. help="Input length ratio to obtain max output length. "
  144. "If maxlenratio=0.0 (default), it uses a end-detect "
  145. "function "
  146. "to automatically find maximum hypothesis lengths."
  147. "If maxlenratio<0.0, its absolute value is interpreted"
  148. "as a constant max output length",
  149. )
  150. group.add_argument(
  151. "--minlenratio",
  152. type=float,
  153. default=0.0,
  154. help="Input length ratio to obtain min output length",
  155. )
  156. group.add_argument(
  157. "--ctc_weight",
  158. type=float,
  159. default=0.0,
  160. help="CTC weight in joint decoding",
  161. )
  162. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  163. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  164. group.add_argument("--streaming", type=str2bool, default=False)
  165. group.add_argument("--simu_streaming", type=str2bool, default=False)
  166. group.add_argument("--chunk_size", type=int, default=16)
  167. group.add_argument("--left_context", type=int, default=16)
  168. group.add_argument("--right_context", type=int, default=0)
  169. group.add_argument(
  170. "--display_partial_hypotheses",
  171. type=bool,
  172. default=False,
  173. help="Whether to display partial hypotheses during chunk-by-chunk inference.",
  174. )
  175. group = parser.add_argument_group("Dynamic quantization related")
  176. group.add_argument(
  177. "--quantize_asr_model",
  178. type=bool,
  179. default=False,
  180. help="Apply dynamic quantization to ASR model.",
  181. )
  182. group.add_argument(
  183. "--quantize_modules",
  184. nargs="*",
  185. default=None,
  186. help="""Module names to apply dynamic quantization on.
  187. The module names are provided as a list, where each name is separated
  188. by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
  189. Each specified name should be an attribute of 'torch.nn', e.g.:
  190. torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
  191. )
  192. group.add_argument(
  193. "--quantize_dtype",
  194. type=str,
  195. default="qint8",
  196. choices=["float16", "qint8"],
  197. help="Dtype for dynamic quantization.",
  198. )
  199. group = parser.add_argument_group("Text converter related")
  200. group.add_argument(
  201. "--token_type",
  202. type=str_or_none,
  203. default=None,
  204. choices=["char", "bpe", None],
  205. help="The token type for ASR model. "
  206. "If not given, refers from the training args",
  207. )
  208. group.add_argument(
  209. "--bpemodel",
  210. type=str_or_none,
  211. default=None,
  212. help="The model path of sentencepiece. "
  213. "If not given, refers from the training args",
  214. )
  215. group.add_argument("--token_num_relax", type=int, default=1, help="")
  216. group.add_argument("--decoding_ind", type=int, default=0, help="")
  217. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  218. group.add_argument(
  219. "--ctc_weight2",
  220. type=float,
  221. default=0.0,
  222. help="CTC weight in joint decoding",
  223. )
  224. return parser
  225. def inference_launch(**kwargs):
  226. if 'mode' in kwargs:
  227. mode = kwargs['mode']
  228. else:
  229. logging.info("Unknown decoding mode.")
  230. return None
  231. if mode == "asr":
  232. from funasr.bin.asr_inference import inference_modelscope
  233. return inference_modelscope(**kwargs)
  234. elif mode == "uniasr":
  235. from funasr.bin.asr_inference_uniasr import inference_modelscope
  236. return inference_modelscope(**kwargs)
  237. elif mode == "uniasr_vad":
  238. from funasr.bin.asr_inference_uniasr_vad import inference_modelscope
  239. return inference_modelscope(**kwargs)
  240. elif mode == "paraformer":
  241. from funasr.bin.asr_inference_paraformer import inference_modelscope
  242. return inference_modelscope(**kwargs)
  243. elif mode == "paraformer_streaming":
  244. from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope
  245. return inference_modelscope(**kwargs)
  246. elif mode == "paraformer_vad":
  247. from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
  248. return inference_modelscope(**kwargs)
  249. elif mode == "paraformer_punc":
  250. logging.info("Unknown decoding mode: {}".format(mode))
  251. return None
  252. elif mode == "paraformer_vad_punc":
  253. from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope
  254. return inference_modelscope(**kwargs)
  255. elif mode == "vad":
  256. from funasr.bin.vad_inference import inference_modelscope
  257. return inference_modelscope(**kwargs)
  258. elif mode == "mfcca":
  259. from funasr.bin.asr_inference_mfcca import inference_modelscope
  260. return inference_modelscope(**kwargs)
  261. elif mode == "rnnt":
  262. from funasr.bin.asr_inference_rnnt import inference_modelscope
  263. return inference_modelscope(**kwargs)
  264. else:
  265. logging.info("Unknown decoding mode: {}".format(mode))
  266. return None
  267. def inference_launch_funasr(**kwargs):
  268. if 'mode' in kwargs:
  269. mode = kwargs['mode']
  270. else:
  271. logging.info("Unknown decoding mode.")
  272. return None
  273. if mode == "asr":
  274. from funasr.bin.asr_inference import inference
  275. return inference(**kwargs)
  276. elif mode == "uniasr":
  277. from funasr.bin.asr_inference_uniasr import inference
  278. return inference(**kwargs)
  279. elif mode == "paraformer":
  280. from funasr.bin.asr_inference_paraformer import inference
  281. return inference(**kwargs)
  282. elif mode == "paraformer_vad_punc":
  283. from funasr.bin.asr_inference_paraformer_vad_punc import inference
  284. return inference(**kwargs)
  285. elif mode == "vad":
  286. from funasr.bin.vad_inference import inference
  287. return inference(**kwargs)
  288. elif mode == "mfcca":
  289. from funasr.bin.asr_inference_mfcca import inference_modelscope
  290. return inference_modelscope(**kwargs)
  291. elif mode == "rnnt":
  292. from funasr.bin.asr_inference_rnnt import inference
  293. return inference(**kwargs)
  294. else:
  295. logging.info("Unknown decoding mode: {}".format(mode))
  296. return None
  297. def main(cmd=None):
  298. print(get_commandline_args(), file=sys.stderr)
  299. parser = get_parser()
  300. parser.add_argument(
  301. "--mode",
  302. type=str,
  303. default="asr",
  304. help="The decoding mode",
  305. )
  306. args = parser.parse_args(cmd)
  307. kwargs = vars(args)
  308. kwargs.pop("config", None)
  309. # set logging messages
  310. logging.basicConfig(
  311. level=args.log_level,
  312. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  313. )
  314. logging.info("Decoding args: {}".format(kwargs))
  315. # gpu setting
  316. if args.ngpu > 0:
  317. jobid = int(args.output_dir.split(".")[-1])
  318. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  319. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  320. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  321. inference_launch_funasr(**kwargs)
  322. if __name__ == "__main__":
  323. main()