| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- from funasr.layers.global_mvn import GlobalMVN
- from funasr.layers.utterance_mvn import UtteranceMVN
- from funasr.models.data2vec import Data2VecPretrainModel
- from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
- from funasr.models.frontend.default import DefaultFrontend
- from funasr.models.frontend.windowing import SlidingWindow
- from funasr.models.specaug.specaug import SpecAug
- from funasr.torch_utils.initialize import initialize
- from funasr.train.class_choices import ClassChoices
- frontend_choices = ClassChoices(
- name="frontend",
- classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
- default="default",
- )
- specaug_choices = ClassChoices(
- name="specaug",
- classes=dict(specaug=SpecAug),
- default=None,
- optional=True,
- )
- normalize_choices = ClassChoices(
- "normalize",
- classes=dict(
- global_mvn=GlobalMVN,
- utterance_mvn=UtteranceMVN,
- ),
- default=None,
- optional=True,
- )
- encoder_choices = ClassChoices(
- "encoder",
- classes=dict(
- data2vec_encoder=Data2VecEncoder,
- ),
- default="data2vec_encoder",
- )
- model_choices = ClassChoices(
- "model",
- classes=dict(
- data2vec=Data2VecPretrainModel,
- ),
- default="data2vec",
- )
- class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --specaug and --specaug_conf
- specaug_choices,
- # --normalize and --normalize_conf
- normalize_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --model and --model_conf
- model_choices,
- ]
- def build_pretrain_model(args):
- # frontend
- if args.input_size is None:
- frontend_class = frontend_choices.get_class(args.frontend)
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- args.frontend = None
- args.frontend_conf = {}
- frontend = None
- input_size = args.input_size
- # data augmentation for spectrogram
- if args.specaug is not None:
- specaug_class = specaug_choices.get_class(args.specaug)
- specaug = specaug_class(**args.specaug_conf)
- else:
- specaug = None
- # normalization layer
- if args.normalize is not None:
- normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
- else:
- normalize = None
- # encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(
- input_size=input_size,
- **args.encoder_conf,
- )
- if args.model == "data2vec":
- model_class = model_choices.get_class("data2vec")
- model = model_class(
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- )
- else:
- raise NotImplementedError("Not supported model: {}".format(args.model))
- # initialize
- if args.init is not None:
- initialize(model, args.init)
- return model
|