abs_model.py 887 B

12345678910111213141516171819202122232425262728293031
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  6. class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
  7. """The abstract class
  8. To share the loss calculation way among different models,
  9. We uses delegate pattern here:
  10. The instance of this class should be passed to "LanguageModel"
  11. >>> from funasr.punctuation.abs_model import AbsPunctuation
  12. >>> punc = AbsPunctuation()
  13. >>> model = ESPnetPunctuationModel(punc=punc)
  14. This "model" is one of mediator objects for "Task" class.
  15. """
  16. @abstractmethod
  17. def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  18. raise NotImplementedError
  19. @abstractmethod
  20. def with_vad(self) -> bool:
  21. raise NotImplementedError