| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import collections
- import sys
- from torch import multiprocessing
- def get_size(obj, seen=None):
- """Recursively finds size of objects
- Taken from https://github.com/bosswissam/pysize
- """
- size = sys.getsizeof(obj)
- if seen is None:
- seen = set()
- obj_id = id(obj)
- if obj_id in seen:
- return 0
- # Important mark as seen *before* entering recursion to gracefully handle
- # self-referential objects
- seen.add(obj_id)
- if isinstance(obj, dict):
- size += sum([get_size(v, seen) for v in obj.values()])
- size += sum([get_size(k, seen) for k in obj.keys()])
- elif hasattr(obj, "__dict__"):
- size += get_size(obj.__dict__, seen)
- elif isinstance(obj, (list, set, tuple)):
- size += sum([get_size(i, seen) for i in obj])
- return size
- class SizedDict(collections.abc.MutableMapping):
- def __init__(self, shared: bool = False, data: dict = None):
- if data is None:
- data = {}
- if shared:
- # NOTE(kamo): Don't set manager as a field because Manager, which includes
- # weakref object, causes following error with method="spawn",
- # "TypeError: can't pickle weakref objects"
- self.cache = multiprocessing.Manager().dict(**data)
- else:
- self.manager = None
- self.cache = dict(**data)
- self.size = 0
- def __setitem__(self, key, value):
- if key in self.cache:
- self.size -= get_size(self.cache[key])
- else:
- self.size += sys.getsizeof(key)
- self.size += get_size(value)
- self.cache[key] = value
- def __getitem__(self, key):
- return self.cache[key]
- def __delitem__(self, key):
- self.size -= get_size(self.cache[key])
- self.size -= sys.getsizeof(key)
- del self.cache[key]
- def __iter__(self):
- return iter(self.cache)
- def __contains__(self, key):
- return key in self.cache
- def __len__(self):
- return len(self.cache)
|