sized_dict.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import collections
  2. import sys
  3. from torch import multiprocessing
  4. def get_size(obj, seen=None):
  5. """Recursively finds size of objects
  6. Taken from https://github.com/bosswissam/pysize
  7. """
  8. size = sys.getsizeof(obj)
  9. if seen is None:
  10. seen = set()
  11. obj_id = id(obj)
  12. if obj_id in seen:
  13. return 0
  14. # Important mark as seen *before* entering recursion to gracefully handle
  15. # self-referential objects
  16. seen.add(obj_id)
  17. if isinstance(obj, dict):
  18. size += sum([get_size(v, seen) for v in obj.values()])
  19. size += sum([get_size(k, seen) for k in obj.keys()])
  20. elif hasattr(obj, "__dict__"):
  21. size += get_size(obj.__dict__, seen)
  22. elif isinstance(obj, (list, set, tuple)):
  23. size += sum([get_size(i, seen) for i in obj])
  24. return size
  25. class SizedDict(collections.abc.MutableMapping):
  26. def __init__(self, shared: bool = False, data: dict = None):
  27. if data is None:
  28. data = {}
  29. if shared:
  30. # NOTE(kamo): Don't set manager as a field because Manager, which includes
  31. # weakref object, causes following error with method="spawn",
  32. # "TypeError: can't pickle weakref objects"
  33. self.cache = multiprocessing.Manager().dict(**data)
  34. else:
  35. self.manager = None
  36. self.cache = dict(**data)
  37. self.size = 0
  38. def __setitem__(self, key, value):
  39. if key in self.cache:
  40. self.size -= get_size(self.cache[key])
  41. else:
  42. self.size += sys.getsizeof(key)
  43. self.size += get_size(value)
  44. self.cache[key] = value
  45. def __getitem__(self, key):
  46. return self.cache[key]
  47. def __delitem__(self, key):
  48. self.size -= get_size(self.cache[key])
  49. self.size -= sys.getsizeof(key)
  50. del self.cache[key]
  51. def __iter__(self):
  52. return iter(self.cache)
  53. def __contains__(self, key):
  54. return key in self.cache
  55. def __len__(self):
  56. return len(self.cache)