|
|
@@ -386,7 +386,7 @@ class TransducerModel(AbsESPnetModel):
|
|
|
|
|
|
if not self.training and (self.report_cer or self.report_wer):
|
|
|
if self.error_calculator is None:
|
|
|
- from espnet2.asr_transducer.error_calculator import ErrorCalculator
|
|
|
+ from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
|
|
|
|
|
|
self.error_calculator = ErrorCalculator(
|
|
|
self.decoder,
|
|
|
@@ -398,7 +398,7 @@ class TransducerModel(AbsESPnetModel):
|
|
|
report_wer=self.report_wer,
|
|
|
)
|
|
|
|
|
|
- cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
|
|
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
|
|
|
|
|
|
return loss_transducer, cer_transducer, wer_transducer
|
|
|
|
|
|
@@ -889,6 +889,8 @@ class UnifiedTransducerModel(AbsESPnetModel):
|
|
|
|
|
|
if not self.training and (self.report_cer or self.report_wer):
|
|
|
if self.error_calculator is None:
|
|
|
+ from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
|
|
|
+
|
|
|
self.error_calculator = ErrorCalculator(
|
|
|
self.decoder,
|
|
|
self.joint_network,
|
|
|
@@ -899,7 +901,7 @@ class UnifiedTransducerModel(AbsESPnetModel):
|
|
|
report_wer=self.report_wer,
|
|
|
)
|
|
|
|
|
|
- cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
|
|
|
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
|
|
|
return loss_transducer, cer_transducer, wer_transducer
|
|
|
|
|
|
return loss_transducer, None, None
|