forward_adaptor.py 1.0 KB

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. from typeguard import check_argument_types
  3. class ForwardAdaptor(torch.nn.Module):
  4. """Wrapped module to parallelize specified method
  5. torch.nn.DataParallel parallelizes only "forward()"
  6. and, maybe, the method having the other name can't be applied
  7. except for wrapping the module just like this class.
  8. Examples:
  9. >>> class A(torch.nn.Module):
  10. ... def foo(self, x):
  11. ... ...
  12. >>> model = A()
  13. >>> model = ForwardAdaptor(model, "foo")
  14. >>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
  15. >>> x = torch.randn(2, 10)
  16. >>> model(x)
  17. """
  18. def __init__(self, module: torch.nn.Module, name: str):
  19. assert check_argument_types()
  20. super().__init__()
  21. self.module = module
  22. self.name = name
  23. if not hasattr(module, name):
  24. raise ValueError(f"{module} doesn't have {name}")
  25. def forward(self, *args, **kwargs):
  26. func = getattr(self.module, self.name)
  27. return func(*args, **kwargs)