unsorted_batch_sampler.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import logging
  2. from typing import Iterator
  3. from typing import Tuple
  4. from typeguard import check_argument_types
  5. from funasr.fileio.read_text import read_2column_text
  6. from funasr.samplers.abs_sampler import AbsSampler
  7. class UnsortedBatchSampler(AbsSampler):
  8. """BatchSampler with constant batch-size.
  9. Any sorting is not done in this class,
  10. so no length information is required,
  11. This class is convenient for decoding mode,
  12. or not seq2seq learning e.g. classification.
  13. Args:
  14. batch_size:
  15. key_file:
  16. """
  17. def __init__(
  18. self,
  19. batch_size: int,
  20. key_file: str,
  21. drop_last: bool = False,
  22. utt2category_file: str = None,
  23. ):
  24. assert check_argument_types()
  25. assert batch_size > 0
  26. self.batch_size = batch_size
  27. self.key_file = key_file
  28. self.drop_last = drop_last
  29. # utt2shape:
  30. # uttA <anything is o.k>
  31. # uttB <anything is o.k>
  32. utt2any = read_2column_text(key_file)
  33. if len(utt2any) == 0:
  34. logging.warning(f"{key_file} is empty")
  35. # In this case the, the first column in only used
  36. keys = list(utt2any)
  37. if len(keys) == 0:
  38. raise RuntimeError(f"0 lines found: {key_file}")
  39. category2utt = {}
  40. if utt2category_file is not None:
  41. utt2category = read_2column_text(utt2category_file)
  42. if set(utt2category) != set(keys):
  43. raise RuntimeError(
  44. f"keys are mismatched between {utt2category_file} != {key_file}"
  45. )
  46. for k, v in utt2category.items():
  47. category2utt.setdefault(v, []).append(k)
  48. else:
  49. category2utt["default_category"] = keys
  50. self.batch_list = []
  51. for d, v in category2utt.items():
  52. category_keys = v
  53. # Apply max(, 1) to avoid 0-batches
  54. N = max(len(category_keys) // batch_size, 1)
  55. if not self.drop_last:
  56. # Split keys evenly as possible as. Note that If N != 1,
  57. # the these batches always have size of batch_size at minimum.
  58. cur_batch_list = [
  59. category_keys[i * len(keys) // N : (i + 1) * len(keys) // N]
  60. for i in range(N)
  61. ]
  62. else:
  63. cur_batch_list = [
  64. tuple(category_keys[i * batch_size : (i + 1) * batch_size])
  65. for i in range(N)
  66. ]
  67. self.batch_list.extend(cur_batch_list)
  68. def __repr__(self):
  69. return (
  70. f"{self.__class__.__name__}("
  71. f"N-batch={len(self)}, "
  72. f"batch_size={self.batch_size}, "
  73. f"key_file={self.key_file}, "
  74. )
  75. def __len__(self):
  76. return len(self.batch_list)
  77. def __iter__(self) -> Iterator[Tuple[str, ...]]:
  78. return iter(self.batch_list)