|
|
@@ -132,6 +132,8 @@ model_choices = ClassChoices(
|
|
|
neatcontextual_paraformer=NeatContextualParaformer,
|
|
|
mfcca=MFCCA,
|
|
|
timestamp_prediction=TimestampPredictor,
|
|
|
+ rnnt=TransducerModel,
|
|
|
+ rnnt_unified=UnifiedTransducerModel,
|
|
|
),
|
|
|
type_check=FunASRModel,
|
|
|
default="asr",
|
|
|
@@ -1453,7 +1455,7 @@ class ASRTransducerTask(ASRTask):
|
|
|
decoder_output_size = decoder.output_size
|
|
|
|
|
|
if getattr(args, "decoder", None) is not None:
|
|
|
- att_decoder_class = decoder_choices.get_class(args.att_decoder)
|
|
|
+ att_decoder_class = decoder_choices.get_class(args.decoder)
|
|
|
|
|
|
att_decoder = att_decoder_class(
|
|
|
vocab_size=vocab_size,
|
|
|
@@ -1471,35 +1473,23 @@ class ASRTransducerTask(ASRTask):
|
|
|
)
|
|
|
|
|
|
# 7. Build model
|
|
|
+ try:
|
|
|
+ model_class = model_choices.get_class(args.model)
|
|
|
+ except AttributeError:
|
|
|
+ model_class = model_choices.get_class("asr")
|
|
|
|
|
|
- if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
|
|
|
- model = UnifiedTransducerModel(
|
|
|
- vocab_size=vocab_size,
|
|
|
- token_list=token_list,
|
|
|
- frontend=frontend,
|
|
|
- specaug=specaug,
|
|
|
- normalize=normalize,
|
|
|
- encoder=encoder,
|
|
|
- decoder=decoder,
|
|
|
- att_decoder=att_decoder,
|
|
|
- joint_network=joint_network,
|
|
|
- **args.model_conf,
|
|
|
- )
|
|
|
-
|
|
|
- else:
|
|
|
- model = TransducerModel(
|
|
|
- vocab_size=vocab_size,
|
|
|
- token_list=token_list,
|
|
|
- frontend=frontend,
|
|
|
- specaug=specaug,
|
|
|
- normalize=normalize,
|
|
|
- encoder=encoder,
|
|
|
- decoder=decoder,
|
|
|
- att_decoder=att_decoder,
|
|
|
- joint_network=joint_network,
|
|
|
- **args.model_conf,
|
|
|
- )
|
|
|
-
|
|
|
+ model = model_class(
|
|
|
+ vocab_size=vocab_size,
|
|
|
+ token_list=token_list,
|
|
|
+ frontend=frontend,
|
|
|
+ specaug=specaug,
|
|
|
+ normalize=normalize,
|
|
|
+ encoder=encoder,
|
|
|
+ decoder=decoder,
|
|
|
+ att_decoder=att_decoder,
|
|
|
+ joint_network=joint_network,
|
|
|
+ **args.model_conf,
|
|
|
+ )
|
|
|
# 8. Initialize model
|
|
|
if args.init is not None:
|
|
|
raise NotImplementedError(
|