build_model.py 968 B

12345678910111213141516171819202122232425
  1. from funasr.build_utils.build_asr_model import build_asr_model
  2. from funasr.build_utils.build_lm_model import build_lm_model
  3. from funasr.build_utils.build_pretrain_model import build_pretrain_model
  4. from funasr.build_utils.build_punc_model import build_punc_model
  5. from funasr.build_utils.build_vad_model import build_vad_model
  6. from funasr.build_utils.build_diar_model import build_diar_model
  7. def build_model(args):
  8. if args.task_name == "asr":
  9. model = build_asr_model(args)
  10. elif args.task_name == "pretrain":
  11. model = build_pretrain_model(args)
  12. elif args.task_name == "lm":
  13. model = build_lm_model(args)
  14. elif args.task_name == "punc":
  15. model = build_punc_model(args)
  16. elif args.task_name == "vad":
  17. model = build_vad_model(args)
  18. elif args.task_name == "diar":
  19. model = build_diar_model(args)
  20. else:
  21. raise NotImplementedError("Not supported task: {}".format(args.task_name))
  22. return model