|
|
@@ -209,7 +209,7 @@ def build_diar_model(args):
|
|
|
encoder_class = encoder_choices.get_class(args.encoder)
|
|
|
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
|
|
|
|
|
- if args.model_name == "sond":
|
|
|
+ if args.model == "sond":
|
|
|
# data augmentation for spectrogram
|
|
|
if args.specaug is not None:
|
|
|
specaug_class = specaug_choices.get_class(args.specaug)
|
|
|
@@ -247,11 +247,7 @@ def build_diar_model(args):
|
|
|
|
|
|
# decoder
|
|
|
decoder_class = decoder_choices.get_class(args.decoder)
|
|
|
- decoder = decoder_class(
|
|
|
- vocab_size=vocab_size,
|
|
|
- encoder_output_size=encoder.output_size(),
|
|
|
- **args.decoder_conf,
|
|
|
- )
|
|
|
+ decoder = decoder_class(**args.decoder_conf)
|
|
|
|
|
|
# logger aggregator
|
|
|
if getattr(args, "label_aggregator", None) is not None:
|