| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- import random
- from itertools import count
- from functools import partial
- from torch.utils.data import IterableDataset
- from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
- tiebreaker = count()
- def _default_len_fn(token):
- return len(token), next(tiebreaker)
- def _token_len_fn(token, len_fn):
- return len_fn(token), next(tiebreaker), token
- class MaxTokenBucketizerIterDataPipe(IterableDataset):
- def __init__(
- self,
- datapipe,
- batch_size=8000,
- len_fn=_default_len_fn,
- buffer_size=10240,
- sort_size=500,
- batch_mode="padding",
- ):
- assert batch_size > 0, "Batch size is required to be larger than 0!"
- assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
- assert sort_size > 0, "Sort size is required to be larger than 0!"
- datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
- self.datapipe = datapipe
- self.batch_size = batch_size
- self.buffer_size = buffer_size
- self.sort_size = sort_size
- self.batch_mode = batch_mode
- def set_epoch(self, epoch):
- self.epoch = epoch
- def __iter__(self):
- buffer = []
- batch = []
- bucket = []
- max_lengths = 0
- min_lengths = 999999
- batch_lengths = 0
- if self.batch_mode == "clipping":
- assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- buffer.append(d)
- if len(buffer) == self.buffer_size:
- random.shuffle(buffer)
- for sample in buffer:
- bucket.append(sample)
- if len(bucket) == self.sort_size:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length < min_lengths:
- min_lengths = length
- batch_lengths = min_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- min_lengths = length
- batch.append(token)
- bucket = []
- buffer = []
- if buffer:
- random.shuffle(buffer)
- for sample in buffer:
- bucket.append(sample)
- if len(bucket) == self.sort_size:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length < min_lengths:
- min_lengths = length
- batch_lengths = min_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- min_lengths = length
- batch.append(token)
- bucket = []
- buffer = []
- if bucket:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length < min_lengths:
- min_lengths = length
- batch_lengths = min_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- min_lengths = length
- batch.append(token)
- bucket = []
- if batch:
- yield batch
- else:
- if self.buffer_size == -1:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- buffer.append(d)
- buffer.sort()
- for sample in buffer:
- length, _, token = sample
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- bucket.append(batch)
- batch = []
- max_lengths = length
- batch.append(token)
- random.shuffle(bucket)
- if bucket:
- for batch_sample in bucket:
- yield batch_sample
- if batch:
- yield batch
- elif self.buffer_size == 0:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- length, _, token = d
- if length > self.batch_size:
- continue
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- if batch:
- yield batch
- else:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- buffer.append(d)
- if len(buffer) == self.buffer_size:
- random.shuffle(buffer)
- for sample in buffer:
- bucket.append(sample)
- if len(bucket) == self.sort_size:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
- buffer = []
- if buffer:
- random.shuffle(buffer)
- for sample in buffer:
- bucket.append(sample)
- if len(bucket) == self.sort_size:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
- buffer = []
- if bucket:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
- if batch:
- yield batch
|