num_elements_batch_sampler.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from typing import Iterator
  2. from typing import List
  3. from typing import Tuple
  4. from typing import Union
  5. import numpy as np
  6. from funasr.fileio.read_text import load_num_sequence_text
  7. from funasr.samplers.abs_sampler import AbsSampler
  8. class NumElementsBatchSampler(AbsSampler):
  9. def __init__(
  10. self,
  11. batch_bins: int,
  12. shape_files: Union[Tuple[str, ...], List[str]],
  13. min_batch_size: int = 1,
  14. sort_in_batch: str = "descending",
  15. sort_batch: str = "ascending",
  16. drop_last: bool = False,
  17. padding: bool = True,
  18. ):
  19. assert batch_bins > 0
  20. if sort_batch != "ascending" and sort_batch != "descending":
  21. raise ValueError(
  22. f"sort_batch must be ascending or descending: {sort_batch}"
  23. )
  24. if sort_in_batch != "descending" and sort_in_batch != "ascending":
  25. raise ValueError(
  26. f"sort_in_batch must be ascending or descending: {sort_in_batch}"
  27. )
  28. self.batch_bins = batch_bins
  29. self.shape_files = shape_files
  30. self.sort_in_batch = sort_in_batch
  31. self.sort_batch = sort_batch
  32. self.drop_last = drop_last
  33. # utt2shape: (Length, ...)
  34. # uttA 100,...
  35. # uttB 201,...
  36. utt2shapes = [
  37. load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
  38. ]
  39. first_utt2shape = utt2shapes[0]
  40. for s, d in zip(shape_files, utt2shapes):
  41. if set(d) != set(first_utt2shape):
  42. raise RuntimeError(
  43. f"keys are mismatched between {s} != {shape_files[0]}"
  44. )
  45. # Sort samples in ascending order
  46. # (shape order should be like (Length, Dim))
  47. keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
  48. if len(keys) == 0:
  49. raise RuntimeError(f"0 lines found: {shape_files[0]}")
  50. if padding:
  51. # If padding case, the feat-dim must be same over whole corpus,
  52. # therefore the first sample is referred
  53. feat_dims = [np.prod(d[keys[0]][1:]) for d in utt2shapes]
  54. else:
  55. feat_dims = None
  56. # Decide batch-sizes
  57. batch_sizes = []
  58. current_batch_keys = []
  59. for key in keys:
  60. current_batch_keys.append(key)
  61. # shape: (Length, dim1, dim2, ...)
  62. if padding:
  63. for d, s in zip(utt2shapes, shape_files):
  64. if tuple(d[key][1:]) != tuple(d[keys[0]][1:]):
  65. raise RuntimeError(
  66. "If padding=True, the "
  67. f"feature dimension must be unified: {s}",
  68. )
  69. bins = sum(
  70. len(current_batch_keys) * sh[key][0] * d
  71. for sh, d in zip(utt2shapes, feat_dims)
  72. )
  73. else:
  74. bins = sum(
  75. np.prod(d[k]) for k in current_batch_keys for d in utt2shapes
  76. )
  77. if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
  78. batch_sizes.append(len(current_batch_keys))
  79. current_batch_keys = []
  80. else:
  81. if len(current_batch_keys) != 0 and (
  82. not self.drop_last or len(batch_sizes) == 0
  83. ):
  84. batch_sizes.append(len(current_batch_keys))
  85. if len(batch_sizes) == 0:
  86. # Maybe we can't reach here
  87. raise RuntimeError("0 batches")
  88. # If the last batch-size is smaller than minimum batch_size,
  89. # the samples are redistributed to the other mini-batches
  90. if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
  91. for i in range(batch_sizes.pop(-1)):
  92. batch_sizes[-(i % len(batch_sizes)) - 1] += 1
  93. if not self.drop_last:
  94. # Bug check
  95. assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
  96. # Set mini-batch
  97. self.batch_list = []
  98. iter_bs = iter(batch_sizes)
  99. bs = next(iter_bs)
  100. minibatch_keys = []
  101. for key in keys:
  102. minibatch_keys.append(key)
  103. if len(minibatch_keys) == bs:
  104. if sort_in_batch == "descending":
  105. minibatch_keys.reverse()
  106. elif sort_in_batch == "ascending":
  107. # Key are already sorted in ascending
  108. pass
  109. else:
  110. raise ValueError(
  111. "sort_in_batch must be ascending"
  112. f" or descending: {sort_in_batch}"
  113. )
  114. self.batch_list.append(tuple(minibatch_keys))
  115. minibatch_keys = []
  116. try:
  117. bs = next(iter_bs)
  118. except StopIteration:
  119. break
  120. if sort_batch == "ascending":
  121. pass
  122. elif sort_batch == "descending":
  123. self.batch_list.reverse()
  124. else:
  125. raise ValueError(
  126. f"sort_batch must be ascending or descending: {sort_batch}"
  127. )
  128. def __repr__(self):
  129. return (
  130. f"{self.__class__.__name__}("
  131. f"N-batch={len(self)}, "
  132. f"batch_bins={self.batch_bins}, "
  133. f"sort_in_batch={self.sort_in_batch}, "
  134. f"sort_batch={self.sort_batch})"
  135. )
  136. def __len__(self):
  137. return len(self.batch_list)
  138. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  139. return iter(self.batch_list)