multiple_iter_factory.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import logging
  2. from typing import Callable
  3. from typing import Collection
  4. from typing import Iterator
  5. import numpy as np
  6. from typeguard import check_argument_types
  7. from funasr.iterators.abs_iter_factory import AbsIterFactory
  8. class MultipleIterFactory(AbsIterFactory):
  9. def __init__(
  10. self,
  11. build_funcs: Collection[Callable[[], AbsIterFactory]],
  12. seed: int = 0,
  13. shuffle: bool = False,
  14. ):
  15. assert check_argument_types()
  16. self.build_funcs = list(build_funcs)
  17. self.seed = seed
  18. self.shuffle = shuffle
  19. def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
  20. if shuffle is None:
  21. shuffle = self.shuffle
  22. build_funcs = list(self.build_funcs)
  23. if shuffle:
  24. np.random.RandomState(epoch + self.seed).shuffle(build_funcs)
  25. for i, build_func in enumerate(build_funcs):
  26. logging.info(f"Building {i}th iter-factory...")
  27. iter_factory = build_func()
  28. assert isinstance(iter_factory, AbsIterFactory), type(iter_factory)
  29. yield from iter_factory.build_iter(epoch, shuffle)