argument.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import sys
  6. from funasr.utils.types import str2bool
  7. from funasr.utils.types import str2triple_str
  8. from funasr.utils.types import str_or_none
  9. from funasr.utils import config_argparse
  10. import argparse
  11. def get_parser():
  12. parser = config_argparse.ArgumentParser(
  13. description="ASR Decoding",
  14. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  15. )
  16. # Note(kamo): Use '_' instead of '-' as separator.
  17. # '-' is confusing if written in yaml.
  18. parser.add_argument(
  19. "--log_level",
  20. type=lambda x: x.upper(),
  21. default="INFO",
  22. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  23. help="The verbose level of logging",
  24. )
  25. parser.add_argument("--output_dir", type=str, default=None)
  26. parser.add_argument(
  27. "--ngpu",
  28. type=int,
  29. default=1,
  30. help="The number of gpus. 0 indicates CPU mode",
  31. )
  32. parser.add_argument(
  33. "--njob",
  34. type=int,
  35. default=1,
  36. help="The number of jobs for each gpu",
  37. )
  38. parser.add_argument(
  39. "--gpuid_list",
  40. type=str,
  41. default="",
  42. help="The visible gpus",
  43. )
  44. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  45. parser.add_argument(
  46. "--dtype",
  47. default="float32",
  48. choices=["float16", "float32", "float64"],
  49. help="Data type",
  50. )
  51. parser.add_argument(
  52. "--num_workers",
  53. type=int,
  54. default=1,
  55. help="The number of workers used for DataLoader",
  56. )
  57. group = parser.add_argument_group("Input data related")
  58. group.add_argument(
  59. "--data_path_and_name_and_type",
  60. type=str2triple_str,
  61. required=False,
  62. action="append",
  63. )
  64. group.add_argument("--key_file", type=str_or_none)
  65. parser.add_argument(
  66. "--hotword",
  67. type=str_or_none,
  68. default=None,
  69. help="hotword file path or hotwords seperated by space"
  70. )
  71. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  72. group.add_argument(
  73. "--mc",
  74. type=bool,
  75. default=False,
  76. help="MultiChannel input",
  77. )
  78. group = parser.add_argument_group("The model configuration related")
  79. group.add_argument(
  80. "--vad_infer_config",
  81. type=str,
  82. help="VAD infer configuration",
  83. )
  84. group.add_argument(
  85. "--vad_model_file",
  86. type=str,
  87. help="VAD model parameter file",
  88. )
  89. group.add_argument(
  90. "--punc_infer_config",
  91. type=str,
  92. help="PUNC infer configuration",
  93. )
  94. group.add_argument(
  95. "--punc_model_file",
  96. type=str,
  97. help="PUNC model parameter file",
  98. )
  99. group.add_argument(
  100. "--cmvn_file",
  101. type=str,
  102. help="Global CMVN file",
  103. )
  104. group.add_argument(
  105. "--asr_train_config",
  106. type=str,
  107. help="ASR training configuration",
  108. )
  109. group.add_argument(
  110. "--asr_model_file",
  111. type=str,
  112. help="ASR model parameter file",
  113. )
  114. group.add_argument(
  115. "--sv_model_file",
  116. type=str,
  117. help="SV model parameter file",
  118. )
  119. group.add_argument(
  120. "--lm_train_config",
  121. type=str,
  122. help="LM training configuration",
  123. )
  124. group.add_argument(
  125. "--lm_file",
  126. type=str,
  127. help="LM parameter file",
  128. )
  129. group.add_argument(
  130. "--word_lm_train_config",
  131. type=str,
  132. help="Word LM training configuration",
  133. )
  134. group.add_argument(
  135. "--word_lm_file",
  136. type=str,
  137. help="Word LM parameter file",
  138. )
  139. group.add_argument(
  140. "--ngram_file",
  141. type=str,
  142. help="N-gram parameter file",
  143. )
  144. group.add_argument(
  145. "--model_tag",
  146. type=str,
  147. help="Pretrained model tag. If specify this option, *_train_config and "
  148. "*_file will be overwritten",
  149. )
  150. group.add_argument(
  151. "--beam_search_config",
  152. default={},
  153. help="The keyword arguments for transducer beam search.",
  154. )
  155. group = parser.add_argument_group("Beam-search related")
  156. group.add_argument(
  157. "--batch_size",
  158. type=int,
  159. default=1,
  160. help="The batch size for inference",
  161. )
  162. group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
  163. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  164. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  165. group.add_argument(
  166. "--maxlenratio",
  167. type=float,
  168. default=0.0,
  169. help="Input length ratio to obtain max output length. "
  170. "If maxlenratio=0.0 (default), it uses a end-detect "
  171. "function "
  172. "to automatically find maximum hypothesis lengths."
  173. "If maxlenratio<0.0, its absolute value is interpreted"
  174. "as a constant max output length",
  175. )
  176. group.add_argument(
  177. "--minlenratio",
  178. type=float,
  179. default=0.0,
  180. help="Input length ratio to obtain min output length",
  181. )
  182. group.add_argument(
  183. "--ctc_weight",
  184. type=float,
  185. default=0.0,
  186. help="CTC weight in joint decoding",
  187. )
  188. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  189. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  190. group.add_argument("--streaming", type=str2bool, default=False)
  191. group.add_argument("--fake_streaming", type=str2bool, default=False)
  192. group.add_argument("--full_utt", type=str2bool, default=False)
  193. group.add_argument("--chunk_size", type=int, default=16)
  194. group.add_argument("--left_context", type=int, default=16)
  195. group.add_argument("--right_context", type=int, default=0)
  196. group.add_argument(
  197. "--display_partial_hypotheses",
  198. type=bool,
  199. default=False,
  200. help="Whether to display partial hypotheses during chunk-by-chunk inference.",
  201. )
  202. group = parser.add_argument_group("Dynamic quantization related")
  203. group.add_argument(
  204. "--quantize_asr_model",
  205. type=bool,
  206. default=False,
  207. help="Apply dynamic quantization to ASR model.",
  208. )
  209. group.add_argument(
  210. "--quantize_modules",
  211. nargs="*",
  212. default=None,
  213. help="""Module names to apply dynamic quantization on.
  214. The module names are provided as a list, where each name is separated
  215. by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
  216. Each specified name should be an attribute of 'torch.nn', e.g.:
  217. torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
  218. )
  219. group.add_argument(
  220. "--quantize_dtype",
  221. type=str,
  222. default="qint8",
  223. choices=["float16", "qint8"],
  224. help="Dtype for dynamic quantization.",
  225. )
  226. group = parser.add_argument_group("Text converter related")
  227. group.add_argument(
  228. "--token_type",
  229. type=str_or_none,
  230. default=None,
  231. choices=["char", "bpe", None],
  232. help="The token type for ASR model. "
  233. "If not given, refers from the training args",
  234. )
  235. group.add_argument(
  236. "--bpemodel",
  237. type=str_or_none,
  238. default=None,
  239. help="The model path of sentencepiece. "
  240. "If not given, refers from the training args",
  241. )
  242. group.add_argument("--token_num_relax", type=int, default=1, help="")
  243. group.add_argument("--decoding_ind", type=int, default=0, help="")
  244. group.add_argument("--decoding_mode", type=str, default="model1", help="")
  245. group.add_argument(
  246. "--ctc_weight2",
  247. type=float,
  248. default=0.0,
  249. help="CTC weight in joint decoding",
  250. )
  251. return parser