sequence_iter_factory.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from typing import Any
  2. from typing import Sequence
  3. from typing import Union
  4. import numpy as np
  5. from torch.utils.data import DataLoader
  6. from funasr.iterators.abs_iter_factory import AbsIterFactory
  7. from funasr.samplers.abs_sampler import AbsSampler
  8. class RawSampler(AbsSampler):
  9. def __init__(self, batches):
  10. self.batches = batches
  11. def __len__(self):
  12. return len(self.batches)
  13. def __iter__(self):
  14. return iter(self.batches)
  15. def generate(self, seed):
  16. return list(self.batches)
  17. class SequenceIterFactory(AbsIterFactory):
  18. """Build iterator for each epoch.
  19. This class simply creates pytorch DataLoader except for the following points:
  20. - The random seed is decided according to the number of epochs. This feature
  21. guarantees reproducibility when resuming from middle of training process.
  22. - Enable to restrict the number of samples for one epoch. This features
  23. controls the interval number between training and evaluation.
  24. """
  25. def __init__(
  26. self,
  27. dataset,
  28. batches: Union[AbsSampler, Sequence[Sequence[Any]]],
  29. num_iters_per_epoch: int = None,
  30. seed: int = 0,
  31. shuffle: bool = False,
  32. num_workers: int = 0,
  33. collate_fn=None,
  34. pin_memory: bool = False,
  35. ):
  36. if not isinstance(batches, AbsSampler):
  37. self.sampler = RawSampler(batches)
  38. else:
  39. self.sampler = batches
  40. self.dataset = dataset
  41. self.num_iters_per_epoch = num_iters_per_epoch
  42. self.shuffle = shuffle
  43. self.seed = seed
  44. self.num_workers = num_workers
  45. self.collate_fn = collate_fn
  46. # https://discuss.pytorch.org/t/what-is-the-disadvantage-of-using-pin-memory/1702
  47. self.pin_memory = pin_memory
  48. def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
  49. if shuffle is None:
  50. shuffle = self.shuffle
  51. if self.num_iters_per_epoch is not None:
  52. N = len(self.sampler)
  53. # If corpus size is larger than the num_per_epoch
  54. if self.num_iters_per_epoch < N:
  55. N = len(self.sampler)
  56. real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
  57. if offset >= self.num_iters_per_epoch:
  58. current_batches = self.sampler.generate(real_epoch + self.seed)
  59. if shuffle:
  60. np.random.RandomState(real_epoch + self.seed).shuffle(
  61. current_batches
  62. )
  63. batches = current_batches[
  64. offset - self.num_iters_per_epoch : offset
  65. ]
  66. else:
  67. prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
  68. current_batches = self.sampler.generate(real_epoch + self.seed)
  69. if shuffle:
  70. np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
  71. prev_batches
  72. )
  73. np.random.RandomState(real_epoch + self.seed).shuffle(
  74. current_batches
  75. )
  76. batches = (
  77. prev_batches[offset - self.num_iters_per_epoch :]
  78. + current_batches[:offset]
  79. )
  80. # If corpus size is less than the num_per_epoch
  81. else:
  82. _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
  83. _remain = self.num_iters_per_epoch
  84. batches = []
  85. current_batches = self.sampler.generate(_epoch + self.seed)
  86. if shuffle:
  87. np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
  88. while _remain > 0:
  89. _batches = current_batches[_cursor : _cursor + _remain]
  90. batches += _batches
  91. if _cursor + _remain >= N:
  92. _epoch += 1
  93. _cursor = 0
  94. current_batches = self.sampler.generate(_epoch + self.seed)
  95. if shuffle:
  96. np.random.RandomState(_epoch + self.seed).shuffle(
  97. current_batches
  98. )
  99. else:
  100. _cursor = _cursor + _remain
  101. _remain -= len(_batches)
  102. assert len(batches) == self.num_iters_per_epoch
  103. else:
  104. batches = self.sampler.generate(epoch + self.seed)
  105. if shuffle:
  106. np.random.RandomState(epoch + self.seed).shuffle(batches)
  107. # For backward compatibility for pytorch DataLoader
  108. if self.collate_fn is not None:
  109. kwargs = dict(collate_fn=self.collate_fn)
  110. else:
  111. kwargs = {}
  112. return DataLoader(
  113. dataset=self.dataset,
  114. batch_sampler=batches,
  115. num_workers=self.num_workers,
  116. pin_memory=self.pin_memory,
  117. **kwargs,
  118. )