sequence_iter_factory.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import logging
  2. import numpy as np
  3. import torch
  4. from torch.utils.data import DataLoader
  5. from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
  6. from funasr.datasets.small_datasets.dataset import ESPnetDataset
  7. from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
  8. from funasr.datasets.small_datasets.preprocessor import build_preprocess
  9. from funasr.iterators.abs_iter_factory import AbsIterFactory
  10. from funasr.samplers.abs_sampler import AbsSampler
  11. class RawSampler(AbsSampler):
  12. def __init__(self, batches):
  13. self.batches = batches
  14. def __len__(self):
  15. return len(self.batches)
  16. def __iter__(self):
  17. return iter(self.batches)
  18. def generate(self, seed):
  19. return list(self.batches)
  20. class SequenceIterFactory(AbsIterFactory):
  21. """Build iterator for each epoch, modified from ESPnet
  22. """
  23. def __init__(self, args, mode="train"):
  24. # preprocess
  25. preprocess_fn = build_preprocess(args, train=mode == "train")
  26. # collate
  27. if args.task_name in ["punc", "lm"]:
  28. collate_fn = CommonCollateFn(int_pad_value=0)
  29. else:
  30. collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  31. # dataset
  32. dest_sample_rate = args.frontend_conf["fs"] if (
  33. args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
  34. if mode == "train":
  35. data_path_and_name_and_type = args.train_data_path_and_name_and_type
  36. shape_files = args.train_shape_file
  37. elif mode == "valid":
  38. data_path_and_name_and_type = args.valid_data_path_and_name_and_type
  39. shape_files = args.valid_shape_file
  40. else:
  41. raise NotImplementedError(f"mode={mode}")
  42. dataset = ESPnetDataset(
  43. data_path_and_name_and_type,
  44. preprocess=preprocess_fn,
  45. dest_sample_rate=dest_sample_rate,
  46. speed_perturb=args.speed_perturb if mode=="train" else None,
  47. )
  48. # sampler
  49. dataset_conf = args.dataset_conf
  50. batch_sampler = LengthBatchSampler(
  51. batch_bins=dataset_conf["batch_conf"]["batch_size"] * args.ngpu,
  52. shape_files=shape_files,
  53. sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
  54. sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
  55. drop_last=False,
  56. padding=True,
  57. )
  58. batches = list(batch_sampler)
  59. bs_list = [len(batch) for batch in batches]
  60. logging.info(f"[{mode}] dataset:\n{dataset}")
  61. logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
  62. logging.info(
  63. f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
  64. f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
  65. )
  66. if args.scheduler == "tri_stage" and mode == "train":
  67. args.max_update = len(bs_list) * args.max_epoch
  68. logging.info("Max update: {}".format(args.max_update))
  69. if args.distributed and mode=="train":
  70. world_size = torch.distributed.get_world_size()
  71. rank = torch.distributed.get_rank()
  72. for batch in batches:
  73. if len(batch) < world_size:
  74. raise RuntimeError(
  75. f"The batch-size must be equal or more than world_size: "
  76. f"{len(batch)} < {world_size}"
  77. )
  78. batches = [batch[rank::world_size] for batch in batches]
  79. if not isinstance(batches, AbsSampler):
  80. self.sampler = RawSampler(batches)
  81. else:
  82. self.sampler = batches
  83. self.dataset = dataset
  84. self.num_iters_per_epoch = None
  85. self.shuffle = mode == "train"
  86. self.seed = args.seed
  87. self.num_workers = args.dataset_conf.get("num_workers", 8)
  88. self.collate_fn = collate_fn
  89. self.pin_memory = args.ngpu > 0
  90. def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
  91. if shuffle is None:
  92. shuffle = self.shuffle
  93. if self.num_iters_per_epoch is not None:
  94. N = len(self.sampler)
  95. # If corpus size is larger than the num_per_epoch
  96. if self.num_iters_per_epoch < N:
  97. N = len(self.sampler)
  98. real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
  99. if offset >= self.num_iters_per_epoch:
  100. current_batches = self.sampler.generate(real_epoch + self.seed)
  101. if shuffle:
  102. np.random.RandomState(real_epoch + self.seed).shuffle(
  103. current_batches
  104. )
  105. batches = current_batches[
  106. offset - self.num_iters_per_epoch: offset
  107. ]
  108. else:
  109. prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
  110. current_batches = self.sampler.generate(real_epoch + self.seed)
  111. if shuffle:
  112. np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
  113. prev_batches
  114. )
  115. np.random.RandomState(real_epoch + self.seed).shuffle(
  116. current_batches
  117. )
  118. batches = (
  119. prev_batches[offset - self.num_iters_per_epoch:]
  120. + current_batches[:offset]
  121. )
  122. # If corpus size is less than the num_per_epoch
  123. else:
  124. _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
  125. _remain = self.num_iters_per_epoch
  126. batches = []
  127. current_batches = self.sampler.generate(_epoch + self.seed)
  128. if shuffle:
  129. np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
  130. while _remain > 0:
  131. _batches = current_batches[_cursor: _cursor + _remain]
  132. batches += _batches
  133. if _cursor + _remain >= N:
  134. _epoch += 1
  135. _cursor = 0
  136. current_batches = self.sampler.generate(_epoch + self.seed)
  137. if shuffle:
  138. np.random.RandomState(_epoch + self.seed).shuffle(
  139. current_batches
  140. )
  141. else:
  142. _cursor = _cursor + _remain
  143. _remain -= len(_batches)
  144. assert len(batches) == self.num_iters_per_epoch
  145. else:
  146. batches = self.sampler.generate(epoch + self.seed)
  147. if shuffle:
  148. np.random.RandomState(epoch + self.seed).shuffle(batches)
  149. # For backward compatibility for pytorch DataLoader
  150. if self.collate_fn is not None:
  151. kwargs = dict(collate_fn=self.collate_fn)
  152. else:
  153. kwargs = {}
  154. return DataLoader(
  155. dataset=self.dataset,
  156. batch_sampler=batches,
  157. num_workers=self.num_workers,
  158. pin_memory=self.pin_memory,
  159. **kwargs,
  160. )