|
|
@@ -17,6 +17,7 @@ from funasr.losses.label_smoothing_loss import (
|
|
|
LabelSmoothingLoss, # noqa: H301
|
|
|
)
|
|
|
from funasr.models.ctc import CTC
|
|
|
+from funasr.models.frontend.abs_frontend import AbsFrontend
|
|
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
|
from funasr.models.base_model import FunASRModel
|
|
|
@@ -41,7 +42,7 @@ class ESPnetASRModel(FunASRModel):
|
|
|
self,
|
|
|
vocab_size: int,
|
|
|
token_list: Union[Tuple[str, ...], List[str]],
|
|
|
- frontend: Optional[torch.nn.Module],
|
|
|
+ frontend: Optional[AbsFrontend],
|
|
|
specaug: Optional[torch.nn.Module],
|
|
|
normalize: Optional[torch.nn.Module],
|
|
|
encoder: AbsEncoder,
|