|
|
@@ -17,9 +17,8 @@ from funasr.losses.label_smoothing_loss import (
|
|
|
LabelSmoothingLoss, # noqa: H301
|
|
|
)
|
|
|
from funasr.models.ctc import CTC
|
|
|
+from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
|
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
|
|
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
|
from funasr.models.base_model import FunASRModel
|
|
|
from funasr.modules.add_sos_eos import add_sos_eos
|
|
|
from funasr.modules.e2e_asr_common import ErrorCalculator
|
|
|
@@ -45,9 +44,7 @@ class ESPnetASRModel(FunASRModel):
|
|
|
frontend: Optional[torch.nn.Module],
|
|
|
specaug: Optional[torch.nn.Module],
|
|
|
normalize: Optional[torch.nn.Module],
|
|
|
- preencoder: Optional[AbsPreEncoder],
|
|
|
- encoder: torch.nn.Module,
|
|
|
- postencoder: Optional[AbsPostEncoder],
|
|
|
+ encoder: AbsEncoder,
|
|
|
decoder: AbsDecoder,
|
|
|
ctc: CTC,
|
|
|
ctc_weight: float = 0.5,
|