|
|
@@ -57,39 +57,39 @@ class_choices_list = [
|
|
|
|
|
|
|
|
|
def build_pretrain_model(args):
|
|
|
- if args.model_name == "data2vec":
|
|
|
- # frontend
|
|
|
- if args.input_size is None:
|
|
|
- frontend_class = frontend_choices.get_class(args.frontend)
|
|
|
- frontend = frontend_class(**args.frontend_conf)
|
|
|
- input_size = frontend.output_size()
|
|
|
- else:
|
|
|
- args.frontend = None
|
|
|
- args.frontend_conf = {}
|
|
|
- frontend = None
|
|
|
- input_size = args.input_size
|
|
|
+ # frontend
|
|
|
+ if args.input_size is None:
|
|
|
+ frontend_class = frontend_choices.get_class(args.frontend)
|
|
|
+ frontend = frontend_class(**args.frontend_conf)
|
|
|
+ input_size = frontend.output_size()
|
|
|
+ else:
|
|
|
+ args.frontend = None
|
|
|
+ args.frontend_conf = {}
|
|
|
+ frontend = None
|
|
|
+ input_size = args.input_size
|
|
|
|
|
|
- # data augmentation for spectrogram
|
|
|
- if args.specaug is not None:
|
|
|
- specaug_class = specaug_choices.get_class(args.specaug)
|
|
|
- specaug = specaug_class(**args.specaug_conf)
|
|
|
- else:
|
|
|
- specaug = None
|
|
|
+ # data augmentation for spectrogram
|
|
|
+ if args.specaug is not None:
|
|
|
+ specaug_class = specaug_choices.get_class(args.specaug)
|
|
|
+ specaug = specaug_class(**args.specaug_conf)
|
|
|
+ else:
|
|
|
+ specaug = None
|
|
|
|
|
|
- # normalization layer
|
|
|
- if args.normalize is not None:
|
|
|
- normalize_class = normalize_choices.get_class(args.normalize)
|
|
|
- normalize = normalize_class(**args.normalize_conf)
|
|
|
- else:
|
|
|
- normalize = None
|
|
|
+ # normalization layer
|
|
|
+ if args.normalize is not None:
|
|
|
+ normalize_class = normalize_choices.get_class(args.normalize)
|
|
|
+ normalize = normalize_class(**args.normalize_conf)
|
|
|
+ else:
|
|
|
+ normalize = None
|
|
|
|
|
|
- # encoder
|
|
|
- encoder_class = encoder_choices.get_class(args.encoder)
|
|
|
- encoder = encoder_class(
|
|
|
- input_size=input_size,
|
|
|
- **args.encoder_conf,
|
|
|
- )
|
|
|
+ # encoder
|
|
|
+ encoder_class = encoder_choices.get_class(args.encoder)
|
|
|
+ encoder = encoder_class(
|
|
|
+ input_size=input_size,
|
|
|
+ **args.encoder_conf,
|
|
|
+ )
|
|
|
|
|
|
+ if args.model_name == "data2vec":
|
|
|
model_class = model_choices.get_class("data2vec")
|
|
|
model = model_class(
|
|
|
frontend=frontend,
|
|
|
@@ -97,9 +97,11 @@ def build_pretrain_model(args):
|
|
|
normalize=normalize,
|
|
|
encoder=encoder,
|
|
|
)
|
|
|
+ else:
|
|
|
+ raise NotImplementedError("Not supported model: {}".format(args.model))
|
|
|
|
|
|
- # 7. Initialize
|
|
|
- if args.init is not None:
|
|
|
- initialize(model, args.init)
|
|
|
+ # initialize
|
|
|
+ if args.init is not None:
|
|
|
+ initialize(model, args.init)
|
|
|
|
|
|
- return model
|
|
|
+ return model
|