|
@@ -0,0 +1,77 @@
|
|
|
|
|
+import torch
|
|
|
|
|
+
|
|
|
|
|
+from funasr.models.e2e_vad import E2EVadModel
|
|
|
|
|
+from funasr.models.encoder.fsmn_encoder import FSMN
|
|
|
|
|
+from funasr.models.frontend.default import DefaultFrontend
|
|
|
|
|
+from funasr.models.frontend.fused import FusedFrontends
|
|
|
|
|
+from funasr.models.frontend.s3prl import S3prlFrontend
|
|
|
|
|
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
|
|
|
|
+from funasr.models.frontend.windowing import SlidingWindow
|
|
|
|
|
+from funasr.torch_utils.initialize import initialize
|
|
|
|
|
+from funasr.train.class_choices import ClassChoices
|
|
|
|
|
+
|
|
|
|
|
+frontend_choices = ClassChoices(
|
|
|
|
|
+ name="frontend",
|
|
|
|
|
+ classes=dict(
|
|
|
|
|
+ default=DefaultFrontend,
|
|
|
|
|
+ sliding_window=SlidingWindow,
|
|
|
|
|
+ s3prl=S3prlFrontend,
|
|
|
|
|
+ fused=FusedFrontends,
|
|
|
|
|
+ wav_frontend=WavFrontend,
|
|
|
|
|
+ wav_frontend_online=WavFrontendOnline,
|
|
|
|
|
+ ),
|
|
|
|
|
+ default="default",
|
|
|
|
|
+)
|
|
|
|
|
+encoder_choices = ClassChoices(
|
|
|
|
|
+ "encoder",
|
|
|
|
|
+ classes=dict(
|
|
|
|
|
+ fsmn=FSMN,
|
|
|
|
|
+ ),
|
|
|
|
|
+ type_check=torch.nn.Module,
|
|
|
|
|
+ default="fsmn",
|
|
|
|
|
+)
|
|
|
|
|
+model_choices = ClassChoices(
|
|
|
|
|
+ "model",
|
|
|
|
|
+ classes=dict(
|
|
|
|
|
+ e2evad=E2EVadModel,
|
|
|
|
|
+ ),
|
|
|
|
|
+ default="e2evad",
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+class_choices_list = [
|
|
|
|
|
+ # --frontend and --frontend_conf
|
|
|
|
|
+ frontend_choices,
|
|
|
|
|
+ # --encoder and --encoder_conf
|
|
|
|
|
+ encoder_choices,
|
|
|
|
|
+ # --model and --model_conf
|
|
|
|
|
+ model_choices,
|
|
|
|
|
+]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_vad_model(args):
|
|
|
|
|
+ # frontend
|
|
|
|
|
+ if args.input_size is None:
|
|
|
|
|
+ frontend_class = frontend_choices.get_class(args.frontend)
|
|
|
|
|
+ if args.frontend == 'wav_frontend':
|
|
|
|
|
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
|
|
|
|
+ else:
|
|
|
|
|
+ frontend = frontend_class(**args.frontend_conf)
|
|
|
|
|
+ input_size = frontend.output_size()
|
|
|
|
|
+ else:
|
|
|
|
|
+ args.frontend = None
|
|
|
|
|
+ args.frontend_conf = {}
|
|
|
|
|
+ frontend = None
|
|
|
|
|
+ input_size = args.input_size
|
|
|
|
|
+
|
|
|
|
|
+ # encoder
|
|
|
|
|
+ encoder_class = encoder_choices.get_class(args.encoder)
|
|
|
|
|
+ encoder = encoder_class(**args.encoder_conf)
|
|
|
|
|
+
|
|
|
|
|
+ model_class = model_choices.get_class(args.model)
|
|
|
|
|
+ model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
|
|
|
|
|
+
|
|
|
|
|
+ # initialize
|
|
|
|
|
+ if args.init is not None:
|
|
|
|
|
+ initialize(model, args.init)
|
|
|
|
|
+
|
|
|
|
|
+ return model
|