length_batch_sampler.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from typing import Iterator
  2. from typing import List
  3. from typing import Dict
  4. from typing import Tuple
  5. from typing import Union
  6. from typeguard import check_argument_types
  7. from funasr.fileio.read_text import load_num_sequence_text
  8. from funasr.samplers.abs_sampler import AbsSampler
  9. class LengthBatchSampler(AbsSampler):
  10. def __init__(
  11. self,
  12. batch_bins: int,
  13. shape_files: Union[Tuple[str, ...], List[str], Dict],
  14. min_batch_size: int = 1,
  15. sort_in_batch: str = "descending",
  16. sort_batch: str = "ascending",
  17. drop_last: bool = False,
  18. padding: bool = True,
  19. ):
  20. assert check_argument_types()
  21. assert batch_bins > 0
  22. if sort_batch != "ascending" and sort_batch != "descending":
  23. raise ValueError(
  24. f"sort_batch must be ascending or descending: {sort_batch}"
  25. )
  26. if sort_in_batch != "descending" and sort_in_batch != "ascending":
  27. raise ValueError(
  28. f"sort_in_batch must be ascending or descending: {sort_in_batch}"
  29. )
  30. self.batch_bins = batch_bins
  31. self.shape_files = shape_files
  32. self.sort_in_batch = sort_in_batch
  33. self.sort_batch = sort_batch
  34. self.drop_last = drop_last
  35. # utt2shape: (Length, ...)
  36. # uttA 100,...
  37. # uttB 201,...
  38. if isinstance(shape_files, dict):
  39. utt2shapes = [shape_files]
  40. else:
  41. utt2shapes = [
  42. load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
  43. ]
  44. first_utt2shape = utt2shapes[0]
  45. for s, d in zip(shape_files, utt2shapes):
  46. if set(d) != set(first_utt2shape):
  47. raise RuntimeError(
  48. f"keys are mismatched between {s} != {shape_files[0]}"
  49. )
  50. # Sort samples in ascending order
  51. # (shape order should be like (Length, Dim))
  52. keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
  53. if len(keys) == 0:
  54. raise RuntimeError(f"0 lines found: {shape_files[0]}")
  55. # Decide batch-sizes
  56. batch_sizes = []
  57. current_batch_keys = []
  58. for key in keys:
  59. current_batch_keys.append(key)
  60. # shape: (Length, dim1, dim2, ...)
  61. if padding:
  62. # bins = bs x max_length
  63. bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
  64. else:
  65. # bins = sum of lengths
  66. bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
  67. if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
  68. batch_sizes.append(len(current_batch_keys))
  69. current_batch_keys = []
  70. else:
  71. if len(current_batch_keys) != 0 and (
  72. not self.drop_last or len(batch_sizes) == 0
  73. ):
  74. batch_sizes.append(len(current_batch_keys))
  75. if len(batch_sizes) == 0:
  76. # Maybe we can't reach here
  77. raise RuntimeError("0 batches")
  78. # If the last batch-size is smaller than minimum batch_size,
  79. # the samples are redistributed to the other mini-batches
  80. if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
  81. for i in range(batch_sizes.pop(-1)):
  82. batch_sizes[-(i % len(batch_sizes)) - 1] += 1
  83. if not self.drop_last:
  84. # Bug check
  85. assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
  86. # Set mini-batch
  87. self.batch_list = []
  88. iter_bs = iter(batch_sizes)
  89. bs = next(iter_bs)
  90. minibatch_keys = []
  91. for key in keys:
  92. minibatch_keys.append(key)
  93. if len(minibatch_keys) == bs:
  94. if sort_in_batch == "descending":
  95. minibatch_keys.reverse()
  96. elif sort_in_batch == "ascending":
  97. # Key are already sorted in ascending
  98. pass
  99. else:
  100. raise ValueError(
  101. "sort_in_batch must be ascending"
  102. f" or descending: {sort_in_batch}"
  103. )
  104. self.batch_list.append(tuple(minibatch_keys))
  105. minibatch_keys = []
  106. try:
  107. bs = next(iter_bs)
  108. except StopIteration:
  109. break
  110. if sort_batch == "ascending":
  111. pass
  112. elif sort_batch == "descending":
  113. self.batch_list.reverse()
  114. else:
  115. raise ValueError(
  116. f"sort_batch must be ascending or descending: {sort_batch}"
  117. )
  118. def __repr__(self):
  119. return (
  120. f"{self.__class__.__name__}("
  121. f"N-batch={len(self)}, "
  122. f"batch_bins={self.batch_bins}, "
  123. f"sort_in_batch={self.sort_in_batch}, "
  124. f"sort_batch={self.sort_batch})"
  125. )
  126. def __len__(self):
  127. return len(self.batch_list)
  128. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  129. return iter(self.batch_list)