abs_decoder.py 473 B

12345678910111213141516171819
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. from funasr.modules.scorers.scorer_interface import ScorerInterface
  6. class AbsDecoder(torch.nn.Module, ScorerInterface, ABC):
  7. @abstractmethod
  8. def forward(
  9. self,
  10. hs_pad: torch.Tensor,
  11. hlens: torch.Tensor,
  12. ys_in_pad: torch.Tensor,
  13. ys_in_lens: torch.Tensor,
  14. ) -> Tuple[torch.Tensor, torch.Tensor]:
  15. raise NotImplementedError