build_batch_sampler.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from typing import List
  2. from typing import Dict
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. from funasr.samplers.abs_sampler import AbsSampler
  7. from funasr.samplers.folded_batch_sampler import FoldedBatchSampler
  8. from funasr.samplers.length_batch_sampler import LengthBatchSampler
  9. from funasr.samplers.num_elements_batch_sampler import NumElementsBatchSampler
  10. from funasr.samplers.sorted_batch_sampler import SortedBatchSampler
  11. from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
  12. BATCH_TYPES = dict(
  13. unsorted="UnsortedBatchSampler has nothing in particular feature and "
  14. "just creates mini-batches which has constant batch_size. "
  15. "This sampler doesn't require any length "
  16. "information for each feature. "
  17. "'key_file' is just a text file which describes each sample name."
  18. "\n\n"
  19. " utterance_id_a\n"
  20. " utterance_id_b\n"
  21. " utterance_id_c\n"
  22. "\n"
  23. "The fist column is referred, so 'shape file' can be used, too.\n\n"
  24. " utterance_id_a 100,80\n"
  25. " utterance_id_b 400,80\n"
  26. " utterance_id_c 512,80\n",
  27. sorted="SortedBatchSampler sorts samples by the length of the first input "
  28. " in order to make each sample in a mini-batch has close length. "
  29. "This sampler requires a text file which describes the length for each sample "
  30. "\n\n"
  31. " utterance_id_a 1000\n"
  32. " utterance_id_b 1453\n"
  33. " utterance_id_c 1241\n"
  34. "\n"
  35. "The first element of feature dimensions is referred, "
  36. "so 'shape_file' can be also used.\n\n"
  37. " utterance_id_a 1000,80\n"
  38. " utterance_id_b 1453,80\n"
  39. " utterance_id_c 1241,80\n",
  40. folded="FoldedBatchSampler supports variable batch_size. "
  41. "The batch_size is decided by\n"
  42. " batch_size = base_batch_size // (L // fold_length)\n"
  43. "L is referred to the largest length of samples in the mini-batch. "
  44. "This samples requires length information as same as SortedBatchSampler\n",
  45. length="LengthBatchSampler supports variable batch_size. "
  46. "This sampler makes mini-batches which have same number of 'bins' as possible "
  47. "counting by the total lengths of each feature in the mini-batch. "
  48. "This sampler requires a text file which describes the length for each sample. "
  49. "\n\n"
  50. " utterance_id_a 1000\n"
  51. " utterance_id_b 1453\n"
  52. " utterance_id_c 1241\n"
  53. "\n"
  54. "The first element of feature dimensions is referred, "
  55. "so 'shape_file' can be also used.\n\n"
  56. " utterance_id_a 1000,80\n"
  57. " utterance_id_b 1453,80\n"
  58. " utterance_id_c 1241,80\n",
  59. numel="NumElementsBatchSampler supports variable batch_size. "
  60. "Just like LengthBatchSampler, this sampler makes mini-batches"
  61. " which have same number of 'bins' as possible "
  62. "counting by the total number of elements of each feature "
  63. "instead of the length. "
  64. "Thus this sampler requires the full information of the dimension of the features. "
  65. "\n\n"
  66. " utterance_id_a 1000,80\n"
  67. " utterance_id_b 1453,80\n"
  68. " utterance_id_c 1241,80\n",
  69. )
  70. def build_batch_sampler(
  71. type: str,
  72. batch_size: int,
  73. batch_bins: int,
  74. shape_files: Union[Tuple[str, ...], List[str], Dict],
  75. sort_in_batch: str = "descending",
  76. sort_batch: str = "ascending",
  77. drop_last: bool = False,
  78. min_batch_size: int = 1,
  79. fold_lengths: Sequence[int] = (),
  80. padding: bool = True,
  81. utt2category_file: str = None,
  82. ) -> AbsSampler:
  83. """Helper function to instantiate BatchSampler.
  84. Args:
  85. type: mini-batch type. "unsorted", "sorted", "folded", "numel", or, "length"
  86. batch_size: The mini-batch size. Used for "unsorted", "sorted", "folded" mode
  87. batch_bins: Used for "numel" model
  88. shape_files: Text files describing the length and dimension
  89. of each features. e.g. uttA 1330,80
  90. sort_in_batch:
  91. sort_batch:
  92. drop_last:
  93. min_batch_size: Used for "numel" or "folded" mode
  94. fold_lengths: Used for "folded" mode
  95. padding: Whether sequences are input as a padded tensor or not.
  96. used for "numel" mode
  97. """
  98. if len(shape_files) == 0:
  99. raise ValueError("No shape file are given")
  100. if type == "unsorted":
  101. retval = UnsortedBatchSampler(
  102. batch_size=batch_size, key_file=shape_files[0], drop_last=drop_last
  103. )
  104. elif type == "sorted":
  105. retval = SortedBatchSampler(
  106. batch_size=batch_size,
  107. shape_file=shape_files[0],
  108. sort_in_batch=sort_in_batch,
  109. sort_batch=sort_batch,
  110. drop_last=drop_last,
  111. )
  112. elif type == "folded":
  113. if len(fold_lengths) != len(shape_files):
  114. raise ValueError(
  115. f"The number of fold_lengths must be equal to "
  116. f"the number of shape_files: "
  117. f"{len(fold_lengths)} != {len(shape_files)}"
  118. )
  119. retval = FoldedBatchSampler(
  120. batch_size=batch_size,
  121. shape_files=shape_files,
  122. fold_lengths=fold_lengths,
  123. sort_in_batch=sort_in_batch,
  124. sort_batch=sort_batch,
  125. drop_last=drop_last,
  126. min_batch_size=min_batch_size,
  127. utt2category_file=utt2category_file,
  128. )
  129. elif type == "numel":
  130. retval = NumElementsBatchSampler(
  131. batch_bins=batch_bins,
  132. shape_files=shape_files,
  133. sort_in_batch=sort_in_batch,
  134. sort_batch=sort_batch,
  135. drop_last=drop_last,
  136. padding=padding,
  137. min_batch_size=min_batch_size,
  138. )
  139. elif type == "length":
  140. retval = LengthBatchSampler(
  141. batch_bins=batch_bins,
  142. shape_files=shape_files,
  143. sort_in_batch=sort_in_batch,
  144. sort_batch=sort_batch,
  145. drop_last=drop_last,
  146. padding=padding,
  147. min_batch_size=min_batch_size,
  148. )
  149. else:
  150. raise ValueError(f"Not supported: {type}")
  151. return retval