multiple_iter_factory.py 1.0 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import logging
  2. from typing import Callable
  3. from typing import Collection
  4. from typing import Iterator
  5. import numpy as np
  6. from funasr.iterators.abs_iter_factory import AbsIterFactory
  7. class MultipleIterFactory(AbsIterFactory):
  8. def __init__(
  9. self,
  10. build_funcs: Collection[Callable[[], AbsIterFactory]],
  11. seed: int = 0,
  12. shuffle: bool = False,
  13. ):
  14. self.build_funcs = list(build_funcs)
  15. self.seed = seed
  16. self.shuffle = shuffle
  17. def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
  18. if shuffle is None:
  19. shuffle = self.shuffle
  20. build_funcs = list(self.build_funcs)
  21. if shuffle:
  22. np.random.RandomState(epoch + self.seed).shuffle(build_funcs)
  23. for i, build_func in enumerate(build_funcs):
  24. logging.info(f"Building {i}th iter-factory...")
  25. iter_factory = build_func()
  26. assert isinstance(iter_factory, AbsIterFactory), type(iter_factory)
  27. yield from iter_factory.build_iter(epoch, shuffle)