| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import dataclasses
- import warnings
- import numpy as np
- import torch
- def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
- """Change the device of object recursively"""
- if isinstance(data, dict):
- return {
- k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()
- }
- elif dataclasses.is_dataclass(data) and not isinstance(data, type):
- return type(data)(
- *[
- to_device(v, device, dtype, non_blocking, copy)
- for v in dataclasses.astuple(data)
- ]
- )
- # maybe namedtuple. I don't know the correct way to judge namedtuple.
- elif isinstance(data, tuple) and type(data) is not tuple:
- return type(data)(
- *[to_device(o, device, dtype, non_blocking, copy) for o in data]
- )
- elif isinstance(data, (list, tuple)):
- return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
- elif isinstance(data, np.ndarray):
- return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
- elif isinstance(data, torch.Tensor):
- return data.to(device, dtype, non_blocking, copy)
- else:
- return data
- def force_gatherable(data, device):
- """Change object to gatherable in torch.nn.DataParallel recursively
- The difference from to_device() is changing to torch.Tensor if float or int
- value is found.
- The restriction to the returned value in DataParallel:
- The object must be
- - torch.cuda.Tensor
- - 1 or more dimension. 0-dimension-tensor sends warning.
- or a list, tuple, dict.
- """
- if isinstance(data, dict):
- return {k: force_gatherable(v, device) for k, v in data.items()}
- # DataParallel can't handle NamedTuple well
- elif isinstance(data, tuple) and type(data) is not tuple:
- return type(data)(*[force_gatherable(o, device) for o in data])
- elif isinstance(data, (list, tuple, set)):
- return type(data)(force_gatherable(v, device) for v in data)
- elif isinstance(data, np.ndarray):
- return force_gatherable(torch.from_numpy(data), device)
- elif isinstance(data, torch.Tensor):
- if data.dim() == 0:
- # To 1-dim array
- data = data[None]
- return data.to(device)
- elif isinstance(data, float):
- return torch.tensor([data], dtype=torch.float, device=device)
- elif isinstance(data, int):
- return torch.tensor([data], dtype=torch.long, device=device)
- elif data is None:
- return None
- else:
- warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
- return data
|