build_batch_sampler.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from typing import List
  2. from typing import Dict
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. from typeguard import check_argument_types
  7. from typeguard import check_return_type
  8. from funasr.samplers.abs_sampler import AbsSampler
  9. from funasr.samplers.folded_batch_sampler import FoldedBatchSampler
  10. from funasr.samplers.length_batch_sampler import LengthBatchSampler
  11. from funasr.samplers.num_elements_batch_sampler import NumElementsBatchSampler
  12. from funasr.samplers.sorted_batch_sampler import SortedBatchSampler
  13. from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
  14. BATCH_TYPES = dict(
  15. unsorted="UnsortedBatchSampler has nothing in particular feature and "
  16. "just creates mini-batches which has constant batch_size. "
  17. "This sampler doesn't require any length "
  18. "information for each feature. "
  19. "'key_file' is just a text file which describes each sample name."
  20. "\n\n"
  21. " utterance_id_a\n"
  22. " utterance_id_b\n"
  23. " utterance_id_c\n"
  24. "\n"
  25. "The fist column is referred, so 'shape file' can be used, too.\n\n"
  26. " utterance_id_a 100,80\n"
  27. " utterance_id_b 400,80\n"
  28. " utterance_id_c 512,80\n",
  29. sorted="SortedBatchSampler sorts samples by the length of the first input "
  30. " in order to make each sample in a mini-batch has close length. "
  31. "This sampler requires a text file which describes the length for each sample "
  32. "\n\n"
  33. " utterance_id_a 1000\n"
  34. " utterance_id_b 1453\n"
  35. " utterance_id_c 1241\n"
  36. "\n"
  37. "The first element of feature dimensions is referred, "
  38. "so 'shape_file' can be also used.\n\n"
  39. " utterance_id_a 1000,80\n"
  40. " utterance_id_b 1453,80\n"
  41. " utterance_id_c 1241,80\n",
  42. folded="FoldedBatchSampler supports variable batch_size. "
  43. "The batch_size is decided by\n"
  44. " batch_size = base_batch_size // (L // fold_length)\n"
  45. "L is referred to the largest length of samples in the mini-batch. "
  46. "This samples requires length information as same as SortedBatchSampler\n",
  47. length="LengthBatchSampler supports variable batch_size. "
  48. "This sampler makes mini-batches which have same number of 'bins' as possible "
  49. "counting by the total lengths of each feature in the mini-batch. "
  50. "This sampler requires a text file which describes the length for each sample. "
  51. "\n\n"
  52. " utterance_id_a 1000\n"
  53. " utterance_id_b 1453\n"
  54. " utterance_id_c 1241\n"
  55. "\n"
  56. "The first element of feature dimensions is referred, "
  57. "so 'shape_file' can be also used.\n\n"
  58. " utterance_id_a 1000,80\n"
  59. " utterance_id_b 1453,80\n"
  60. " utterance_id_c 1241,80\n",
  61. numel="NumElementsBatchSampler supports variable batch_size. "
  62. "Just like LengthBatchSampler, this sampler makes mini-batches"
  63. " which have same number of 'bins' as possible "
  64. "counting by the total number of elements of each feature "
  65. "instead of the length. "
  66. "Thus this sampler requires the full information of the dimension of the features. "
  67. "\n\n"
  68. " utterance_id_a 1000,80\n"
  69. " utterance_id_b 1453,80\n"
  70. " utterance_id_c 1241,80\n",
  71. )
  72. def build_batch_sampler(
  73. type: str,
  74. batch_size: int,
  75. batch_bins: int,
  76. shape_files: Union[Tuple[str, ...], List[str], Dict],
  77. sort_in_batch: str = "descending",
  78. sort_batch: str = "ascending",
  79. drop_last: bool = False,
  80. min_batch_size: int = 1,
  81. fold_lengths: Sequence[int] = (),
  82. padding: bool = True,
  83. utt2category_file: str = None,
  84. ) -> AbsSampler:
  85. """Helper function to instantiate BatchSampler.
  86. Args:
  87. type: mini-batch type. "unsorted", "sorted", "folded", "numel", or, "length"
  88. batch_size: The mini-batch size. Used for "unsorted", "sorted", "folded" mode
  89. batch_bins: Used for "numel" model
  90. shape_files: Text files describing the length and dimension
  91. of each features. e.g. uttA 1330,80
  92. sort_in_batch:
  93. sort_batch:
  94. drop_last:
  95. min_batch_size: Used for "numel" or "folded" mode
  96. fold_lengths: Used for "folded" mode
  97. padding: Whether sequences are input as a padded tensor or not.
  98. used for "numel" mode
  99. """
  100. assert check_argument_types()
  101. if len(shape_files) == 0:
  102. raise ValueError("No shape file are given")
  103. if type == "unsorted":
  104. retval = UnsortedBatchSampler(
  105. batch_size=batch_size, key_file=shape_files[0], drop_last=drop_last
  106. )
  107. elif type == "sorted":
  108. retval = SortedBatchSampler(
  109. batch_size=batch_size,
  110. shape_file=shape_files[0],
  111. sort_in_batch=sort_in_batch,
  112. sort_batch=sort_batch,
  113. drop_last=drop_last,
  114. )
  115. elif type == "folded":
  116. if len(fold_lengths) != len(shape_files):
  117. raise ValueError(
  118. f"The number of fold_lengths must be equal to "
  119. f"the number of shape_files: "
  120. f"{len(fold_lengths)} != {len(shape_files)}"
  121. )
  122. retval = FoldedBatchSampler(
  123. batch_size=batch_size,
  124. shape_files=shape_files,
  125. fold_lengths=fold_lengths,
  126. sort_in_batch=sort_in_batch,
  127. sort_batch=sort_batch,
  128. drop_last=drop_last,
  129. min_batch_size=min_batch_size,
  130. utt2category_file=utt2category_file,
  131. )
  132. elif type == "numel":
  133. retval = NumElementsBatchSampler(
  134. batch_bins=batch_bins,
  135. shape_files=shape_files,
  136. sort_in_batch=sort_in_batch,
  137. sort_batch=sort_batch,
  138. drop_last=drop_last,
  139. padding=padding,
  140. min_batch_size=min_batch_size,
  141. )
  142. elif type == "length":
  143. retval = LengthBatchSampler(
  144. batch_bins=batch_bins,
  145. shape_files=shape_files,
  146. sort_in_batch=sort_in_batch,
  147. sort_batch=sort_batch,
  148. drop_last=drop_last,
  149. padding=padding,
  150. min_batch_size=min_batch_size,
  151. )
  152. else:
  153. raise ValueError(f"Not supported: {type}")
  154. assert check_return_type(retval)
  155. return retval