@@ -39,7 +39,7 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
self.batch_mode = batch_mode
def set_epoch(self, epoch):
- self.epoch = epoch
+ self.datapipe.set_epoch(epoch)
def __iter__(self):
buffer = []
@@ -13,7 +13,7 @@ class FilterIterDataPipe(IterableDataset):
self.fn = fn
assert callable(self.fn)
@@ -21,4 +21,4 @@ class FilterIterDataPipe(IterableDataset):
if self.fn(data):
yield data
else:
- continue
+ continue
@@ -14,7 +14,7 @@ class MapperIterDataPipe(IterableDataset):