asr_inference_launch.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. #!/usr/bin/env python3
  2. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. import argparse
  5. import logging
  6. import os
  7. import sys
  8. from typing import Union, Dict, Any
  9. from funasr.utils import config_argparse
  10. from funasr.utils.cli_utils import get_commandline_args
  11. from funasr.utils.types import str2bool
  12. from funasr.utils.types import str2triple_str
  13. from funasr.utils.types import str_or_none
  14. def get_parser():
  15. parser = config_argparse.ArgumentParser(
  16. description="ASR Decoding",
  17. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  18. )
  19. # Note(kamo): Use '_' instead of '-' as separator.
  20. # '-' is confusing if written in yaml.
  21. parser.add_argument(
  22. "--log_level",
  23. type=lambda x: x.upper(),
  24. default="INFO",
  25. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  26. help="The verbose level of logging",
  27. )
  28. parser.add_argument("--output_dir", type=str, required=True)
  29. parser.add_argument(
  30. "--ngpu",
  31. type=int,
  32. default=0,
  33. help="The number of gpus. 0 indicates CPU mode",
  34. )
  35. parser.add_argument(
  36. "--njob",
  37. type=int,
  38. default=1,
  39. help="The number of jobs for each gpu",
  40. )
  41. parser.add_argument(
  42. "--gpuid_list",
  43. type=str,
  44. default="",
  45. help="The visible gpus",
  46. )
  47. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  48. parser.add_argument(
  49. "--dtype",
  50. default="float32",
  51. choices=["float16", "float32", "float64"],
  52. help="Data type",
  53. )
  54. parser.add_argument(
  55. "--num_workers",
  56. type=int,
  57. default=1,
  58. help="The number of workers used for DataLoader",
  59. )
  60. group = parser.add_argument_group("Input data related")
  61. group.add_argument(
  62. "--data_path_and_name_and_type",
  63. type=str2triple_str,
  64. required=True,
  65. action="append",
  66. )
  67. group.add_argument("--key_file", type=str_or_none)
  68. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  69. group = parser.add_argument_group("The model configuration related")
  70. group.add_argument(
  71. "--vad_infer_config",
  72. type=str,
  73. help="VAD infer configuration",
  74. )
  75. group.add_argument(
  76. "--vad_model_file",
  77. type=str,
  78. help="VAD model parameter file",
  79. )
  80. group.add_argument(
  81. "--cmvn_file",
  82. type=str,
  83. help="Global CMVN file",
  84. )
  85. group.add_argument(
  86. "--asr_train_config",
  87. type=str,
  88. help="ASR training configuration",
  89. )
  90. group.add_argument(
  91. "--asr_model_file",
  92. type=str,
  93. help="ASR model parameter file",
  94. )
  95. group.add_argument(
  96. "--lm_train_config",
  97. type=str,
  98. help="LM training configuration",
  99. )
  100. group.add_argument(
  101. "--lm_file",
  102. type=str,
  103. help="LM parameter file",
  104. )
  105. group.add_argument(
  106. "--word_lm_train_config",
  107. type=str,
  108. help="Word LM training configuration",
  109. )
  110. group.add_argument(
  111. "--word_lm_file",
  112. type=str,
  113. help="Word LM parameter file",
  114. )
  115. group.add_argument(
  116. "--ngram_file",
  117. type=str,
  118. help="N-gram parameter file",
  119. )
  120. group.add_argument(
  121. "--model_tag",
  122. type=str,
  123. help="Pretrained model tag. If specify this option, *_train_config and "
  124. "*_file will be overwritten",
  125. )
  126. group = parser.add_argument_group("Beam-search related")
  127. group.add_argument(
  128. "--batch_size",
  129. type=int,
  130. default=1,
  131. help="The batch size for inference",
  132. )
  133. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  134. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  135. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  136. group.add_argument(
  137. "--maxlenratio",
  138. type=float,
  139. default=0.0,
  140. help="Input length ratio to obtain max output length. "
  141. "If maxlenratio=0.0 (default), it uses a end-detect "
  142. "function "
  143. "to automatically find maximum hypothesis lengths."
  144. "If maxlenratio<0.0, its absolute value is interpreted"
  145. "as a constant max output length",
  146. )
  147. group.add_argument(
  148. "--minlenratio",
  149. type=float,
  150. default=0.0,
  151. help="Input length ratio to obtain min output length",
  152. )
  153. group.add_argument(
  154. "--ctc_weight",
  155. type=float,
  156. default=0.0,
  157. help="CTC weight in joint decoding",
  158. )
  159. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  160. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  161. group.add_argument("--streaming", type=str2bool, default=False)
  162. group = parser.add_argument_group("Text converter related")
  163. group.add_argument(
  164. "--token_type",
  165. type=str_or_none,
  166. default=None,
  167. choices=["char", "bpe", None],
  168. help="The token type for ASR model. "
  169. "If not given, refers from the training args",
  170. )
  171. group.add_argument(
  172. "--bpemodel",
  173. type=str_or_none,
  174. default=None,
  175. help="The model path of sentencepiece. "
  176. "If not given, refers from the training args",
  177. )
  178. group.add_argument("--token_num_relax", type=int, default=1, help="")
  179. group.add_argument("--decoding_ind", type=int, default=0, help="")
  180. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  181. group.add_argument(
  182. "--ctc_weight2",
  183. type=float,
  184. default=0.0,
  185. help="CTC weight in joint decoding",
  186. )
  187. return parser
  188. def inference_launch(**kwargs):
  189. if 'mode' in kwargs:
  190. mode = kwargs['mode']
  191. else:
  192. logging.info("Unknown decoding mode.")
  193. return None
  194. if mode == "asr":
  195. from funasr.bin.asr_inference import inference_modelscope
  196. return inference_modelscope(**kwargs)
  197. elif mode == "uniasr":
  198. from funasr.bin.asr_inference_uniasr import inference_modelscope
  199. return inference_modelscope(**kwargs)
  200. elif mode == "uniasr_vad":
  201. from funasr.bin.asr_inference_uniasr_vad import inference_modelscope
  202. return inference_modelscope(**kwargs)
  203. elif mode == "paraformer":
  204. from funasr.bin.asr_inference_paraformer import inference_modelscope
  205. return inference_modelscope(**kwargs)
  206. elif mode == "paraformer_vad":
  207. from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
  208. return inference_modelscope(**kwargs)
  209. elif mode == "paraformer_punc":
  210. logging.info("Unknown decoding mode: {}".format(mode))
  211. return None
  212. elif mode == "paraformer_vad_punc":
  213. from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope
  214. return inference_modelscope(**kwargs)
  215. elif mode == "vad":
  216. from funasr.bin.vad_inference import inference_modelscope
  217. return inference_modelscope(**kwargs)
  218. elif mode == "mfcca":
  219. from funasr.bin.asr_inference_mfcca import inference_modelscope
  220. return inference_modelscope(**kwargs)
  221. else:
  222. logging.info("Unknown decoding mode: {}".format(mode))
  223. return None
  224. def inference_launch_funasr(**kwargs):
  225. if 'mode' in kwargs:
  226. mode = kwargs['mode']
  227. else:
  228. logging.info("Unknown decoding mode.")
  229. return None
  230. if mode == "asr":
  231. from funasr.bin.asr_inference import inference
  232. return inference(**kwargs)
  233. elif mode == "uniasr":
  234. from funasr.bin.asr_inference_uniasr import inference
  235. return inference(**kwargs)
  236. elif mode == "paraformer":
  237. from funasr.bin.asr_inference_paraformer import inference
  238. return inference(**kwargs)
  239. elif mode == "paraformer_vad_punc":
  240. from funasr.bin.asr_inference_paraformer_vad_punc import inference
  241. return inference(**kwargs)
  242. elif mode == "vad":
  243. from funasr.bin.vad_inference import inference
  244. return inference(**kwargs)
  245. elif mode == "mfcca":
  246. from funasr.bin.asr_inference_mfcca import inference_modelscope
  247. return inference_modelscope(**kwargs)
  248. else:
  249. logging.info("Unknown decoding mode: {}".format(mode))
  250. return None
  251. def main(cmd=None):
  252. print(get_commandline_args(), file=sys.stderr)
  253. parser = get_parser()
  254. parser.add_argument(
  255. "--mode",
  256. type=str,
  257. default="asr",
  258. help="The decoding mode",
  259. )
  260. args = parser.parse_args(cmd)
  261. kwargs = vars(args)
  262. kwargs.pop("config", None)
  263. # set logging messages
  264. logging.basicConfig(
  265. level=args.log_level,
  266. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  267. )
  268. logging.info("Decoding args: {}".format(kwargs))
  269. # gpu setting
  270. if args.ngpu > 0:
  271. jobid = int(args.output_dir.split(".")[-1])
  272. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  273. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  274. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  275. inference_launch_funasr(**kwargs)
  276. if __name__ == "__main__":
  277. main()