unsorted_batch_sampler.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import logging
  2. from typing import Iterator
  3. from typing import Tuple
  4. from funasr.fileio.read_text import read_2column_text
  5. from funasr.samplers.abs_sampler import AbsSampler
  6. class UnsortedBatchSampler(AbsSampler):
  7. """BatchSampler with constant batch-size.
  8. Any sorting is not done in this class,
  9. so no length information is required,
  10. This class is convenient for decoding mode,
  11. or not seq2seq learning e.g. classification.
  12. Args:
  13. batch_size:
  14. key_file:
  15. """
  16. def __init__(
  17. self,
  18. batch_size: int,
  19. key_file: str,
  20. drop_last: bool = False,
  21. utt2category_file: str = None,
  22. ):
  23. assert batch_size > 0
  24. self.batch_size = batch_size
  25. self.key_file = key_file
  26. self.drop_last = drop_last
  27. # utt2shape:
  28. # uttA <anything is o.k>
  29. # uttB <anything is o.k>
  30. utt2any = read_2column_text(key_file)
  31. if len(utt2any) == 0:
  32. logging.warning(f"{key_file} is empty")
  33. # In this case the, the first column in only used
  34. keys = list(utt2any)
  35. if len(keys) == 0:
  36. raise RuntimeError(f"0 lines found: {key_file}")
  37. category2utt = {}
  38. if utt2category_file is not None:
  39. utt2category = read_2column_text(utt2category_file)
  40. if set(utt2category) != set(keys):
  41. raise RuntimeError(
  42. f"keys are mismatched between {utt2category_file} != {key_file}"
  43. )
  44. for k, v in utt2category.items():
  45. category2utt.setdefault(v, []).append(k)
  46. else:
  47. category2utt["default_category"] = keys
  48. self.batch_list = []
  49. for d, v in category2utt.items():
  50. category_keys = v
  51. # Apply max(, 1) to avoid 0-batches
  52. N = max(len(category_keys) // batch_size, 1)
  53. if not self.drop_last:
  54. # Split keys evenly as possible as. Note that If N != 1,
  55. # the these batches always have size of batch_size at minimum.
  56. cur_batch_list = [
  57. category_keys[i * len(keys) // N : (i + 1) * len(keys) // N]
  58. for i in range(N)
  59. ]
  60. else:
  61. cur_batch_list = [
  62. tuple(category_keys[i * batch_size : (i + 1) * batch_size])
  63. for i in range(N)
  64. ]
  65. self.batch_list.extend(cur_batch_list)
  66. def __repr__(self):
  67. return (
  68. f"{self.__class__.__name__}("
  69. f"N-batch={len(self)}, "
  70. f"batch_size={self.batch_size}, "
  71. f"key_file={self.key_file}, "
  72. )
  73. def __len__(self):
  74. return len(self.batch_list)
  75. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  76. return iter(self.batch_list)