Ver Fonte

update error calculator for rnnt

aky15 há 2 anos atrás
pai
commit
bdb8a99da4
2 ficheiros alterados com 11 adições e 5 exclusões
  1. 5 3
      funasr/models/e2e_asr_transducer.py
  2. 6 2
      funasr/modules/e2e_asr_common.py

+ 5 - 3
funasr/models/e2e_asr_transducer.py

@@ -386,7 +386,7 @@ class TransducerModel(AbsESPnetModel):
 
         if not self.training and (self.report_cer or self.report_wer):
             if self.error_calculator is None:
-                from espnet2.asr_transducer.error_calculator import ErrorCalculator
+                from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
 
                 self.error_calculator = ErrorCalculator(
                     self.decoder,
@@ -398,7 +398,7 @@ class TransducerModel(AbsESPnetModel):
                     report_wer=self.report_wer,
                 )
 
-            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
 
             return loss_transducer, cer_transducer, wer_transducer
 
@@ -889,6 +889,8 @@ class UnifiedTransducerModel(AbsESPnetModel):
 
         if not self.training and (self.report_cer or self.report_wer):
             if self.error_calculator is None:
+                from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
+
                 self.error_calculator = ErrorCalculator(
                     self.decoder,
                     self.joint_network,
@@ -899,7 +901,7 @@ class UnifiedTransducerModel(AbsESPnetModel):
                     report_wer=self.report_wer,
                 )
 
-            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
             return loss_transducer, cer_transducer, wer_transducer
 
         return loss_transducer, None, None

+ 6 - 2
funasr/modules/e2e_asr_common.py

@@ -296,12 +296,13 @@ class ErrorCalculatorTransducer:
         self.report_wer = report_wer
 
     def __call__(
-        self, encoder_out: torch.Tensor, target: torch.Tensor
+        self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor,
     ) -> Tuple[Optional[float], Optional[float]]:
         """Calculate sentence-level WER or/and CER score for Transducer model.
         Args:
             encoder_out: Encoder output sequences. (B, T, D_enc)
             target: Target label ID sequences. (B, L)
+            encoder_out_lens: Encoder output sequences length. (B,)
         Returns:
             : Sentence-level CER score.
             : Sentence-level WER score.
@@ -312,7 +313,10 @@ class ErrorCalculatorTransducer:
 
         encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
 
-        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+        batch_nbest = [
+            self.beam_search(encoder_out[b][: encoder_out_lens[b]])
+            for b in range(batchsize)
+        ]
         pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
 
         char_pred, char_target = self.convert_to_char(pred, target)