length_batch_sampler.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from typing import Iterator
  2. from typing import List
  3. from typing import Dict
  4. from typing import Tuple
  5. from typing import Union
  6. from funasr.fileio.read_text import load_num_sequence_text
  7. from funasr.samplers.abs_sampler import AbsSampler
  8. class LengthBatchSampler(AbsSampler):
  9. def __init__(
  10. self,
  11. batch_bins: int,
  12. shape_files: Union[Tuple[str, ...], List[str], Dict],
  13. min_batch_size: int = 1,
  14. sort_in_batch: str = "descending",
  15. sort_batch: str = "ascending",
  16. drop_last: bool = False,
  17. padding: bool = True,
  18. ):
  19. assert batch_bins > 0
  20. if sort_batch != "ascending" and sort_batch != "descending":
  21. raise ValueError(
  22. f"sort_batch must be ascending or descending: {sort_batch}"
  23. )
  24. if sort_in_batch != "descending" and sort_in_batch != "ascending":
  25. raise ValueError(
  26. f"sort_in_batch must be ascending or descending: {sort_in_batch}"
  27. )
  28. self.batch_bins = batch_bins
  29. self.shape_files = shape_files
  30. self.sort_in_batch = sort_in_batch
  31. self.sort_batch = sort_batch
  32. self.drop_last = drop_last
  33. # utt2shape: (Length, ...)
  34. # uttA 100,...
  35. # uttB 201,...
  36. if isinstance(shape_files, dict):
  37. utt2shapes = [shape_files]
  38. else:
  39. utt2shapes = [
  40. load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
  41. ]
  42. first_utt2shape = utt2shapes[0]
  43. for s, d in zip(shape_files, utt2shapes):
  44. if set(d) != set(first_utt2shape):
  45. raise RuntimeError(
  46. f"keys are mismatched between {s} != {shape_files[0]}"
  47. )
  48. # Sort samples in ascending order
  49. # (shape order should be like (Length, Dim))
  50. keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
  51. if len(keys) == 0:
  52. raise RuntimeError(f"0 lines found: {shape_files[0]}")
  53. # Decide batch-sizes
  54. batch_sizes = []
  55. current_batch_keys = []
  56. for key in keys:
  57. current_batch_keys.append(key)
  58. # shape: (Length, dim1, dim2, ...)
  59. if padding:
  60. # bins = bs x max_length
  61. bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
  62. else:
  63. # bins = sum of lengths
  64. bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
  65. if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
  66. batch_sizes.append(len(current_batch_keys))
  67. current_batch_keys = []
  68. else:
  69. if len(current_batch_keys) != 0 and (
  70. not self.drop_last or len(batch_sizes) == 0
  71. ):
  72. batch_sizes.append(len(current_batch_keys))
  73. if len(batch_sizes) == 0:
  74. # Maybe we can't reach here
  75. raise RuntimeError("0 batches")
  76. # If the last batch-size is smaller than minimum batch_size,
  77. # the samples are redistributed to the other mini-batches
  78. if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
  79. for i in range(batch_sizes.pop(-1)):
  80. batch_sizes[-(i % len(batch_sizes)) - 1] += 1
  81. if not self.drop_last:
  82. # Bug check
  83. assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
  84. # Set mini-batch
  85. self.batch_list = []
  86. iter_bs = iter(batch_sizes)
  87. bs = next(iter_bs)
  88. minibatch_keys = []
  89. for key in keys:
  90. minibatch_keys.append(key)
  91. if len(minibatch_keys) == bs:
  92. if sort_in_batch == "descending":
  93. minibatch_keys.reverse()
  94. elif sort_in_batch == "ascending":
  95. # Key are already sorted in ascending
  96. pass
  97. else:
  98. raise ValueError(
  99. "sort_in_batch must be ascending"
  100. f" or descending: {sort_in_batch}"
  101. )
  102. self.batch_list.append(tuple(minibatch_keys))
  103. minibatch_keys = []
  104. try:
  105. bs = next(iter_bs)
  106. except StopIteration:
  107. break
  108. if sort_batch == "ascending":
  109. pass
  110. elif sort_batch == "descending":
  111. self.batch_list.reverse()
  112. else:
  113. raise ValueError(
  114. f"sort_batch must be ascending or descending: {sort_batch}"
  115. )
  116. def __repr__(self):
  117. return (
  118. f"{self.__class__.__name__}("
  119. f"N-batch={len(self)}, "
  120. f"batch_bins={self.batch_bins}, "
  121. f"sort_in_batch={self.sort_in_batch}, "
  122. f"sort_batch={self.sort_batch})"
  123. )
  124. def __len__(self):
  125. return len(self.batch_list)
  126. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  127. return iter(self.batch_list)