num_elements_batch_sampler.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from typing import Iterator
  2. from typing import List
  3. from typing import Tuple
  4. from typing import Union
  5. import numpy as np
  6. from typeguard import check_argument_types
  7. from funasr.fileio.read_text import load_num_sequence_text
  8. from funasr.samplers.abs_sampler import AbsSampler
  9. class NumElementsBatchSampler(AbsSampler):
  10. def __init__(
  11. self,
  12. batch_bins: int,
  13. shape_files: Union[Tuple[str, ...], List[str]],
  14. min_batch_size: int = 1,
  15. sort_in_batch: str = "descending",
  16. sort_batch: str = "ascending",
  17. drop_last: bool = False,
  18. padding: bool = True,
  19. ):
  20. assert check_argument_types()
  21. assert batch_bins > 0
  22. if sort_batch != "ascending" and sort_batch != "descending":
  23. raise ValueError(
  24. f"sort_batch must be ascending or descending: {sort_batch}"
  25. )
  26. if sort_in_batch != "descending" and sort_in_batch != "ascending":
  27. raise ValueError(
  28. f"sort_in_batch must be ascending or descending: {sort_in_batch}"
  29. )
  30. self.batch_bins = batch_bins
  31. self.shape_files = shape_files
  32. self.sort_in_batch = sort_in_batch
  33. self.sort_batch = sort_batch
  34. self.drop_last = drop_last
  35. # utt2shape: (Length, ...)
  36. # uttA 100,...
  37. # uttB 201,...
  38. utt2shapes = [
  39. load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
  40. ]
  41. first_utt2shape = utt2shapes[0]
  42. for s, d in zip(shape_files, utt2shapes):
  43. if set(d) != set(first_utt2shape):
  44. raise RuntimeError(
  45. f"keys are mismatched between {s} != {shape_files[0]}"
  46. )
  47. # Sort samples in ascending order
  48. # (shape order should be like (Length, Dim))
  49. keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
  50. if len(keys) == 0:
  51. raise RuntimeError(f"0 lines found: {shape_files[0]}")
  52. if padding:
  53. # If padding case, the feat-dim must be same over whole corpus,
  54. # therefore the first sample is referred
  55. feat_dims = [np.prod(d[keys[0]][1:]) for d in utt2shapes]
  56. else:
  57. feat_dims = None
  58. # Decide batch-sizes
  59. batch_sizes = []
  60. current_batch_keys = []
  61. for key in keys:
  62. current_batch_keys.append(key)
  63. # shape: (Length, dim1, dim2, ...)
  64. if padding:
  65. for d, s in zip(utt2shapes, shape_files):
  66. if tuple(d[key][1:]) != tuple(d[keys[0]][1:]):
  67. raise RuntimeError(
  68. "If padding=True, the "
  69. f"feature dimension must be unified: {s}",
  70. )
  71. bins = sum(
  72. len(current_batch_keys) * sh[key][0] * d
  73. for sh, d in zip(utt2shapes, feat_dims)
  74. )
  75. else:
  76. bins = sum(
  77. np.prod(d[k]) for k in current_batch_keys for d in utt2shapes
  78. )
  79. if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
  80. batch_sizes.append(len(current_batch_keys))
  81. current_batch_keys = []
  82. else:
  83. if len(current_batch_keys) != 0 and (
  84. not self.drop_last or len(batch_sizes) == 0
  85. ):
  86. batch_sizes.append(len(current_batch_keys))
  87. if len(batch_sizes) == 0:
  88. # Maybe we can't reach here
  89. raise RuntimeError("0 batches")
  90. # If the last batch-size is smaller than minimum batch_size,
  91. # the samples are redistributed to the other mini-batches
  92. if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
  93. for i in range(batch_sizes.pop(-1)):
  94. batch_sizes[-(i % len(batch_sizes)) - 1] += 1
  95. if not self.drop_last:
  96. # Bug check
  97. assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
  98. # Set mini-batch
  99. self.batch_list = []
  100. iter_bs = iter(batch_sizes)
  101. bs = next(iter_bs)
  102. minibatch_keys = []
  103. for key in keys:
  104. minibatch_keys.append(key)
  105. if len(minibatch_keys) == bs:
  106. if sort_in_batch == "descending":
  107. minibatch_keys.reverse()
  108. elif sort_in_batch == "ascending":
  109. # Key are already sorted in ascending
  110. pass
  111. else:
  112. raise ValueError(
  113. "sort_in_batch must be ascending"
  114. f" or descending: {sort_in_batch}"
  115. )
  116. self.batch_list.append(tuple(minibatch_keys))
  117. minibatch_keys = []
  118. try:
  119. bs = next(iter_bs)
  120. except StopIteration:
  121. break
  122. if sort_batch == "ascending":
  123. pass
  124. elif sort_batch == "descending":
  125. self.batch_list.reverse()
  126. else:
  127. raise ValueError(
  128. f"sort_batch must be ascending or descending: {sort_batch}"
  129. )
  130. def __repr__(self):
  131. return (
  132. f"{self.__class__.__name__}("
  133. f"N-batch={len(self)}, "
  134. f"batch_bins={self.batch_bins}, "
  135. f"sort_in_batch={self.sort_in_batch}, "
  136. f"sort_batch={self.sort_batch})"
  137. )
  138. def __len__(self):
  139. return len(self.batch_list)
  140. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  141. return iter(self.batch_list)