abs_sampler.py 425 B

12345678910111213141516171819
  1. from abc import ABC
  2. from abc import abstractmethod
  3. from typing import Iterator
  4. from typing import Tuple
  5. from torch.utils.data import Sampler
  6. class AbsSampler(Sampler, ABC):
  7. @abstractmethod
  8. def __len__(self) -> int:
  9. raise NotImplementedError
  10. @abstractmethod
  11. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  12. raise NotImplementedError
  13. def generate(self, seed):
  14. return list(self)