|
|
@@ -15,7 +15,7 @@ from typeguard import check_argument_types
|
|
|
|
|
|
from funasr.modules.nets_utils import make_pad_mask
|
|
|
from funasr.torch_utils.device_funcs import force_gatherable
|
|
|
-from funasr.train.abs_espnet_model import AbsESPnetModel
|
|
|
+from funasr.models.base_model import FunASRModel
|
|
|
|
|
|
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
|
|
|
"""The abstract LM class
|
|
|
@@ -39,7 +39,7 @@ class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
-class LanguageModel(AbsESPnetModel):
|
|
|
+class LanguageModel(FunASRModel):
|
|
|
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
|
|
|
assert check_argument_types()
|
|
|
super().__init__()
|