abs_model.py 772 B

1234567891011121314151617181920212223242526272829
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  6. class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
  7. """The abstract LM class
  8. To share the loss calculation way among different models,
  9. We uses delegate pattern here:
  10. The instance of this class should be passed to "LanguageModel"
  11. >>> from funasr.lm.abs_model import AbsLM
  12. >>> lm = AbsLM()
  13. >>> model = LanguageESPnetModel(lm=lm)
  14. This "model" is one of mediator objects for "Task" class.
  15. """
  16. @abstractmethod
  17. def forward(
  18. self, input: torch.Tensor, hidden: torch.Tensor
  19. ) -> Tuple[torch.Tensor, torch.Tensor]:
  20. raise NotImplementedError