|
|
@@ -13,12 +13,9 @@ import torch
|
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
from funasr.layers.abs_normalize import AbsNormalize
|
|
|
-from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
-from funasr.models.frontend.abs_frontend import AbsFrontend
|
|
|
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
|
-from funasr.models.specaug.abs_specaug import AbsSpecAug
|
|
|
from funasr.torch_utils.device_funcs import force_gatherable
|
|
|
-from funasr.train.abs_espnet_model import AbsESPnetModel
|
|
|
+from funasr.models.base_model import FunASRModel
|
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
|
from torch.cuda.amp import autocast
|
|
|
@@ -29,16 +26,16 @@ else:
|
|
|
yield
|
|
|
|
|
|
|
|
|
-class Data2VecPretrainModel(AbsESPnetModel):
|
|
|
+class Data2VecPretrainModel(FunASRModel):
|
|
|
"""Data2Vec Pretrain model"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- frontend: Optional[AbsFrontend],
|
|
|
- specaug: Optional[AbsSpecAug],
|
|
|
+ frontend: Optional[torch.nn.Module],
|
|
|
+ specaug: Optional[torch.nn.Module],
|
|
|
normalize: Optional[AbsNormalize],
|
|
|
preencoder: Optional[AbsPreEncoder],
|
|
|
- encoder: AbsEncoder,
|
|
|
+ encoder: torch.nn.Module,
|
|
|
):
|
|
|
assert check_argument_types()
|
|
|
|