length_batch_sampler.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from typing import Iterator
  2. from typing import List
  3. from typing import Tuple
  4. from typing import Union
  5. from typeguard import check_argument_types
  6. from funasr.fileio.read_text import load_num_sequence_text
  7. from funasr.samplers.abs_sampler import AbsSampler
  8. class LengthBatchSampler(AbsSampler):
  9. def __init__(
  10. self,
  11. batch_bins: int,
  12. shape_files: Union[Tuple[str, ...], List[str]],
  13. min_batch_size: int = 1,
  14. sort_in_batch: str = "descending",
  15. sort_batch: str = "ascending",
  16. drop_last: bool = False,
  17. padding: bool = True,
  18. ):
  19. assert check_argument_types()
  20. assert batch_bins > 0
  21. if sort_batch != "ascending" and sort_batch != "descending":
  22. raise ValueError(
  23. f"sort_batch must be ascending or descending: {sort_batch}"
  24. )
  25. if sort_in_batch != "descending" and sort_in_batch != "ascending":
  26. raise ValueError(
  27. f"sort_in_batch must be ascending or descending: {sort_in_batch}"
  28. )
  29. self.batch_bins = batch_bins
  30. self.shape_files = shape_files
  31. self.sort_in_batch = sort_in_batch
  32. self.sort_batch = sort_batch
  33. self.drop_last = drop_last
  34. # utt2shape: (Length, ...)
  35. # uttA 100,...
  36. # uttB 201,...
  37. utt2shapes = [
  38. load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
  39. ]
  40. first_utt2shape = utt2shapes[0]
  41. for s, d in zip(shape_files, utt2shapes):
  42. if set(d) != set(first_utt2shape):
  43. raise RuntimeError(
  44. f"keys are mismatched between {s} != {shape_files[0]}"
  45. )
  46. # Sort samples in ascending order
  47. # (shape order should be like (Length, Dim))
  48. keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
  49. if len(keys) == 0:
  50. raise RuntimeError(f"0 lines found: {shape_files[0]}")
  51. # Decide batch-sizes
  52. batch_sizes = []
  53. current_batch_keys = []
  54. for key in keys:
  55. current_batch_keys.append(key)
  56. # shape: (Length, dim1, dim2, ...)
  57. if padding:
  58. # bins = bs x max_length
  59. bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
  60. else:
  61. # bins = sum of lengths
  62. bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
  63. if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
  64. batch_sizes.append(len(current_batch_keys))
  65. current_batch_keys = []
  66. else:
  67. if len(current_batch_keys) != 0 and (
  68. not self.drop_last or len(batch_sizes) == 0
  69. ):
  70. batch_sizes.append(len(current_batch_keys))
  71. if len(batch_sizes) == 0:
  72. # Maybe we can't reach here
  73. raise RuntimeError("0 batches")
  74. # If the last batch-size is smaller than minimum batch_size,
  75. # the samples are redistributed to the other mini-batches
  76. if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
  77. for i in range(batch_sizes.pop(-1)):
  78. batch_sizes[-(i % len(batch_sizes)) - 1] += 1
  79. if not self.drop_last:
  80. # Bug check
  81. assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
  82. # Set mini-batch
  83. self.batch_list = []
  84. iter_bs = iter(batch_sizes)
  85. bs = next(iter_bs)
  86. minibatch_keys = []
  87. for key in keys:
  88. minibatch_keys.append(key)
  89. if len(minibatch_keys) == bs:
  90. if sort_in_batch == "descending":
  91. minibatch_keys.reverse()
  92. elif sort_in_batch == "ascending":
  93. # Key are already sorted in ascending
  94. pass
  95. else:
  96. raise ValueError(
  97. "sort_in_batch must be ascending"
  98. f" or descending: {sort_in_batch}"
  99. )
  100. self.batch_list.append(tuple(minibatch_keys))
  101. minibatch_keys = []
  102. try:
  103. bs = next(iter_bs)
  104. except StopIteration:
  105. break
  106. if sort_batch == "ascending":
  107. pass
  108. elif sort_batch == "descending":
  109. self.batch_list.reverse()
  110. else:
  111. raise ValueError(
  112. f"sort_batch must be ascending or descending: {sort_batch}"
  113. )
  114. def __repr__(self):
  115. return (
  116. f"{self.__class__.__name__}("
  117. f"N-batch={len(self)}, "
  118. f"batch_bins={self.batch_bins}, "
  119. f"sort_in_batch={self.sort_in_batch}, "
  120. f"sort_batch={self.sort_batch})"
  121. )
  122. def __len__(self):
  123. return len(self.batch_list)
  124. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  125. return iter(self.batch_list)