|
|
@@ -24,7 +24,8 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
|
|
batch_size=8000,
|
|
|
len_fn=_default_len_fn,
|
|
|
buffer_size=10240,
|
|
|
- sort_size=500
|
|
|
+ sort_size=500,
|
|
|
+ batch_mode="padding",
|
|
|
):
|
|
|
assert batch_size > 0, "Batch size is required to be larger than 0!"
|
|
|
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
|
|
|
@@ -35,6 +36,7 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
|
|
self.batch_size = batch_size
|
|
|
self.buffer_size = buffer_size
|
|
|
self.sort_size = sort_size
|
|
|
+ self.batch_mode = batch_mode
|
|
|
|
|
|
def set_epoch(self, epoch):
|
|
|
self.epoch = epoch
|
|
|
@@ -46,48 +48,8 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
|
|
max_lengths = 0
|
|
|
batch_lengths = 0
|
|
|
|
|
|
- if self.buffer_size == -1:
|
|
|
- for d in self.datapipe:
|
|
|
- if d[0] > self.batch_size:
|
|
|
- continue
|
|
|
- buffer.append(d)
|
|
|
- buffer.sort()
|
|
|
- for sample in buffer:
|
|
|
- length, _, token = sample
|
|
|
- if length > max_lengths:
|
|
|
- max_lengths = length
|
|
|
- batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
- if batch_lengths > self.batch_size:
|
|
|
- bucket.append(batch)
|
|
|
- batch = []
|
|
|
- max_lengths = length
|
|
|
- batch.append(token)
|
|
|
- random.shuffle(bucket)
|
|
|
- if bucket:
|
|
|
- for batch_sample in bucket:
|
|
|
- yield batch_sample
|
|
|
- if batch:
|
|
|
- yield batch
|
|
|
-
|
|
|
- elif self.buffer_size == 0:
|
|
|
- for d in self.datapipe:
|
|
|
- if d[0] > self.batch_size:
|
|
|
- continue
|
|
|
- length, _, token = d
|
|
|
- if length > self.batch_size:
|
|
|
- continue
|
|
|
- if length > max_lengths:
|
|
|
- max_lengths = length
|
|
|
- batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
- if batch_lengths > self.batch_size:
|
|
|
- yield batch
|
|
|
- batch = []
|
|
|
- max_lengths = length
|
|
|
- batch.append(token)
|
|
|
- if batch:
|
|
|
- yield batch
|
|
|
-
|
|
|
- else:
|
|
|
+ if self.batch_mode == "clipping":
|
|
|
+ assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
|
|
|
for d in self.datapipe:
|
|
|
if d[0] > self.batch_size:
|
|
|
continue
|
|
|
@@ -100,13 +62,13 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
|
|
bucket.sort()
|
|
|
for x in bucket:
|
|
|
length, _, token = x
|
|
|
- if length > max_lengths:
|
|
|
- max_lengths = length
|
|
|
- batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
+ if length < min_lengths:
|
|
|
+ min_lengths = length
|
|
|
+ batch_lengths = min_lengths * (len(batch) + 1)
|
|
|
if batch_lengths > self.batch_size:
|
|
|
yield batch
|
|
|
batch = []
|
|
|
- max_lengths = length
|
|
|
+ min_lengths = length
|
|
|
batch.append(token)
|
|
|
bucket = []
|
|
|
buffer = []
|
|
|
@@ -119,13 +81,13 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
|
|
bucket.sort()
|
|
|
for x in bucket:
|
|
|
length, _, token = x
|
|
|
- if length > max_lengths:
|
|
|
- max_lengths = length
|
|
|
- batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
+ if length < min_lengths:
|
|
|
+ min_lengths = length
|
|
|
+ batch_lengths = min_lengths * (len(batch) + 1)
|
|
|
if batch_lengths > self.batch_size:
|
|
|
yield batch
|
|
|
batch = []
|
|
|
- max_lengths = length
|
|
|
+ min_lengths = length
|
|
|
batch.append(token)
|
|
|
bucket = []
|
|
|
buffer = []
|
|
|
@@ -134,6 +96,50 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
|
|
bucket.sort()
|
|
|
for x in bucket:
|
|
|
length, _, token = x
|
|
|
+ if length < min_lengths:
|
|
|
+ min_lengths = length
|
|
|
+ batch_lengths = min_lengths * (len(batch) + 1)
|
|
|
+ if batch_lengths > self.batch_size:
|
|
|
+ yield batch
|
|
|
+ batch = []
|
|
|
+ min_lengths = length
|
|
|
+ batch.append(token)
|
|
|
+ bucket = []
|
|
|
+
|
|
|
+ if batch:
|
|
|
+ yield batch
|
|
|
+
|
|
|
+ else:
|
|
|
+ if self.buffer_size == -1:
|
|
|
+ for d in self.datapipe:
|
|
|
+ if d[0] > self.batch_size:
|
|
|
+ continue
|
|
|
+ buffer.append(d)
|
|
|
+ buffer.sort()
|
|
|
+ for sample in buffer:
|
|
|
+ length, _, token = sample
|
|
|
+ if length > max_lengths:
|
|
|
+ max_lengths = length
|
|
|
+ batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
+ if batch_lengths > self.batch_size:
|
|
|
+ bucket.append(batch)
|
|
|
+ batch = []
|
|
|
+ max_lengths = length
|
|
|
+ batch.append(token)
|
|
|
+ random.shuffle(bucket)
|
|
|
+ if bucket:
|
|
|
+ for batch_sample in bucket:
|
|
|
+ yield batch_sample
|
|
|
+ if batch:
|
|
|
+ yield batch
|
|
|
+
|
|
|
+ elif self.buffer_size == 0:
|
|
|
+ for d in self.datapipe:
|
|
|
+ if d[0] > self.batch_size:
|
|
|
+ continue
|
|
|
+ length, _, token = d
|
|
|
+ if length > self.batch_size:
|
|
|
+ continue
|
|
|
if length > max_lengths:
|
|
|
max_lengths = length
|
|
|
batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
@@ -142,7 +148,65 @@ class MaxTokenBucketizerIterDataPipe(IterableDataset):
|
|
|
batch = []
|
|
|
max_lengths = length
|
|
|
batch.append(token)
|
|
|
- bucket = []
|
|
|
+ if batch:
|
|
|
+ yield batch
|
|
|
|
|
|
- if batch:
|
|
|
- yield batch
|
|
|
+ else:
|
|
|
+ for d in self.datapipe:
|
|
|
+ if d[0] > self.batch_size:
|
|
|
+ continue
|
|
|
+ buffer.append(d)
|
|
|
+ if len(buffer) == self.buffer_size:
|
|
|
+ random.shuffle(buffer)
|
|
|
+ for sample in buffer:
|
|
|
+ bucket.append(sample)
|
|
|
+ if len(bucket) == self.sort_size:
|
|
|
+ bucket.sort()
|
|
|
+ for x in bucket:
|
|
|
+ length, _, token = x
|
|
|
+ if length > max_lengths:
|
|
|
+ max_lengths = length
|
|
|
+ batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
+ if batch_lengths > self.batch_size:
|
|
|
+ yield batch
|
|
|
+ batch = []
|
|
|
+ max_lengths = length
|
|
|
+ batch.append(token)
|
|
|
+ bucket = []
|
|
|
+ buffer = []
|
|
|
+
|
|
|
+ if buffer:
|
|
|
+ random.shuffle(buffer)
|
|
|
+ for sample in buffer:
|
|
|
+ bucket.append(sample)
|
|
|
+ if len(bucket) == self.sort_size:
|
|
|
+ bucket.sort()
|
|
|
+ for x in bucket:
|
|
|
+ length, _, token = x
|
|
|
+ if length > max_lengths:
|
|
|
+ max_lengths = length
|
|
|
+ batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
+ if batch_lengths > self.batch_size:
|
|
|
+ yield batch
|
|
|
+ batch = []
|
|
|
+ max_lengths = length
|
|
|
+ batch.append(token)
|
|
|
+ bucket = []
|
|
|
+ buffer = []
|
|
|
+
|
|
|
+ if bucket:
|
|
|
+ bucket.sort()
|
|
|
+ for x in bucket:
|
|
|
+ length, _, token = x
|
|
|
+ if length > max_lengths:
|
|
|
+ max_lengths = length
|
|
|
+ batch_lengths = max_lengths * (len(batch) + 1)
|
|
|
+ if batch_lengths > self.batch_size:
|
|
|
+ yield batch
|
|
|
+ batch = []
|
|
|
+ max_lengths = length
|
|
|
+ batch.append(token)
|
|
|
+ bucket = []
|
|
|
+
|
|
|
+ if batch:
|
|
|
+ yield batch
|