|
|
@@ -14,25 +14,38 @@ import torch
|
|
|
class MultiSequential(torch.nn.Sequential):
|
|
|
"""Multi-input multi-output torch.nn.Sequential."""
|
|
|
|
|
|
+ def __init__(self, *args, layer_drop_rate=0.0):
|
|
|
+ """Initialize MultiSequential with layer_drop.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ layer_drop_rate (float): Probability of dropping out each fn (layer).
|
|
|
+
|
|
|
+ """
|
|
|
+ super(MultiSequential, self).__init__(*args)
|
|
|
+ self.layer_drop_rate = layer_drop_rate
|
|
|
+
|
|
|
def forward(self, *args):
|
|
|
"""Repeat."""
|
|
|
- for m in self:
|
|
|
- args = m(*args)
|
|
|
+ _probs = torch.empty(len(self)).uniform_()
|
|
|
+ for idx, m in enumerate(self):
|
|
|
+ if not self.training or (_probs[idx] >= self.layer_drop_rate):
|
|
|
+ args = m(*args)
|
|
|
return args
|
|
|
|
|
|
|
|
|
-def repeat(N, fn):
|
|
|
+def repeat(N, fn, layer_drop_rate=0.0):
|
|
|
"""Repeat module N times.
|
|
|
|
|
|
Args:
|
|
|
N (int): Number of repeat time.
|
|
|
fn (Callable): Function to generate module.
|
|
|
+ layer_drop_rate (float): Probability of dropping out each fn (layer).
|
|
|
|
|
|
Returns:
|
|
|
MultiSequential: Repeated model instance.
|
|
|
|
|
|
"""
|
|
|
- return MultiSequential(*[fn(n) for n in range(N)])
|
|
|
+ return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)
|
|
|
|
|
|
|
|
|
class MultiBlocks(torch.nn.Module):
|