Sfoglia il codice sorgente

joint network bug fix

aky15 2 anni fa
parent
commit
6f10b7bc41
1 ha cambiato i file con 10 aggiunte e 1 eliminazioni
  1. 10 1
      funasr/tasks/asr.py

+ 10 - 1
funasr/tasks/asr.py

@@ -224,6 +224,15 @@ rnnt_decoder_choices = ClassChoices(
     default="rnnt",
 )
 
+joint_network_choices = ClassChoices(
+    name="joint_network",
+    classes=dict(
+        joint_network=JointNetwork,
+    ),
+    default="joint_network",
+    optional=True,
+)
+
 predictor_choices = ClassChoices(
     name="predictor",
     classes=dict(
@@ -353,7 +362,7 @@ class ASRTask(AbsTask):
             help="The keyword arguments for CTC class.",
         )
         group.add_argument(
-            "--joint_net_conf",
+            "--joint_network_conf",
             action=NestedDictAction,
             default=None,
             help="The keyword arguments for joint network class.",