samplers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import torch
  2. import numpy as np
  3. import logging
  4. import torch.distributed as dist
  5. from funasr.register import tables
  6. @tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
  7. class BatchSampler(torch.utils.data.BatchSampler):
  8. def __init__(self, dataset,
  9. batch_type: str = "example",
  10. batch_size: int = 100,
  11. buffer_size: int = 30,
  12. drop_last: bool = False,
  13. shuffle: bool = True,
  14. is_training: bool = True,
  15. **kwargs):
  16. self.drop_last = drop_last
  17. self.pre_idx = -1
  18. self.dataset = dataset
  19. self.total_samples = len(dataset)
  20. self.batch_type = batch_type
  21. self.batch_size = int(batch_size)
  22. self.buffer_size = buffer_size
  23. self.max_token_length = kwargs.get("max_token_length", 5000)
  24. self.shuffle_idx = np.arange(self.total_samples)
  25. self.shuffle = shuffle and is_training
  26. self.length_scale_source = kwargs.get("length_scale_source", 1.0)
  27. def __len__(self):
  28. return (self.total_samples-1) // self.batch_size + 1
  29. def set_epoch(self, epoch):
  30. np.random.seed(epoch)
  31. def __iter__(self):
  32. if self.shuffle:
  33. np.random.shuffle(self.shuffle_idx)
  34. batch = []
  35. max_token = 0
  36. num_sample = 0
  37. iter_num = (self.total_samples - 1) // self.buffer_size + 1
  38. # print("iter_num: ", iter_num)
  39. for iter in range(self.pre_idx + 1, iter_num):
  40. datalen_with_index = []
  41. for i in range(self.buffer_size):
  42. idx = iter * self.buffer_size + i
  43. if idx >= self.total_samples:
  44. continue
  45. idx_map = self.shuffle_idx[idx]
  46. # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
  47. target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
  48. source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
  49. sample_len_cur = source_len + target_len
  50. datalen_with_index.append([idx, sample_len_cur])
  51. datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
  52. for item in datalen_with_index_sort:
  53. idx, sample_len_cur_raw = item
  54. if sample_len_cur_raw > self.max_token_length:
  55. continue
  56. max_token_cur = max(max_token, sample_len_cur_raw)
  57. max_token_padding = 1 + num_sample
  58. if self.batch_type != 'example':
  59. max_token_padding *= max_token_cur
  60. if max_token_padding <= self.batch_size:
  61. batch.append(idx)
  62. max_token = max_token_cur
  63. num_sample += 1
  64. else:
  65. yield batch
  66. batch = [idx]
  67. max_token = sample_len_cur_raw
  68. num_sample = 1
  69. @tables.register("batch_sampler_classes", "BatchSampler")
  70. @tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
  71. class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
  72. def __init__(self, dataset,
  73. batch_type: str = "example",
  74. batch_size: int = 100,
  75. buffer_size: int = 30,
  76. drop_last: bool = True,
  77. shuffle: bool = True,
  78. is_training: bool = True,
  79. **kwargs):
  80. self.drop_last = drop_last
  81. self.pre_idx = -1
  82. self.dataset = dataset
  83. self.total_samples = len(dataset)
  84. self.batch_type = batch_type
  85. self.batch_size = int(batch_size)
  86. self.buffer_size = buffer_size
  87. self.max_token_length = kwargs.get("max_token_length", 1500)
  88. self.shuffle_idx = np.arange(self.total_samples)
  89. self.shuffle = shuffle and is_training
  90. self.length_scale_source = kwargs.get("length_scale_source", 1.0)
  91. try:
  92. rank = dist.get_rank()
  93. world_size = dist.get_world_size()
  94. except:
  95. rank = 0
  96. world_size = 1
  97. self.rank = rank
  98. self.world_size = world_size
  99. def __len__(self):
  100. return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
  101. def set_epoch(self, epoch):
  102. np.random.seed(epoch)
  103. def __iter__(self):
  104. batch_size_total = self.batch_size * self.world_size
  105. if self.shuffle:
  106. np.random.shuffle(self.shuffle_idx)
  107. batch = []
  108. max_token = 0
  109. num_sample = 0
  110. iter_num = (self.total_samples - 1) // self.buffer_size + 1
  111. # print("iter_num: ", iter_num)
  112. for iter in range(self.pre_idx + 1, iter_num):
  113. # if iter == iter_num -1 and self.drop_last:
  114. # continue
  115. datalen_with_index = []
  116. for i in range(self.buffer_size):
  117. idx = iter * self.buffer_size + i
  118. if idx >= self.total_samples:
  119. continue
  120. idx_map = self.shuffle_idx[idx]
  121. # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
  122. source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
  123. target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
  124. sample_len_cur = source_len + target_len
  125. datalen_with_index.append([idx, sample_len_cur])
  126. datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
  127. for item in datalen_with_index_sort:
  128. idx, sample_len_cur_raw = item
  129. if sample_len_cur_raw > self.max_token_length:
  130. continue
  131. max_token_cur = max(max_token, sample_len_cur_raw)
  132. max_token_padding = 1 + num_sample
  133. # if self.batch_type != 'example':
  134. # max_token_padding *= max_token_cur
  135. if max_token_padding <= batch_size_total:
  136. batch.append(idx)
  137. max_token = max_token_cur
  138. num_sample += 1
  139. else:
  140. batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
  141. yield batch_rank
  142. batch = [idx]
  143. max_token = sample_len_cur_raw
  144. num_sample = 1
  145. @tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
  146. class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
  147. def __init__(self, dataset,
  148. batch_type: str = "example",
  149. batch_size: int = 100,
  150. buffer_size: int = 30,
  151. drop_last: bool = True,
  152. shuffle: bool = True,
  153. is_training: bool = True,
  154. **kwargs):
  155. self.drop_last = drop_last
  156. self.pre_idx = -1
  157. self.dataset = dataset
  158. self.total_samples = len(dataset)
  159. self.batch_type = batch_type
  160. self.batch_size = int(batch_size)
  161. self.buffer_size = buffer_size
  162. self.max_token_length = kwargs.get("max_token_length", 1500)
  163. self.shuffle_idx = np.arange(self.total_samples)
  164. self.shuffle = shuffle and is_training
  165. self.length_scale_source = kwargs.get("length_scale_source", 1.0)
  166. try:
  167. rank = dist.get_rank()
  168. world_size = dist.get_world_size()
  169. except:
  170. rank = 0
  171. world_size = 1
  172. self.rank = rank
  173. self.world_size = world_size
  174. def __len__(self):
  175. return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
  176. def set_epoch(self, epoch):
  177. np.random.seed(epoch)
  178. def __iter__(self):
  179. batch_size_total = self.batch_size * self.world_size
  180. if self.shuffle:
  181. np.random.shuffle(self.shuffle_idx)
  182. batch_list_all_rank = []
  183. batch_list_cur = []
  184. max_token = 0
  185. num_sample = 0
  186. iter_num = (self.total_samples - 1) // self.buffer_size + 1
  187. # print("iter_num: ", iter_num)
  188. for iter in range(self.pre_idx + 1, iter_num):
  189. # if iter == iter_num - 1 and self.drop_last:
  190. # continue
  191. datalen_with_index = []
  192. for i in range(self.buffer_size):
  193. idx = iter * self.buffer_size + i
  194. if idx >= self.total_samples:
  195. continue
  196. idx_map = self.shuffle_idx[idx]
  197. # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
  198. source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
  199. target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
  200. sample_len_cur = source_len + target_len
  201. datalen_with_index.append([idx, sample_len_cur])
  202. datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
  203. for ii, item in enumerate(datalen_with_index_sort):
  204. is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
  205. idx, sample_len_cur_raw = item
  206. if sample_len_cur_raw > self.max_token_length:
  207. continue
  208. max_token_cur = max(max_token, sample_len_cur_raw)
  209. max_token_padding = 1 + num_sample
  210. if self.batch_type != 'example':
  211. max_token_padding *= max_token_cur
  212. if len(batch_list_all_rank) < self.world_size:
  213. if max_token_padding <= self.batch_size:
  214. batch_list_cur.append(idx)
  215. max_token = max_token_cur
  216. num_sample += 1
  217. else:
  218. batch_list_all_rank.append(batch_list_cur)
  219. batch_list_cur = []
  220. else:
  221. batch_rank = batch_list_all_rank[self.rank]
  222. yield batch_rank
  223. batch_list_all_rank = [idx]
  224. max_token = sample_len_cur_raw
  225. num_sample = 1