abs_normalize.py 358 B

1234567891011121314
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Tuple
  4. import torch
  5. class AbsNormalize(torch.nn.Module, ABC):
  6. @abstractmethod
  7. def forward(
  8. self, input: torch.Tensor, input_lengths: torch.Tensor = None
  9. ) -> Tuple[torch.Tensor, torch.Tensor]:
  10. # return output, output_lengths
  11. raise NotImplementedError