build_model.py 1.1 KB

12345678910111213141516171819202122232425262728
  1. from funasr.build_utils.build_asr_model import build_asr_model
  2. from funasr.build_utils.build_diar_model import build_diar_model
  3. from funasr.build_utils.build_lm_model import build_lm_model
  4. from funasr.build_utils.build_pretrain_model import build_pretrain_model
  5. from funasr.build_utils.build_punc_model import build_punc_model
  6. from funasr.build_utils.build_sv_model import build_sv_model
  7. from funasr.build_utils.build_vad_model import build_vad_model
  8. def build_model(args):
  9. if args.task_name == "asr":
  10. model = build_asr_model(args)
  11. elif args.task_name == "pretrain":
  12. model = build_pretrain_model(args)
  13. elif args.task_name == "lm":
  14. model = build_lm_model(args)
  15. elif args.task_name == "punc":
  16. model = build_punc_model(args)
  17. elif args.task_name == "vad":
  18. model = build_vad_model(args)
  19. elif args.task_name == "diar":
  20. model = build_diar_model(args)
  21. elif args.task_name == "sv":
  22. model = build_sv_model(args)
  23. else:
  24. raise NotImplementedError("Not supported task: {}".format(args.task_name))
  25. return model