|
|
@@ -0,0 +1,147 @@
|
|
|
+from typing import Iterator
|
|
|
+from typing import List
|
|
|
+from typing import Dict
|
|
|
+from typing import Tuple
|
|
|
+from typing import Union
|
|
|
+
|
|
|
+from typeguard import check_argument_types
|
|
|
+
|
|
|
+from funasr.fileio.read_text import load_num_sequence_text
|
|
|
+from funasr.samplers.abs_sampler import AbsSampler
|
|
|
+
|
|
|
+
|
|
|
+class LengthBatchSampler(AbsSampler):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ batch_bins: int,
|
|
|
+ shape_files: Union[Tuple[str, ...], List[str], Dict],
|
|
|
+ min_batch_size: int = 1,
|
|
|
+ sort_in_batch: str = "descending",
|
|
|
+ sort_batch: str = "ascending",
|
|
|
+ drop_last: bool = False,
|
|
|
+ padding: bool = True,
|
|
|
+ ):
|
|
|
+ assert check_argument_types()
|
|
|
+ assert batch_bins > 0
|
|
|
+ if sort_batch != "ascending" and sort_batch != "descending":
|
|
|
+ raise ValueError(
|
|
|
+ f"sort_batch must be ascending or descending: {sort_batch}"
|
|
|
+ )
|
|
|
+ if sort_in_batch != "descending" and sort_in_batch != "ascending":
|
|
|
+ raise ValueError(
|
|
|
+ f"sort_in_batch must be ascending or descending: {sort_in_batch}"
|
|
|
+ )
|
|
|
+
|
|
|
+ self.batch_bins = batch_bins
|
|
|
+ self.shape_files = shape_files
|
|
|
+ self.sort_in_batch = sort_in_batch
|
|
|
+ self.sort_batch = sort_batch
|
|
|
+ self.drop_last = drop_last
|
|
|
+
|
|
|
+ # utt2shape: (Length, ...)
|
|
|
+ # uttA 100,...
|
|
|
+ # uttB 201,...
|
|
|
+ if isinstance(shape_files, dict):
|
|
|
+ utt2shapes = [shape_files]
|
|
|
+ else:
|
|
|
+ utt2shapes = [
|
|
|
+ load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
|
|
|
+ ]
|
|
|
+
|
|
|
+ first_utt2shape = utt2shapes[0]
|
|
|
+ for s, d in zip(shape_files, utt2shapes):
|
|
|
+ if set(d) != set(first_utt2shape):
|
|
|
+ raise RuntimeError(
|
|
|
+ f"keys are mismatched between {s} != {shape_files[0]}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Sort samples in ascending order
|
|
|
+ # (shape order should be like (Length, Dim))
|
|
|
+ keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
|
|
|
+ if len(keys) == 0:
|
|
|
+ raise RuntimeError(f"0 lines found: {shape_files[0]}")
|
|
|
+
|
|
|
+ # Decide batch-sizes
|
|
|
+ batch_sizes = []
|
|
|
+ current_batch_keys = []
|
|
|
+ for key in keys:
|
|
|
+ current_batch_keys.append(key)
|
|
|
+ # shape: (Length, dim1, dim2, ...)
|
|
|
+ if padding:
|
|
|
+ # bins = bs x max_length
|
|
|
+ bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
|
|
|
+ else:
|
|
|
+ # bins = sum of lengths
|
|
|
+ bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
|
|
|
+
|
|
|
+ if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
|
|
|
+ batch_sizes.append(len(current_batch_keys))
|
|
|
+ current_batch_keys = []
|
|
|
+ else:
|
|
|
+ if len(current_batch_keys) != 0 and (
|
|
|
+ not self.drop_last or len(batch_sizes) == 0
|
|
|
+ ):
|
|
|
+ batch_sizes.append(len(current_batch_keys))
|
|
|
+
|
|
|
+ if len(batch_sizes) == 0:
|
|
|
+ # Maybe we can't reach here
|
|
|
+ raise RuntimeError("0 batches")
|
|
|
+
|
|
|
+ # If the last batch-size is smaller than minimum batch_size,
|
|
|
+ # the samples are redistributed to the other mini-batches
|
|
|
+ if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
|
|
|
+ for i in range(batch_sizes.pop(-1)):
|
|
|
+ batch_sizes[-(i % len(batch_sizes)) - 1] += 1
|
|
|
+
|
|
|
+ if not self.drop_last:
|
|
|
+ # Bug check
|
|
|
+ assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
|
|
|
+
|
|
|
+ # Set mini-batch
|
|
|
+ self.batch_list = []
|
|
|
+ iter_bs = iter(batch_sizes)
|
|
|
+ bs = next(iter_bs)
|
|
|
+ minibatch_keys = []
|
|
|
+ for key in keys:
|
|
|
+ minibatch_keys.append(key)
|
|
|
+ if len(minibatch_keys) == bs:
|
|
|
+ if sort_in_batch == "descending":
|
|
|
+ minibatch_keys.reverse()
|
|
|
+ elif sort_in_batch == "ascending":
|
|
|
+ # Key are already sorted in ascending
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ "sort_in_batch must be ascending"
|
|
|
+ f" or descending: {sort_in_batch}"
|
|
|
+ )
|
|
|
+ self.batch_list.append(tuple(minibatch_keys))
|
|
|
+ minibatch_keys = []
|
|
|
+ try:
|
|
|
+ bs = next(iter_bs)
|
|
|
+ except StopIteration:
|
|
|
+ break
|
|
|
+
|
|
|
+ if sort_batch == "ascending":
|
|
|
+ pass
|
|
|
+ elif sort_batch == "descending":
|
|
|
+ self.batch_list.reverse()
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"sort_batch must be ascending or descending: {sort_batch}"
|
|
|
+ )
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return (
|
|
|
+ f"{self.__class__.__name__}("
|
|
|
+ f"N-batch={len(self)}, "
|
|
|
+ f"batch_bins={self.batch_bins}, "
|
|
|
+ f"sort_in_batch={self.sort_in_batch}, "
|
|
|
+ f"sort_batch={self.sort_batch})"
|
|
|
+ )
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.batch_list)
|
|
|
+
|
|
|
+ def __iter__(self) -> Iterator[Tuple[str, ...]]:
|
|
|
+ return iter(self.batch_list)
|