build_vad_model.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. from funasr.models.e2e_vad import E2EVadModel
  3. from funasr.models.encoder.fsmn_encoder import FSMN
  4. from funasr.models.frontend.default import DefaultFrontend
  5. from funasr.models.frontend.fused import FusedFrontends
  6. from funasr.models.frontend.s3prl import S3prlFrontend
  7. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  8. from funasr.models.frontend.windowing import SlidingWindow
  9. from funasr.torch_utils.initialize import initialize
  10. from funasr.train.class_choices import ClassChoices
  11. frontend_choices = ClassChoices(
  12. name="frontend",
  13. classes=dict(
  14. default=DefaultFrontend,
  15. sliding_window=SlidingWindow,
  16. s3prl=S3prlFrontend,
  17. fused=FusedFrontends,
  18. wav_frontend=WavFrontend,
  19. wav_frontend_online=WavFrontendOnline,
  20. ),
  21. default="default",
  22. )
  23. encoder_choices = ClassChoices(
  24. "encoder",
  25. classes=dict(
  26. fsmn=FSMN,
  27. ),
  28. type_check=torch.nn.Module,
  29. default="fsmn",
  30. )
  31. model_choices = ClassChoices(
  32. "model",
  33. classes=dict(
  34. e2evad=E2EVadModel,
  35. ),
  36. default="e2evad",
  37. )
  38. class_choices_list = [
  39. # --frontend and --frontend_conf
  40. frontend_choices,
  41. # --encoder and --encoder_conf
  42. encoder_choices,
  43. # --model and --model_conf
  44. model_choices,
  45. ]
  46. def build_vad_model(args):
  47. # frontend
  48. if args.input_size is None:
  49. frontend_class = frontend_choices.get_class(args.frontend)
  50. if args.frontend == 'wav_frontend':
  51. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  52. else:
  53. frontend = frontend_class(**args.frontend_conf)
  54. input_size = frontend.output_size()
  55. else:
  56. args.frontend = None
  57. args.frontend_conf = {}
  58. frontend = None
  59. input_size = args.input_size
  60. # encoder
  61. encoder_class = encoder_choices.get_class(args.encoder)
  62. encoder = encoder_class(**args.encoder_conf)
  63. model_class = model_choices.get_class(args.model)
  64. model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
  65. # initialize
  66. if args.init is not None:
  67. initialize(model, args.init)
  68. return model