abs_frontend.py 399 B

1234567891011121314151617
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. class AbsFrontend(torch.nn.Module, ABC):
  6. @abstractmethod
  7. def output_size(self) -> int:
  8. raise NotImplementedError
  9. @abstractmethod
  10. def forward(
  11. self, input: torch.Tensor, input_lengths: torch.Tensor
  12. ) -> Tuple[torch.Tensor, torch.Tensor]:
  13. raise NotImplementedError