abs_encoder.py 502 B

123456789101112131415161718192021
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Optional
  4. from typing import Tuple
  5. import torch
  6. class AbsEncoder(torch.nn.Module, ABC):
  7. @abstractmethod
  8. def output_size(self) -> int:
  9. raise NotImplementedError
  10. @abstractmethod
  11. def forward(
  12. self,
  13. xs_pad: torch.Tensor,
  14. ilens: torch.Tensor,
  15. prev_states: torch.Tensor = None,
  16. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  17. raise NotImplementedError