|
|
@@ -12,12 +12,12 @@ from typing import Tuple
|
|
|
import torch
|
|
|
|
|
|
from funasr.layers.abs_normalize import AbsNormalize
|
|
|
+from funasr.models.base_model import FunASRModel
|
|
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
from funasr.models.frontend.abs_frontend import AbsFrontend
|
|
|
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
|
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
|
|
from funasr.torch_utils.device_funcs import force_gatherable
|
|
|
-from funasr.models.base_model import FunASRModel
|
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
|
from torch.cuda.amp import autocast
|
|
|
@@ -36,8 +36,8 @@ class Data2VecPretrainModel(FunASRModel):
|
|
|
frontend: Optional[AbsFrontend],
|
|
|
specaug: Optional[AbsSpecAug],
|
|
|
normalize: Optional[AbsNormalize],
|
|
|
- preencoder: Optional[AbsPreEncoder],
|
|
|
encoder: AbsEncoder,
|
|
|
+ preencoder: Optional[AbsPreEncoder] = None,
|
|
|
):
|
|
|
|
|
|
super().__init__()
|