sorted_batch_sampler.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import logging
  2. from typing import Iterator
  3. from typing import Tuple
  4. from typeguard import check_argument_types
  5. from funasr.fileio.read_text import load_num_sequence_text
  6. from funasr.samplers.abs_sampler import AbsSampler
  7. class SortedBatchSampler(AbsSampler):
  8. """BatchSampler with sorted samples by length.
  9. Args:
  10. batch_size:
  11. shape_file:
  12. sort_in_batch: 'descending', 'ascending' or None.
  13. sort_batch:
  14. """
  15. def __init__(
  16. self,
  17. batch_size: int,
  18. shape_file: str,
  19. sort_in_batch: str = "descending",
  20. sort_batch: str = "ascending",
  21. drop_last: bool = False,
  22. ):
  23. assert check_argument_types()
  24. assert batch_size > 0
  25. self.batch_size = batch_size
  26. self.shape_file = shape_file
  27. self.sort_in_batch = sort_in_batch
  28. self.sort_batch = sort_batch
  29. self.drop_last = drop_last
  30. # utt2shape: (Length, ...)
  31. # uttA 100,...
  32. # uttB 201,...
  33. utt2shape = load_num_sequence_text(shape_file, loader_type="csv_int")
  34. if sort_in_batch == "descending":
  35. # Sort samples in descending order (required by RNN)
  36. keys = sorted(utt2shape, key=lambda k: -utt2shape[k][0])
  37. elif sort_in_batch == "ascending":
  38. # Sort samples in ascending order
  39. keys = sorted(utt2shape, key=lambda k: utt2shape[k][0])
  40. else:
  41. raise ValueError(
  42. f"sort_in_batch must be either one of "
  43. f"ascending, descending, or None: {sort_in_batch}"
  44. )
  45. if len(keys) == 0:
  46. raise RuntimeError(f"0 lines found: {shape_file}")
  47. # Apply max(, 1) to avoid 0-batches
  48. N = max(len(keys) // batch_size, 1)
  49. if not self.drop_last:
  50. # Split keys evenly as possible as. Note that If N != 1,
  51. # the these batches always have size of batch_size at minimum.
  52. self.batch_list = [
  53. keys[i * len(keys) // N : (i + 1) * len(keys) // N] for i in range(N)
  54. ]
  55. else:
  56. self.batch_list = [
  57. tuple(keys[i * batch_size : (i + 1) * batch_size]) for i in range(N)
  58. ]
  59. if len(self.batch_list) == 0:
  60. logging.warning(f"{shape_file} is empty")
  61. if sort_in_batch != sort_batch:
  62. if sort_batch not in ("ascending", "descending"):
  63. raise ValueError(
  64. f"sort_batch must be ascending or descending: {sort_batch}"
  65. )
  66. self.batch_list.reverse()
  67. if len(self.batch_list) == 0:
  68. raise RuntimeError("0 batches")
  69. def __repr__(self):
  70. return (
  71. f"{self.__class__.__name__}("
  72. f"N-batch={len(self)}, "
  73. f"batch_size={self.batch_size}, "
  74. f"shape_file={self.shape_file}, "
  75. f"sort_in_batch={self.sort_in_batch}, "
  76. f"sort_batch={self.sort_batch})"
  77. )
  78. def __len__(self):
  79. return len(self.batch_list)
  80. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  81. return iter(self.batch_list)