build_args.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from funasr.models.ctc import CTC
  2. from funasr.utils import config_argparse
  3. from funasr.utils.get_default_kwargs import get_default_kwargs
  4. from funasr.utils.nested_dict_action import NestedDictAction
  5. from funasr.utils.types import int_or_none
  6. from funasr.utils.types import str2bool
  7. from funasr.utils.types import str_or_none
  8. def build_args(args, parser, extra_task_params):
  9. task_parser = config_argparse.ArgumentParser("Task related config")
  10. if args.task_name == "asr":
  11. from funasr.build_utils.build_asr_model import class_choices_list
  12. for class_choices in class_choices_list:
  13. class_choices.add_arguments(task_parser)
  14. task_parser.add_argument(
  15. "--split_with_space",
  16. type=str2bool,
  17. default=True,
  18. help="whether to split text using <space>",
  19. )
  20. task_parser.add_argument(
  21. "--seg_dict_file",
  22. type=str,
  23. default=None,
  24. help="seg_dict_file for text processing",
  25. )
  26. task_parser.add_argument(
  27. "--input_size",
  28. type=int_or_none,
  29. default=None,
  30. help="The number of input dimension of the feature",
  31. )
  32. task_parser.add_argument(
  33. "--ctc_conf",
  34. action=NestedDictAction,
  35. default=get_default_kwargs(CTC),
  36. help="The keyword arguments for CTC class.",
  37. )
  38. task_parser.add_argument(
  39. "--cmvn_file",
  40. type=str_or_none,
  41. default=None,
  42. help="The path of cmvn file.",
  43. )
  44. elif args.task_name == "pretrain":
  45. from funasr.build_utils.build_pretrain_model import class_choices_list
  46. for class_choices in class_choices_list:
  47. class_choices.add_arguments(task_parser)
  48. task_parser.add_argument(
  49. "--input_size",
  50. type=int_or_none,
  51. default=None,
  52. help="The number of input dimension of the feature",
  53. )
  54. task_parser.add_argument(
  55. "--cmvn_file",
  56. type=str_or_none,
  57. default=None,
  58. help="The path of cmvn file.",
  59. )
  60. elif args.task_name == "lm":
  61. from funasr.build_utils.build_lm_model import class_choices_list
  62. for class_choices in class_choices_list:
  63. class_choices.add_arguments(task_parser)
  64. elif args.task_name == "punc":
  65. from funasr.build_utils.build_punc_model import class_choices_list
  66. for class_choices in class_choices_list:
  67. class_choices.add_arguments(task_parser)
  68. elif args.task_name == "vad":
  69. from funasr.build_utils.build_vad_model import class_choices_list
  70. for class_choices in class_choices_list:
  71. class_choices.add_arguments(task_parser)
  72. task_parser.add_argument(
  73. "--input_size",
  74. type=int_or_none,
  75. default=None,
  76. help="The number of input dimension of the feature",
  77. )
  78. task_parser.add_argument(
  79. "--cmvn_file",
  80. type=str_or_none,
  81. default=None,
  82. help="The path of cmvn file.",
  83. )
  84. elif args.task_name == "diar":
  85. from funasr.build_utils.build_diar_model import class_choices_list
  86. for class_choices in class_choices_list:
  87. class_choices.add_arguments(task_parser)
  88. task_parser.add_argument(
  89. "--input_size",
  90. type=int_or_none,
  91. default=None,
  92. help="The number of input dimension of the feature",
  93. )
  94. elif args.task_name == "sv":
  95. from funasr.build_utils.build_sv_model import class_choices_list
  96. for class_choices in class_choices_list:
  97. class_choices.add_arguments(task_parser)
  98. task_parser.add_argument(
  99. "--input_size",
  100. type=int_or_none,
  101. default=None,
  102. help="The number of input dimension of the feature",
  103. )
  104. else:
  105. raise NotImplementedError("Not supported task: {}".format(args.task_name))
  106. for action in parser._actions:
  107. if not any(action.dest == a.dest for a in task_parser._actions):
  108. task_parser._add_action(action)
  109. task_parser.set_defaults(**vars(args))
  110. task_args = task_parser.parse_args(extra_task_params)
  111. return task_args