asr_inference_launch.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #!/usr/bin/env python3
  2. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. import torch
  5. torch.set_num_threads(1)
  6. import argparse
  7. import logging
  8. import os
  9. import sys
  10. from typing import Union, Dict, Any
  11. from funasr.utils import config_argparse
  12. from funasr.utils.cli_utils import get_commandline_args
  13. from funasr.utils.types import str2bool
  14. from funasr.utils.types import str2triple_str
  15. from funasr.utils.types import str_or_none
  16. def get_parser():
  17. parser = config_argparse.ArgumentParser(
  18. description="ASR Decoding",
  19. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  20. )
  21. # Note(kamo): Use '_' instead of '-' as separator.
  22. # '-' is confusing if written in yaml.
  23. parser.add_argument(
  24. "--log_level",
  25. type=lambda x: x.upper(),
  26. default="INFO",
  27. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  28. help="The verbose level of logging",
  29. )
  30. parser.add_argument("--output_dir", type=str, required=True)
  31. parser.add_argument(
  32. "--ngpu",
  33. type=int,
  34. default=0,
  35. help="The number of gpus. 0 indicates CPU mode",
  36. )
  37. parser.add_argument(
  38. "--njob",
  39. type=int,
  40. default=1,
  41. help="The number of jobs for each gpu",
  42. )
  43. parser.add_argument(
  44. "--gpuid_list",
  45. type=str,
  46. default="",
  47. help="The visible gpus",
  48. )
  49. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  50. parser.add_argument(
  51. "--dtype",
  52. default="float32",
  53. choices=["float16", "float32", "float64"],
  54. help="Data type",
  55. )
  56. parser.add_argument(
  57. "--num_workers",
  58. type=int,
  59. default=1,
  60. help="The number of workers used for DataLoader",
  61. )
  62. group = parser.add_argument_group("Input data related")
  63. group.add_argument(
  64. "--data_path_and_name_and_type",
  65. type=str2triple_str,
  66. required=True,
  67. action="append",
  68. )
  69. group.add_argument("--key_file", type=str_or_none)
  70. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  71. group = parser.add_argument_group("The model configuration related")
  72. group.add_argument(
  73. "--vad_infer_config",
  74. type=str,
  75. help="VAD infer configuration",
  76. )
  77. group.add_argument(
  78. "--vad_model_file",
  79. type=str,
  80. help="VAD model parameter file",
  81. )
  82. group.add_argument(
  83. "--cmvn_file",
  84. type=str,
  85. help="Global CMVN file",
  86. )
  87. group.add_argument(
  88. "--asr_train_config",
  89. type=str,
  90. help="ASR training configuration",
  91. )
  92. group.add_argument(
  93. "--asr_model_file",
  94. type=str,
  95. help="ASR model parameter file",
  96. )
  97. group.add_argument(
  98. "--lm_train_config",
  99. type=str,
  100. help="LM training configuration",
  101. )
  102. group.add_argument(
  103. "--lm_file",
  104. type=str,
  105. help="LM parameter file",
  106. )
  107. group.add_argument(
  108. "--word_lm_train_config",
  109. type=str,
  110. help="Word LM training configuration",
  111. )
  112. group.add_argument(
  113. "--word_lm_file",
  114. type=str,
  115. help="Word LM parameter file",
  116. )
  117. group.add_argument(
  118. "--ngram_file",
  119. type=str,
  120. help="N-gram parameter file",
  121. )
  122. group.add_argument(
  123. "--model_tag",
  124. type=str,
  125. help="Pretrained model tag. If specify this option, *_train_config and "
  126. "*_file will be overwritten",
  127. )
  128. group = parser.add_argument_group("Beam-search related")
  129. group.add_argument(
  130. "--batch_size",
  131. type=int,
  132. default=1,
  133. help="The batch size for inference",
  134. )
  135. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  136. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  137. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  138. group.add_argument(
  139. "--maxlenratio",
  140. type=float,
  141. default=0.0,
  142. help="Input length ratio to obtain max output length. "
  143. "If maxlenratio=0.0 (default), it uses a end-detect "
  144. "function "
  145. "to automatically find maximum hypothesis lengths."
  146. "If maxlenratio<0.0, its absolute value is interpreted"
  147. "as a constant max output length",
  148. )
  149. group.add_argument(
  150. "--minlenratio",
  151. type=float,
  152. default=0.0,
  153. help="Input length ratio to obtain min output length",
  154. )
  155. group.add_argument(
  156. "--ctc_weight",
  157. type=float,
  158. default=0.0,
  159. help="CTC weight in joint decoding",
  160. )
  161. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  162. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  163. group.add_argument("--streaming", type=str2bool, default=False)
  164. group = parser.add_argument_group("Text converter related")
  165. group.add_argument(
  166. "--token_type",
  167. type=str_or_none,
  168. default=None,
  169. choices=["char", "bpe", None],
  170. help="The token type for ASR model. "
  171. "If not given, refers from the training args",
  172. )
  173. group.add_argument(
  174. "--bpemodel",
  175. type=str_or_none,
  176. default=None,
  177. help="The model path of sentencepiece. "
  178. "If not given, refers from the training args",
  179. )
  180. group.add_argument("--token_num_relax", type=int, default=1, help="")
  181. group.add_argument("--decoding_ind", type=int, default=0, help="")
  182. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  183. group.add_argument(
  184. "--ctc_weight2",
  185. type=float,
  186. default=0.0,
  187. help="CTC weight in joint decoding",
  188. )
  189. return parser
  190. def inference_launch(**kwargs):
  191. if 'mode' in kwargs:
  192. mode = kwargs['mode']
  193. else:
  194. logging.info("Unknown decoding mode.")
  195. return None
  196. if mode == "asr":
  197. from funasr.bin.asr_inference import inference_modelscope
  198. return inference_modelscope(**kwargs)
  199. elif mode == "uniasr":
  200. from funasr.bin.asr_inference_uniasr import inference_modelscope
  201. return inference_modelscope(**kwargs)
  202. elif mode == "uniasr_vad":
  203. from funasr.bin.asr_inference_uniasr_vad import inference_modelscope
  204. return inference_modelscope(**kwargs)
  205. elif mode == "paraformer":
  206. from funasr.bin.asr_inference_paraformer import inference_modelscope
  207. return inference_modelscope(**kwargs)
  208. elif mode == "paraformer_streaming":
  209. from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope
  210. return inference_modelscope(**kwargs)
  211. elif mode == "paraformer_vad":
  212. from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
  213. return inference_modelscope(**kwargs)
  214. elif mode == "paraformer_punc":
  215. logging.info("Unknown decoding mode: {}".format(mode))
  216. return None
  217. elif mode == "paraformer_vad_punc":
  218. from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope
  219. return inference_modelscope(**kwargs)
  220. elif mode == "vad":
  221. from funasr.bin.vad_inference import inference_modelscope
  222. return inference_modelscope(**kwargs)
  223. elif mode == "mfcca":
  224. from funasr.bin.asr_inference_mfcca import inference_modelscope
  225. return inference_modelscope(**kwargs)
  226. elif mode == "rnnt":
  227. from funasr.bin.asr_inference_rnnt import inference_modelscope
  228. return inference_modelscope(**kwargs)
  229. else:
  230. logging.info("Unknown decoding mode: {}".format(mode))
  231. return None
  232. def inference_launch_funasr(**kwargs):
  233. if 'mode' in kwargs:
  234. mode = kwargs['mode']
  235. else:
  236. logging.info("Unknown decoding mode.")
  237. return None
  238. if mode == "asr":
  239. from funasr.bin.asr_inference import inference
  240. return inference(**kwargs)
  241. elif mode == "uniasr":
  242. from funasr.bin.asr_inference_uniasr import inference
  243. return inference(**kwargs)
  244. elif mode == "paraformer":
  245. from funasr.bin.asr_inference_paraformer import inference
  246. return inference(**kwargs)
  247. elif mode == "paraformer_vad_punc":
  248. from funasr.bin.asr_inference_paraformer_vad_punc import inference
  249. return inference(**kwargs)
  250. elif mode == "vad":
  251. from funasr.bin.vad_inference import inference
  252. return inference(**kwargs)
  253. elif mode == "mfcca":
  254. from funasr.bin.asr_inference_mfcca import inference_modelscope
  255. return inference_modelscope(**kwargs)
  256. else:
  257. logging.info("Unknown decoding mode: {}".format(mode))
  258. return None
  259. def main(cmd=None):
  260. print(get_commandline_args(), file=sys.stderr)
  261. parser = get_parser()
  262. parser.add_argument(
  263. "--mode",
  264. type=str,
  265. default="asr",
  266. help="The decoding mode",
  267. )
  268. args = parser.parse_args(cmd)
  269. kwargs = vars(args)
  270. kwargs.pop("config", None)
  271. # set logging messages
  272. logging.basicConfig(
  273. level=args.log_level,
  274. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  275. )
  276. logging.info("Decoding args: {}".format(kwargs))
  277. # gpu setting
  278. if args.ngpu > 0:
  279. jobid = int(args.output_dir.split(".")[-1])
  280. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  281. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  282. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  283. inference_launch_funasr(**kwargs)
  284. if __name__ == "__main__":
  285. main()