Explorar o código

space between tokens

aky15 %!s(int64=3) %!d(string=hai) anos
pai
achega
3e333c0abf

+ 0 - 1
funasr/models_transducer/error_calculator.py

@@ -137,7 +137,6 @@ class ErrorCalculator:
         for i, char_pred_i in enumerate(char_pred):
             pred = char_pred_i.replace(" ", "")
             target = char_target[i].replace(" ", "")
-
             distances.append(editdistance.eval(pred, target))
             lens.append(len(target))
 

+ 2 - 2
funasr/models_transducer/espnet_transducer_model_unified.py

@@ -455,7 +455,8 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
                 gather=True,
         )
 
-        if not self.training and (self.report_cer or self.report_wer):
+        #if not self.training and (self.report_cer or self.report_wer):
+        if self.report_cer or self.report_wer:
             if self.error_calculator is None:
                 self.error_calculator = ErrorCalculator(
                     self.decoder,
@@ -468,7 +469,6 @@ class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
                 )
 
             cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
-
             return loss_transducer, cer_transducer, wer_transducer
 
         return loss_transducer, None, None

+ 7 - 0
funasr/tasks/asr_transducer.py

@@ -137,6 +137,12 @@ class ASRTransducerTask(AbsTask):
             default=None,
             help="Integer-string mapper for tokens.",
         )
+        group.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
         group.add_argument(
             "--input_size",
             type=int_or_none,
@@ -289,6 +295,7 @@ class ASRTransducerTask(AbsTask):
                 non_linguistic_symbols=args.non_linguistic_symbols,
                 text_cleaner=args.cleaner,
                 g2p_type=args.g2p,
+                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                 rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                 rir_apply_prob=args.rir_apply_prob
                 if hasattr(args, "rir_apply_prob")