device_funcs.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import dataclasses
  2. import warnings
  3. import numpy as np
  4. import torch
  5. def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
  6. """Change the device of object recursively"""
  7. if isinstance(data, dict):
  8. return {
  9. k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()
  10. }
  11. elif dataclasses.is_dataclass(data) and not isinstance(data, type):
  12. return type(data)(
  13. *[
  14. to_device(v, device, dtype, non_blocking, copy)
  15. for v in dataclasses.astuple(data)
  16. ]
  17. )
  18. # maybe namedtuple. I don't know the correct way to judge namedtuple.
  19. elif isinstance(data, tuple) and type(data) is not tuple:
  20. return type(data)(
  21. *[to_device(o, device, dtype, non_blocking, copy) for o in data]
  22. )
  23. elif isinstance(data, (list, tuple)):
  24. return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
  25. elif isinstance(data, np.ndarray):
  26. return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
  27. elif isinstance(data, torch.Tensor):
  28. return data.to(device, dtype, non_blocking, copy)
  29. else:
  30. return data
  31. def force_gatherable(data, device):
  32. """Change object to gatherable in torch.nn.DataParallel recursively
  33. The difference from to_device() is changing to torch.Tensor if float or int
  34. value is found.
  35. The restriction to the returned value in DataParallel:
  36. The object must be
  37. - torch.cuda.Tensor
  38. - 1 or more dimension. 0-dimension-tensor sends warning.
  39. or a list, tuple, dict.
  40. """
  41. if isinstance(data, dict):
  42. return {k: force_gatherable(v, device) for k, v in data.items()}
  43. # DataParallel can't handle NamedTuple well
  44. elif isinstance(data, tuple) and type(data) is not tuple:
  45. return type(data)(*[force_gatherable(o, device) for o in data])
  46. elif isinstance(data, (list, tuple, set)):
  47. return type(data)(force_gatherable(v, device) for v in data)
  48. elif isinstance(data, np.ndarray):
  49. return force_gatherable(torch.from_numpy(data), device)
  50. elif isinstance(data, torch.Tensor):
  51. if data.dim() == 0:
  52. # To 1-dim array
  53. data = data[None]
  54. return data.to(device)
  55. elif isinstance(data, float):
  56. return torch.tensor([data], dtype=torch.float, device=device)
  57. elif isinstance(data, int):
  58. return torch.tensor([data], dtype=torch.long, device=device)
  59. elif data is None:
  60. return None
  61. else:
  62. warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
  63. return data