| 123456789101112131415161718192021222324252627282930313233 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- # Copyright 2019 Shigeki Karita
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Repeat the same layer definition."""
- import torch
- class MultiSequential(torch.nn.Sequential):
- """Multi-input multi-output torch.nn.Sequential."""
- def forward(self, *args):
- """Repeat."""
- for m in self:
- args = m(*args)
- return args
- def repeat(N, fn):
- """Repeat module N times.
- Args:
- N (int): Number of repeat time.
- fn (Callable): Function to generate module.
- Returns:
- MultiSequential: Repeated model instance.
- """
- return MultiSequential(*[fn(n) for n in range(N)])
|