folded_batch_sampler.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from typing import Iterator
  2. from typing import List
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. from funasr.fileio.read_text import load_num_sequence_text
  7. from funasr.fileio.read_text import read_2column_text
  8. from funasr.samplers.abs_sampler import AbsSampler
  9. class FoldedBatchSampler(AbsSampler):
  10. def __init__(
  11. self,
  12. batch_size: int,
  13. shape_files: Union[Tuple[str, ...], List[str]],
  14. fold_lengths: Sequence[int],
  15. min_batch_size: int = 1,
  16. sort_in_batch: str = "descending",
  17. sort_batch: str = "ascending",
  18. drop_last: bool = False,
  19. utt2category_file: str = None,
  20. ):
  21. assert batch_size > 0
  22. if sort_batch != "ascending" and sort_batch != "descending":
  23. raise ValueError(
  24. f"sort_batch must be ascending or descending: {sort_batch}"
  25. )
  26. if sort_in_batch != "descending" and sort_in_batch != "ascending":
  27. raise ValueError(
  28. f"sort_in_batch must be ascending or descending: {sort_in_batch}"
  29. )
  30. self.batch_size = batch_size
  31. self.shape_files = shape_files
  32. self.sort_in_batch = sort_in_batch
  33. self.sort_batch = sort_batch
  34. self.drop_last = drop_last
  35. # utt2shape: (Length, ...)
  36. # uttA 100,...
  37. # uttB 201,...
  38. utt2shapes = [
  39. load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
  40. ]
  41. first_utt2shape = utt2shapes[0]
  42. for s, d in zip(shape_files, utt2shapes):
  43. if set(d) != set(first_utt2shape):
  44. raise RuntimeError(
  45. f"keys are mismatched between {s} != {shape_files[0]}"
  46. )
  47. # Sort samples in ascending order
  48. # (shape order should be like (Length, Dim))
  49. keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
  50. if len(keys) == 0:
  51. raise RuntimeError(f"0 lines found: {shape_files[0]}")
  52. category2utt = {}
  53. if utt2category_file is not None:
  54. utt2category = read_2column_text(utt2category_file)
  55. if set(utt2category) != set(first_utt2shape):
  56. raise RuntimeError(
  57. "keys are mismatched between "
  58. f"{utt2category_file} != {shape_files[0]}"
  59. )
  60. for k in keys:
  61. category2utt.setdefault(utt2category[k], []).append(k)
  62. else:
  63. category2utt["default_category"] = keys
  64. self.batch_list = []
  65. for d, v in category2utt.items():
  66. category_keys = v
  67. # Decide batch-sizes
  68. start = 0
  69. batch_sizes = []
  70. while True:
  71. k = category_keys[start]
  72. factor = max(int(d[k][0] / m) for d, m in zip(utt2shapes, fold_lengths))
  73. bs = max(min_batch_size, int(batch_size / (1 + factor)))
  74. if self.drop_last and start + bs > len(category_keys):
  75. # This if-block avoids 0-batches
  76. if len(self.batch_list) > 0:
  77. break
  78. bs = min(len(category_keys) - start, bs)
  79. batch_sizes.append(bs)
  80. start += bs
  81. if start >= len(category_keys):
  82. break
  83. if len(batch_sizes) == 0:
  84. # Maybe we can't reach here
  85. raise RuntimeError("0 batches")
  86. # If the last batch-size is smaller than minimum batch_size,
  87. # the samples are redistributed to the other mini-batches
  88. if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
  89. for i in range(batch_sizes.pop(-1)):
  90. batch_sizes[-(i % len(batch_sizes)) - 2] += 1
  91. if not self.drop_last:
  92. # Bug check
  93. assert sum(batch_sizes) == len(
  94. category_keys
  95. ), f"{sum(batch_sizes)} != {len(category_keys)}"
  96. # Set mini-batch
  97. cur_batch_list = []
  98. start = 0
  99. for bs in batch_sizes:
  100. assert len(category_keys) >= start + bs, "Bug"
  101. minibatch_keys = category_keys[start : start + bs]
  102. start += bs
  103. if sort_in_batch == "descending":
  104. minibatch_keys.reverse()
  105. elif sort_in_batch == "ascending":
  106. # Key are already sorted in ascending
  107. pass
  108. else:
  109. raise ValueError(
  110. "sort_in_batch must be ascending or "
  111. f"descending: {sort_in_batch}"
  112. )
  113. cur_batch_list.append(tuple(minibatch_keys))
  114. if sort_batch == "ascending":
  115. pass
  116. elif sort_batch == "descending":
  117. cur_batch_list.reverse()
  118. else:
  119. raise ValueError(
  120. f"sort_batch must be ascending or descending: {sort_batch}"
  121. )
  122. self.batch_list.extend(cur_batch_list)
  123. def __repr__(self):
  124. return (
  125. f"{self.__class__.__name__}("
  126. f"N-batch={len(self)}, "
  127. f"batch_size={self.batch_size}, "
  128. f"shape_files={self.shape_files}, "
  129. f"sort_in_batch={self.sort_in_batch}, "
  130. f"sort_batch={self.sort_batch})"
  131. )
  132. def __len__(self):
  133. return len(self.batch_list)
  134. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  135. return iter(self.batch_list)