build_args.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from funasr.models.ctc import CTC
  2. from funasr.utils import config_argparse
  3. from funasr.utils.get_default_kwargs import get_default_kwargs
  4. from funasr.utils.nested_dict_action import NestedDictAction
  5. from funasr.utils.types import int_or_none
  6. from funasr.utils.types import str2bool
  7. from funasr.utils.types import str_or_none
  8. def build_args(args, parser, extra_task_params):
  9. task_parser = config_argparse.ArgumentParser("Task related config")
  10. if args.task_name == "asr":
  11. from funasr.build_utils.build_asr_model import class_choices_list
  12. for class_choices in class_choices_list:
  13. class_choices.add_arguments(task_parser)
  14. task_parser.add_argument(
  15. "--split_with_space",
  16. type=str2bool,
  17. default=True,
  18. help="whether to split text using <space>",
  19. )
  20. task_parser.add_argument(
  21. "--seg_dict_file",
  22. type=str,
  23. default=None,
  24. help="seg_dict_file for text processing",
  25. )
  26. task_parser.add_argument(
  27. "--input_size",
  28. type=int_or_none,
  29. default=None,
  30. help="The number of input dimension of the feature",
  31. )
  32. task_parser.add_argument(
  33. "--ctc_conf",
  34. action=NestedDictAction,
  35. default=get_default_kwargs(CTC),
  36. help="The keyword arguments for CTC class.",
  37. )
  38. task_parser.add_argument(
  39. "--cmvn_file",
  40. type=str_or_none,
  41. default=None,
  42. help="The path of cmvn file.",
  43. )
  44. elif args.task_name == "pretrain":
  45. from funasr.build_utils.build_pretrain_model import class_choices_list
  46. for class_choices in class_choices_list:
  47. class_choices.add_arguments(task_parser)
  48. task_parser.add_argument(
  49. "--input_size",
  50. type=int_or_none,
  51. default=None,
  52. help="The number of input dimension of the feature",
  53. )
  54. elif args.task_name == "lm":
  55. from funasr.build_utils.build_lm_model import class_choices_list
  56. for class_choices in class_choices_list:
  57. class_choices.add_arguments(task_parser)
  58. elif args.task_name == "punc":
  59. from funasr.build_utils.build_punc_model import class_choices_list
  60. for class_choices in class_choices_list:
  61. class_choices.add_arguments(task_parser)
  62. elif args.task_name == "vad":
  63. from funasr.build_utils.build_vad_model import class_choices_list
  64. for class_choices in class_choices_list:
  65. class_choices.add_arguments(task_parser)
  66. task_parser.add_argument(
  67. "--input_size",
  68. type=int_or_none,
  69. default=None,
  70. help="The number of input dimension of the feature",
  71. )
  72. task_parser.add_argument(
  73. "--cmvn_file",
  74. type=str_or_none,
  75. default=None,
  76. help="The path of cmvn file.",
  77. )
  78. elif args.task_name == "diar":
  79. from funasr.build_utils.build_diar_model import class_choices_list
  80. for class_choices in class_choices_list:
  81. class_choices.add_arguments(task_parser)
  82. elif args.task_name == "sv":
  83. from funasr.build_utils.build_sv_model import class_choices_list
  84. for class_choices in class_choices_list:
  85. class_choices.add_arguments(task_parser)
  86. task_parser.add_argument(
  87. "--input_size",
  88. type=int_or_none,
  89. default=None,
  90. help="The number of input dimension of the feature",
  91. )
  92. else:
  93. raise NotImplementedError("Not supported task: {}".format(args.task_name))
  94. for action in parser._actions:
  95. if not any(action.dest == a.dest for a in task_parser._actions):
  96. task_parser._add_action(action)
  97. task_parser.set_defaults(**vars(args))
  98. task_args = task_parser.parse_args(extra_task_params)
  99. return task_args