|
@@ -224,6 +224,15 @@ rnnt_decoder_choices = ClassChoices(
|
|
|
default="rnnt",
|
|
default="rnnt",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+joint_network_choices = ClassChoices(
|
|
|
|
|
+ name="joint_network",
|
|
|
|
|
+ classes=dict(
|
|
|
|
|
+ joint_network=JointNetwork,
|
|
|
|
|
+ ),
|
|
|
|
|
+ default="joint_network",
|
|
|
|
|
+ optional=True,
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
predictor_choices = ClassChoices(
|
|
predictor_choices = ClassChoices(
|
|
|
name="predictor",
|
|
name="predictor",
|
|
|
classes=dict(
|
|
classes=dict(
|
|
@@ -353,7 +362,7 @@ class ASRTask(AbsTask):
|
|
|
help="The keyword arguments for CTC class.",
|
|
help="The keyword arguments for CTC class.",
|
|
|
)
|
|
)
|
|
|
group.add_argument(
|
|
group.add_argument(
|
|
|
- "--joint_net_conf",
|
|
|
|
|
|
|
+ "--joint_network_conf",
|
|
|
action=NestedDictAction,
|
|
action=NestedDictAction,
|
|
|
default=None,
|
|
default=None,
|
|
|
help="The keyword arguments for joint network class.",
|
|
help="The keyword arguments for joint network class.",
|