sequence_iter_factory.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from typing import Any
  2. from typing import Sequence
  3. from typing import Union
  4. import numpy as np
  5. from torch.utils.data import DataLoader
  6. from typeguard import check_argument_types
  7. from funasr.iterators.abs_iter_factory import AbsIterFactory
  8. from funasr.samplers.abs_sampler import AbsSampler
  9. class RawSampler(AbsSampler):
  10. def __init__(self, batches):
  11. self.batches = batches
  12. def __len__(self):
  13. return len(self.batches)
  14. def __iter__(self):
  15. return iter(self.batches)
  16. def generate(self, seed):
  17. return list(self.batches)
  18. class SequenceIterFactory(AbsIterFactory):
  19. """Build iterator for each epoch.
  20. This class simply creates pytorch DataLoader except for the following points:
  21. - The random seed is decided according to the number of epochs. This feature
  22. guarantees reproducibility when resuming from middle of training process.
  23. - Enable to restrict the number of samples for one epoch. This features
  24. controls the interval number between training and evaluation.
  25. """
  26. def __init__(
  27. self,
  28. dataset,
  29. batches: Union[AbsSampler, Sequence[Sequence[Any]]],
  30. num_iters_per_epoch: int = None,
  31. seed: int = 0,
  32. shuffle: bool = False,
  33. num_workers: int = 0,
  34. collate_fn=None,
  35. pin_memory: bool = False,
  36. ):
  37. assert check_argument_types()
  38. if not isinstance(batches, AbsSampler):
  39. self.sampler = RawSampler(batches)
  40. else:
  41. self.sampler = batches
  42. self.dataset = dataset
  43. self.num_iters_per_epoch = num_iters_per_epoch
  44. self.shuffle = shuffle
  45. self.seed = seed
  46. self.num_workers = num_workers
  47. self.collate_fn = collate_fn
  48. # https://discuss.pytorch.org/t/what-is-the-disadvantage-of-using-pin-memory/1702
  49. self.pin_memory = pin_memory
  50. def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
  51. if shuffle is None:
  52. shuffle = self.shuffle
  53. if self.num_iters_per_epoch is not None:
  54. N = len(self.sampler)
  55. # If corpus size is larger than the num_per_epoch
  56. if self.num_iters_per_epoch < N:
  57. N = len(self.sampler)
  58. real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
  59. if offset >= self.num_iters_per_epoch:
  60. current_batches = self.sampler.generate(real_epoch + self.seed)
  61. if shuffle:
  62. np.random.RandomState(real_epoch + self.seed).shuffle(
  63. current_batches
  64. )
  65. batches = current_batches[
  66. offset - self.num_iters_per_epoch : offset
  67. ]
  68. else:
  69. prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
  70. current_batches = self.sampler.generate(real_epoch + self.seed)
  71. if shuffle:
  72. np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
  73. prev_batches
  74. )
  75. np.random.RandomState(real_epoch + self.seed).shuffle(
  76. current_batches
  77. )
  78. batches = (
  79. prev_batches[offset - self.num_iters_per_epoch :]
  80. + current_batches[:offset]
  81. )
  82. # If corpus size is less than the num_per_epoch
  83. else:
  84. _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
  85. _remain = self.num_iters_per_epoch
  86. batches = []
  87. current_batches = self.sampler.generate(_epoch + self.seed)
  88. if shuffle:
  89. np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
  90. while _remain > 0:
  91. _batches = current_batches[_cursor : _cursor + _remain]
  92. batches += _batches
  93. if _cursor + _remain >= N:
  94. _epoch += 1
  95. _cursor = 0
  96. current_batches = self.sampler.generate(_epoch + self.seed)
  97. if shuffle:
  98. np.random.RandomState(_epoch + self.seed).shuffle(
  99. current_batches
  100. )
  101. else:
  102. _cursor = _cursor + _remain
  103. _remain -= len(_batches)
  104. assert len(batches) == self.num_iters_per_epoch
  105. else:
  106. batches = self.sampler.generate(epoch + self.seed)
  107. if shuffle:
  108. np.random.RandomState(epoch + self.seed).shuffle(batches)
  109. # For backward compatibility for pytorch DataLoader
  110. if self.collate_fn is not None:
  111. kwargs = dict(collate_fn=self.collate_fn)
  112. else:
  113. kwargs = {}
  114. return DataLoader(
  115. dataset=self.dataset,
  116. batch_sampler=batches,
  117. num_workers=self.num_workers,
  118. pin_memory=self.pin_memory,
  119. **kwargs,
  120. )