recursive_op.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. """Torch utility module."""
  2. import torch
  3. if torch.distributed.is_available():
  4. from torch.distributed import ReduceOp
  5. def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
  6. assert weight.dim() == 1, weight.size()
  7. if isinstance(obj, (tuple, list)):
  8. return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
  9. elif isinstance(obj, dict):
  10. return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
  11. elif isinstance(obj, torch.Tensor):
  12. assert obj.size() == weight.size(), (obj.size(), weight.size())
  13. obj = (obj * weight.type(obj.dtype)).sum()
  14. if distributed:
  15. torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
  16. return obj
  17. elif obj is None:
  18. return None
  19. else:
  20. raise ValueError(type(obj))
  21. def recursive_divide(a, b: torch.Tensor):
  22. if isinstance(a, (tuple, list)):
  23. return type(a)(recursive_divide(v, b) for v in a)
  24. elif isinstance(a, dict):
  25. return {k: recursive_divide(v, b) for k, v in a.items()}
  26. elif isinstance(a, torch.Tensor):
  27. assert a.size() == b.size(), (a.size(), b.size())
  28. return a / b.type(a.dtype)
  29. elif a is None:
  30. return None
  31. else:
  32. raise ValueError(type(a))
  33. def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
  34. obj = recursive_sum(obj, weight, distributed)
  35. weight = weight.sum()
  36. if distributed:
  37. torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
  38. # Normalize weight to be sum-to-1
  39. obj = recursive_divide(obj, weight)
  40. return obj, weight