map.py 453 B

12345678910111213141516171819202122
  1. from torch.utils.data import IterableDataset
  2. def default_fn(data):
  3. return data
  4. class MapperIterDataPipe(IterableDataset):
  5. def __init__(self,
  6. datapipe,
  7. fn=default_fn):
  8. self.datapipe = datapipe
  9. self.fn = fn
  10. def set_epoch(self, epoch):
  11. self.epoch = epoch
  12. def __iter__(self):
  13. assert callable(self.fn)
  14. for data in self.datapipe:
  15. yield self.fn(data)