batch.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import random
  2. from itertools import count
  3. from functools import partial
  4. from torch.utils.data import IterableDataset
  5. from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
  6. tiebreaker = count()
  7. def _default_len_fn(token):
  8. return len(token), next(tiebreaker)
  9. def _token_len_fn(token, len_fn):
  10. return len_fn(token), next(tiebreaker), token
  11. class MaxTokenBucketizerIterDataPipe(IterableDataset):
  12. def __init__(
  13. self,
  14. datapipe,
  15. batch_size=8000,
  16. len_fn=_default_len_fn,
  17. buffer_size=10240,
  18. sort_size=500,
  19. batch_mode="padding",
  20. ):
  21. assert batch_size > 0, "Batch size is required to be larger than 0!"
  22. assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
  23. assert sort_size > 0, "Sort size is required to be larger than 0!"
  24. datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
  25. self.datapipe = datapipe
  26. self.batch_size = batch_size
  27. self.buffer_size = buffer_size
  28. self.sort_size = sort_size
  29. self.batch_mode = batch_mode
  30. def set_epoch(self, epoch):
  31. self.epoch = epoch
  32. def __iter__(self):
  33. buffer = []
  34. batch = []
  35. bucket = []
  36. max_lengths = 0
  37. min_lengths = 999999
  38. batch_lengths = 0
  39. if self.batch_mode == "clipping":
  40. assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
  41. for d in self.datapipe:
  42. if d[0] > self.batch_size:
  43. continue
  44. buffer.append(d)
  45. if len(buffer) == self.buffer_size:
  46. random.shuffle(buffer)
  47. for sample in buffer:
  48. bucket.append(sample)
  49. if len(bucket) == self.sort_size:
  50. bucket.sort()
  51. for x in bucket:
  52. length, _, token = x
  53. if length < min_lengths:
  54. min_lengths = length
  55. batch_lengths = min_lengths * (len(batch) + 1)
  56. if batch_lengths > self.batch_size:
  57. yield batch
  58. batch = []
  59. min_lengths = length
  60. batch.append(token)
  61. bucket = []
  62. buffer = []
  63. if buffer:
  64. random.shuffle(buffer)
  65. for sample in buffer:
  66. bucket.append(sample)
  67. if len(bucket) == self.sort_size:
  68. bucket.sort()
  69. for x in bucket:
  70. length, _, token = x
  71. if length < min_lengths:
  72. min_lengths = length
  73. batch_lengths = min_lengths * (len(batch) + 1)
  74. if batch_lengths > self.batch_size:
  75. yield batch
  76. batch = []
  77. min_lengths = length
  78. batch.append(token)
  79. bucket = []
  80. buffer = []
  81. if bucket:
  82. bucket.sort()
  83. for x in bucket:
  84. length, _, token = x
  85. if length < min_lengths:
  86. min_lengths = length
  87. batch_lengths = min_lengths * (len(batch) + 1)
  88. if batch_lengths > self.batch_size:
  89. yield batch
  90. batch = []
  91. min_lengths = length
  92. batch.append(token)
  93. bucket = []
  94. if batch:
  95. yield batch
  96. else:
  97. if self.buffer_size == -1:
  98. for d in self.datapipe:
  99. if d[0] > self.batch_size:
  100. continue
  101. buffer.append(d)
  102. buffer.sort()
  103. for sample in buffer:
  104. length, _, token = sample
  105. if length > max_lengths:
  106. max_lengths = length
  107. batch_lengths = max_lengths * (len(batch) + 1)
  108. if batch_lengths > self.batch_size:
  109. bucket.append(batch)
  110. batch = []
  111. max_lengths = length
  112. batch.append(token)
  113. random.shuffle(bucket)
  114. if bucket:
  115. for batch_sample in bucket:
  116. yield batch_sample
  117. if batch:
  118. yield batch
  119. elif self.buffer_size == 0:
  120. for d in self.datapipe:
  121. if d[0] > self.batch_size:
  122. continue
  123. length, _, token = d
  124. if length > self.batch_size:
  125. continue
  126. if length > max_lengths:
  127. max_lengths = length
  128. batch_lengths = max_lengths * (len(batch) + 1)
  129. if batch_lengths > self.batch_size:
  130. yield batch
  131. batch = []
  132. max_lengths = length
  133. batch.append(token)
  134. if batch:
  135. yield batch
  136. else:
  137. for d in self.datapipe:
  138. if d[0] > self.batch_size:
  139. continue
  140. buffer.append(d)
  141. if len(buffer) == self.buffer_size:
  142. random.shuffle(buffer)
  143. for sample in buffer:
  144. bucket.append(sample)
  145. if len(bucket) == self.sort_size:
  146. bucket.sort()
  147. for x in bucket:
  148. length, _, token = x
  149. if length > max_lengths:
  150. max_lengths = length
  151. batch_lengths = max_lengths * (len(batch) + 1)
  152. if batch_lengths > self.batch_size:
  153. yield batch
  154. batch = []
  155. max_lengths = length
  156. batch.append(token)
  157. bucket = []
  158. buffer = []
  159. if buffer:
  160. random.shuffle(buffer)
  161. for sample in buffer:
  162. bucket.append(sample)
  163. if len(bucket) == self.sort_size:
  164. bucket.sort()
  165. for x in bucket:
  166. length, _, token = x
  167. if length > max_lengths:
  168. max_lengths = length
  169. batch_lengths = max_lengths * (len(batch) + 1)
  170. if batch_lengths > self.batch_size:
  171. yield batch
  172. batch = []
  173. max_lengths = length
  174. batch.append(token)
  175. bucket = []
  176. buffer = []
  177. if bucket:
  178. bucket.sort()
  179. for x in bucket:
  180. length, _, token = x
  181. if length > max_lengths:
  182. max_lengths = length
  183. batch_lengths = max_lengths * (len(batch) + 1)
  184. if batch_lengths > self.batch_size:
  185. yield batch
  186. batch = []
  187. max_lengths = length
  188. batch.append(token)
  189. bucket = []
  190. if batch:
  191. yield batch