folded_batch_sampler.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import Iterator
  2. from typing import List
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. from typeguard import check_argument_types
  7. from funasr.fileio.read_text import load_num_sequence_text
  8. from funasr.fileio.read_text import read_2column_text
  9. from funasr.samplers.abs_sampler import AbsSampler
  10. class FoldedBatchSampler(AbsSampler):
  11. def __init__(
  12. self,
  13. batch_size: int,
  14. shape_files: Union[Tuple[str, ...], List[str]],
  15. fold_lengths: Sequence[int],
  16. min_batch_size: int = 1,
  17. sort_in_batch: str = "descending",
  18. sort_batch: str = "ascending",
  19. drop_last: bool = False,
  20. utt2category_file: str = None,
  21. ):
  22. assert check_argument_types()
  23. assert batch_size > 0
  24. if sort_batch != "ascending" and sort_batch != "descending":
  25. raise ValueError(
  26. f"sort_batch must be ascending or descending: {sort_batch}"
  27. )
  28. if sort_in_batch != "descending" and sort_in_batch != "ascending":
  29. raise ValueError(
  30. f"sort_in_batch must be ascending or descending: {sort_in_batch}"
  31. )
  32. self.batch_size = batch_size
  33. self.shape_files = shape_files
  34. self.sort_in_batch = sort_in_batch
  35. self.sort_batch = sort_batch
  36. self.drop_last = drop_last
  37. # utt2shape: (Length, ...)
  38. # uttA 100,...
  39. # uttB 201,...
  40. utt2shapes = [
  41. load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
  42. ]
  43. first_utt2shape = utt2shapes[0]
  44. for s, d in zip(shape_files, utt2shapes):
  45. if set(d) != set(first_utt2shape):
  46. raise RuntimeError(
  47. f"keys are mismatched between {s} != {shape_files[0]}"
  48. )
  49. # Sort samples in ascending order
  50. # (shape order should be like (Length, Dim))
  51. keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
  52. if len(keys) == 0:
  53. raise RuntimeError(f"0 lines found: {shape_files[0]}")
  54. category2utt = {}
  55. if utt2category_file is not None:
  56. utt2category = read_2column_text(utt2category_file)
  57. if set(utt2category) != set(first_utt2shape):
  58. raise RuntimeError(
  59. "keys are mismatched between "
  60. f"{utt2category_file} != {shape_files[0]}"
  61. )
  62. for k in keys:
  63. category2utt.setdefault(utt2category[k], []).append(k)
  64. else:
  65. category2utt["default_category"] = keys
  66. self.batch_list = []
  67. for d, v in category2utt.items():
  68. category_keys = v
  69. # Decide batch-sizes
  70. start = 0
  71. batch_sizes = []
  72. while True:
  73. k = category_keys[start]
  74. factor = max(int(d[k][0] / m) for d, m in zip(utt2shapes, fold_lengths))
  75. bs = max(min_batch_size, int(batch_size / (1 + factor)))
  76. if self.drop_last and start + bs > len(category_keys):
  77. # This if-block avoids 0-batches
  78. if len(self.batch_list) > 0:
  79. break
  80. bs = min(len(category_keys) - start, bs)
  81. batch_sizes.append(bs)
  82. start += bs
  83. if start >= len(category_keys):
  84. break
  85. if len(batch_sizes) == 0:
  86. # Maybe we can't reach here
  87. raise RuntimeError("0 batches")
  88. # If the last batch-size is smaller than minimum batch_size,
  89. # the samples are redistributed to the other mini-batches
  90. if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
  91. for i in range(batch_sizes.pop(-1)):
  92. batch_sizes[-(i % len(batch_sizes)) - 2] += 1
  93. if not self.drop_last:
  94. # Bug check
  95. assert sum(batch_sizes) == len(
  96. category_keys
  97. ), f"{sum(batch_sizes)} != {len(category_keys)}"
  98. # Set mini-batch
  99. cur_batch_list = []
  100. start = 0
  101. for bs in batch_sizes:
  102. assert len(category_keys) >= start + bs, "Bug"
  103. minibatch_keys = category_keys[start : start + bs]
  104. start += bs
  105. if sort_in_batch == "descending":
  106. minibatch_keys.reverse()
  107. elif sort_in_batch == "ascending":
  108. # Key are already sorted in ascending
  109. pass
  110. else:
  111. raise ValueError(
  112. "sort_in_batch must be ascending or "
  113. f"descending: {sort_in_batch}"
  114. )
  115. cur_batch_list.append(tuple(minibatch_keys))
  116. if sort_batch == "ascending":
  117. pass
  118. elif sort_batch == "descending":
  119. cur_batch_list.reverse()
  120. else:
  121. raise ValueError(
  122. f"sort_batch must be ascending or descending: {sort_batch}"
  123. )
  124. self.batch_list.extend(cur_batch_list)
  125. def __repr__(self):
  126. return (
  127. f"{self.__class__.__name__}("
  128. f"N-batch={len(self)}, "
  129. f"batch_size={self.batch_size}, "
  130. f"shape_files={self.shape_files}, "
  131. f"sort_in_batch={self.sort_in_batch}, "
  132. f"sort_batch={self.sort_batch})"
  133. )
  134. def __len__(self):
  135. return len(self.batch_list)
  136. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  137. return iter(self.batch_list)