| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- import logging
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
- from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
- from funasr.datasets.small_datasets.dataset import ESPnetDataset
- from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
- from funasr.datasets.small_datasets.preprocessor import build_preprocess
- from funasr.iterators.abs_iter_factory import AbsIterFactory
- from funasr.samplers.abs_sampler import AbsSampler
- class RawSampler(AbsSampler):
- def __init__(self, batches):
- self.batches = batches
- def __len__(self):
- return len(self.batches)
- def __iter__(self):
- return iter(self.batches)
- def generate(self, seed):
- return list(self.batches)
- class SequenceIterFactory(AbsIterFactory):
- """Build iterator for each epoch, modified from ESPnet
- """
- def __init__(self, args, mode="train"):
- # preprocess
- preprocess_fn = build_preprocess(args, train=mode == "train")
- # collate
- if args.task_name in ["punc", "lm"]:
- collate_fn = CommonCollateFn(int_pad_value=0)
- else:
- collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
- # dataset
- dest_sample_rate = args.frontend_conf["fs"] if (
- args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
- if mode == "train":
- data_path_and_name_and_type = args.train_data_path_and_name_and_type
- shape_files = args.train_shape_file
- elif mode == "valid":
- data_path_and_name_and_type = args.valid_data_path_and_name_and_type
- shape_files = args.valid_shape_file
- else:
- raise NotImplementedError(f"mode={mode}")
- dataset = ESPnetDataset(
- data_path_and_name_and_type,
- preprocess=preprocess_fn,
- dest_sample_rate=dest_sample_rate,
- speed_perturb=args.speed_perturb if mode=="train" else None,
- )
- # sampler
- dataset_conf = args.dataset_conf
- batch_sampler = LengthBatchSampler(
- batch_bins=dataset_conf["batch_conf"]["batch_size"] * args.ngpu,
- shape_files=shape_files,
- sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
- sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
- drop_last=False,
- padding=True,
- )
- batches = list(batch_sampler)
- bs_list = [len(batch) for batch in batches]
- logging.info(f"[{mode}] dataset:\n{dataset}")
- logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
- logging.info(
- f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
- f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
- )
- if args.scheduler == "tri_stage" and mode == "train":
- args.max_update = len(bs_list) * args.max_epoch
- logging.info("Max update: {}".format(args.max_update))
- if args.distributed and mode=="train":
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
- for batch in batches:
- if len(batch) < world_size:
- raise RuntimeError(
- f"The batch-size must be equal or more than world_size: "
- f"{len(batch)} < {world_size}"
- )
- batches = [batch[rank::world_size] for batch in batches]
- if not isinstance(batches, AbsSampler):
- self.sampler = RawSampler(batches)
- else:
- self.sampler = batches
- self.dataset = dataset
- self.num_iters_per_epoch = None
- self.shuffle = mode == "train"
- self.seed = args.seed
- self.num_workers = args.dataset_conf.get("num_workers", 8)
- self.collate_fn = collate_fn
- self.pin_memory = args.ngpu > 0
- def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
- if shuffle is None:
- shuffle = self.shuffle
- if self.num_iters_per_epoch is not None:
- N = len(self.sampler)
- # If corpus size is larger than the num_per_epoch
- if self.num_iters_per_epoch < N:
- N = len(self.sampler)
- real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
- if offset >= self.num_iters_per_epoch:
- current_batches = self.sampler.generate(real_epoch + self.seed)
- if shuffle:
- np.random.RandomState(real_epoch + self.seed).shuffle(
- current_batches
- )
- batches = current_batches[
- offset - self.num_iters_per_epoch: offset
- ]
- else:
- prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
- current_batches = self.sampler.generate(real_epoch + self.seed)
- if shuffle:
- np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
- prev_batches
- )
- np.random.RandomState(real_epoch + self.seed).shuffle(
- current_batches
- )
- batches = (
- prev_batches[offset - self.num_iters_per_epoch:]
- + current_batches[:offset]
- )
- # If corpus size is less than the num_per_epoch
- else:
- _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
- _remain = self.num_iters_per_epoch
- batches = []
- current_batches = self.sampler.generate(_epoch + self.seed)
- if shuffle:
- np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
- while _remain > 0:
- _batches = current_batches[_cursor: _cursor + _remain]
- batches += _batches
- if _cursor + _remain >= N:
- _epoch += 1
- _cursor = 0
- current_batches = self.sampler.generate(_epoch + self.seed)
- if shuffle:
- np.random.RandomState(_epoch + self.seed).shuffle(
- current_batches
- )
- else:
- _cursor = _cursor + _remain
- _remain -= len(_batches)
- assert len(batches) == self.num_iters_per_epoch
- else:
- batches = self.sampler.generate(epoch + self.seed)
- if shuffle:
- np.random.RandomState(epoch + self.seed).shuffle(batches)
- # For backward compatibility for pytorch DataLoader
- if self.collate_fn is not None:
- kwargs = dict(collate_fn=self.collate_fn)
- else:
- kwargs = {}
- return DataLoader(
- dataset=self.dataset,
- batch_sampler=batches,
- num_workers=self.num_workers,
- pin_memory=self.pin_memory,
- **kwargs,
- )
|