repeat.py 703 B

123456789101112131415161718192021222324252627282930313233
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Repeat the same layer definition."""
  6. import torch
  7. class MultiSequential(torch.nn.Sequential):
  8. """Multi-input multi-output torch.nn.Sequential."""
  9. def forward(self, *args):
  10. """Repeat."""
  11. for m in self:
  12. args = m(*args)
  13. return args
  14. def repeat(N, fn):
  15. """Repeat module N times.
  16. Args:
  17. N (int): Number of repeat time.
  18. fn (Callable): Function to generate module.
  19. Returns:
  20. MultiSequential: Repeated model instance.
  21. """
  22. return MultiSequential(*[fn(n) for n in range(N)])