forward_adaptor.py 971 B

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. class ForwardAdaptor(torch.nn.Module):
  3. """Wrapped module to parallelize specified method
  4. torch.nn.DataParallel parallelizes only "forward()"
  5. and, maybe, the method having the other name can't be applied
  6. except for wrapping the module just like this class.
  7. Examples:
  8. >>> class A(torch.nn.Module):
  9. ... def foo(self, x):
  10. ... ...
  11. >>> model = A()
  12. >>> model = ForwardAdaptor(model, "foo")
  13. >>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
  14. >>> x = torch.randn(2, 10)
  15. >>> model(x)
  16. """
  17. def __init__(self, module: torch.nn.Module, name: str):
  18. super().__init__()
  19. self.module = module
  20. self.name = name
  21. if not hasattr(module, name):
  22. raise ValueError(f"{module} doesn't have {name}")
  23. def forward(self, *args, **kwargs):
  24. func = getattr(self.module, self.name)
  25. return func(*args, **kwargs)