build_model.py 1.2 KB

12345678910111213141516171819202122232425262728293031
  1. from funasr.build_utils.build_asr_model import build_asr_model
  2. from funasr.build_utils.build_diar_model import build_diar_model
  3. from funasr.build_utils.build_lm_model import build_lm_model
  4. from funasr.build_utils.build_pretrain_model import build_pretrain_model
  5. from funasr.build_utils.build_punc_model import build_punc_model
  6. from funasr.build_utils.build_sv_model import build_sv_model
  7. from funasr.build_utils.build_vad_model import build_vad_model
  8. from funasr.build_utils.build_ss_model import build_ss_model
  9. def build_model(args):
  10. if args.task_name == "asr":
  11. model = build_asr_model(args)
  12. elif args.task_name == "pretrain":
  13. model = build_pretrain_model(args)
  14. elif args.task_name == "lm":
  15. model = build_lm_model(args)
  16. elif args.task_name == "punc":
  17. model = build_punc_model(args)
  18. elif args.task_name == "vad":
  19. model = build_vad_model(args)
  20. elif args.task_name == "diar":
  21. model = build_diar_model(args)
  22. elif args.task_name == "sv":
  23. model = build_sv_model(args)
  24. elif args.task_name == "ss":
  25. model = build_ss_model(args)
  26. else:
  27. raise NotImplementedError("Not supported task: {}".format(args.task_name))
  28. return model