| 123456789101112131415161718192021 |
- from abc import ABC
- from abc import abstractmethod
- from typing import Optional
- from typing import Tuple
- import torch
- class AbsEncoder(torch.nn.Module, ABC):
- @abstractmethod
- def output_size(self) -> int:
- raise NotImplementedError
- @abstractmethod
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- raise NotImplementedError
|