build_vad_model.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import torch
  2. from funasr.models.e2e_vad import E2EVadModel
  3. from funasr.models.encoder.fsmn_encoder import FSMN
  4. from funasr.models.frontend.default import DefaultFrontend
  5. from funasr.models.frontend.fused import FusedFrontends
  6. from funasr.models.frontend.s3prl import S3prlFrontend
  7. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  8. from funasr.models.frontend.windowing import SlidingWindow
  9. from funasr.torch_utils.initialize import initialize
  10. from funasr.train.class_choices import ClassChoices
  11. frontend_choices = ClassChoices(
  12. name="frontend",
  13. classes=dict(
  14. default=DefaultFrontend,
  15. sliding_window=SlidingWindow,
  16. s3prl=S3prlFrontend,
  17. fused=FusedFrontends,
  18. wav_frontend=WavFrontend,
  19. wav_frontend_online=WavFrontendOnline,
  20. ),
  21. default="default",
  22. )
  23. encoder_choices = ClassChoices(
  24. "encoder",
  25. classes=dict(
  26. fsmn=FSMN,
  27. ),
  28. type_check=torch.nn.Module,
  29. default="fsmn",
  30. )
  31. model_choices = ClassChoices(
  32. "model",
  33. classes=dict(
  34. e2evad=E2EVadModel,
  35. ),
  36. default="e2evad",
  37. )
  38. class_choices_list = [
  39. # --frontend and --frontend_conf
  40. frontend_choices,
  41. # --encoder and --encoder_conf
  42. encoder_choices,
  43. # --model and --model_conf
  44. model_choices,
  45. ]
  46. def build_vad_model(args):
  47. # frontend
  48. if not hasattr(args, "cmvn_file"):
  49. args.cmvn_file = None
  50. if not hasattr(args, "init"):
  51. args.init = None
  52. if args.input_size is None:
  53. frontend_class = frontend_choices.get_class(args.frontend)
  54. if args.frontend == 'wav_frontend':
  55. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  56. else:
  57. frontend = frontend_class(**args.frontend_conf)
  58. input_size = frontend.output_size()
  59. else:
  60. args.frontend = None
  61. args.frontend_conf = {}
  62. frontend = None
  63. input_size = args.input_size
  64. # encoder
  65. encoder_class = encoder_choices.get_class(args.encoder)
  66. encoder = encoder_class(**args.encoder_conf)
  67. model_class = model_choices.get_class(args.model)
  68. model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
  69. # initialize
  70. if args.init is not None:
  71. initialize(model, args.init)
  72. return model