weak.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import collections.abc as _collections_abc
  4. import weakref
  5. from collections.abc import Mapping, MutableMapping
  6. from weakref import ref
  7. from torch import Tensor
  8. WeakRef = ref
  9. __all__ = [
  10. "TensorWeakRef",
  11. "WeakIdRef",
  12. "WeakIdKeyDictionary",
  13. "WeakTensorKeyDictionary",
  14. ]
  15. # TODO: make weakref properly thread safe following
  16. # https://github.com/python/cpython/pull/125325
  17. class _IterationGuard:
  18. # This context manager registers itself in the current iterators of the
  19. # weak container, such as to delay all removals until the context manager
  20. # exits.
  21. # This technique should be relatively thread-safe (since sets are).
  22. def __init__(self, weakcontainer) -> None:
  23. # Don't create cycles
  24. self.weakcontainer = ref(weakcontainer)
  25. def __enter__(self):
  26. w = self.weakcontainer()
  27. if w is not None:
  28. w._iterating.add(self)
  29. return self
  30. def __exit__(self, e, t, b):
  31. w = self.weakcontainer()
  32. if w is not None:
  33. s = w._iterating
  34. s.remove(self)
  35. if not s:
  36. w._commit_removals()
  37. # This file defines a variant of WeakKeyDictionary that overrides the hashing
  38. # behavior of the key to use object identity, rather than the builtin
  39. # __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
  40. # __eq__ implementation return a Tensor (elementwise equality), which means
  41. # you can't use them directly with the WeakKeyDictionary in standard library.
  42. #
  43. # Our implementation strategy is to create a wrapper weak key object, which we
  44. # use as a key in a stock Python dictionary. This is similar to how weakref
  45. # implements WeakKeyDictionary, but instead of using weakref.ref as the
  46. # wrapper, we use a custom wrapper that has different __eq__ and __hash__
  47. # behavior. Note that we subsequently store this weak key directly in an
  48. # ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
  49. # be a dictionary so it would have no strong references. Ensuring that
  50. # only live WeakIdKeys are in the map is handled by putting finalizers on the
  51. # original key object.
  52. # It is simpler to implement this with composition, but if we want to
  53. # directly reuse the callback mechanism on weakref, we need the weakref
  54. # and the key to be exactly the same object. Reusing the callback mechanism
  55. # minimizes the divergence between our implementation and Lib/weakref.py
  56. #
  57. # NB: Prefer using this when working with weakrefs of Tensors; e.g., do
  58. # WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
  59. # easy to get wrong cases transparently for you.
  60. class WeakIdRef(weakref.ref):
  61. __slots__ = ["_id"]
  62. def __init__(self, key, callback=None) -> None:
  63. # Unlike stock weakref, which preserves hash semantics of the
  64. # original object but lazily defers hash calls until the first
  65. # time the user attempts to hash the weakref, we can eagerly
  66. # cache the id of the key as we know this is definitely the hash
  67. # method
  68. self._id = id(key)
  69. super().__init__(key, callback) # type: ignore[call-arg]
  70. def __call__(self):
  71. r = super().__call__()
  72. # Special logic for Tensor PyObject resurrection
  73. if hasattr(r, "_fix_weakref"):
  74. r._fix_weakref() # type: ignore[union-attr]
  75. return r
  76. def __hash__(self):
  77. return self._id
  78. def __eq__(self, other):
  79. # An attractive but wrong alternate implementation is to only test if
  80. # the stored _ids match. This can lead to an ABA problem if you have:
  81. #
  82. # a1 = A()
  83. # w1 = WeakIdRef(a1)
  84. # del a1
  85. # a2 = A() # suppose it gets the same ID as a1
  86. # w2 = WeakIdRef(a2)
  87. # print(w1 == w2)
  88. #
  89. # This should be False, as a1 and a2 are unrelated (and a1 is
  90. # dead anyway)
  91. a = self()
  92. b = other()
  93. if a is not None and b is not None:
  94. return a is b
  95. return self is other
  96. # This is the same as WeakIdRef but equality is checked using hash() rather than id.
  97. # This will be equivalent to the one above except for classes where hash is not their id.
  98. class _WeakHashRef(weakref.ref):
  99. __slots__ = ["_id"]
  100. def __init__(self, key, callback=None) -> None:
  101. # Unlike stock weakref, which preserves hash semantics of the
  102. # original object but lazily defers hash calls until the first
  103. # time the user attempts to hash the weakref, we can eagerly
  104. # cache the id of the key as we know this is definitely the hash
  105. # method
  106. self._id = hash(key)
  107. super().__init__(key, callback) # type: ignore[call-arg]
  108. def __call__(self):
  109. r = super().__call__()
  110. # Special logic for Tensor PyObject resurrection
  111. if hasattr(r, "_fix_weakref"):
  112. r._fix_weakref() # type: ignore[union-attr]
  113. return r
  114. def __hash__(self):
  115. return self._id
  116. def __eq__(self, other):
  117. # Use hash equality to determine ref equality.
  118. # ScriptObject implements __hash__ to return the wrapped IValue's id, so
  119. # this is equivalent to doing an identity comparison.
  120. a = self()
  121. b = other()
  122. if a is not None and b is not None:
  123. return hash(a) == hash(b)
  124. return self is other
  125. # This is directly adapted from cpython/Lib/weakref.py
  126. class WeakIdKeyDictionary(MutableMapping):
  127. def __init__(self, dict=None, ref_type=WeakIdRef) -> None: # CHANGED
  128. self.data = {}
  129. self.ref_type = ref_type # CHANGED
  130. def remove(k, selfref=ref(self)) -> None:
  131. self = selfref()
  132. if self is not None:
  133. if self._iterating:
  134. self._pending_removals.append(k)
  135. else:
  136. try:
  137. del self.data[k]
  138. except KeyError:
  139. pass
  140. self._remove = remove
  141. # A list of dead weakrefs (keys to be removed)
  142. self._pending_removals = []
  143. self._iterating = set()
  144. self._dirty_len = False
  145. if dict is not None:
  146. self.update(dict)
  147. def _commit_removals(self) -> None:
  148. # NOTE: We don't need to call this method before mutating the dict,
  149. # because a dead weakref never compares equal to a live weakref,
  150. # even if they happened to refer to equal objects.
  151. # However, it means keys may already have been removed.
  152. pop = self._pending_removals.pop
  153. d = self.data
  154. while True:
  155. try:
  156. key = pop()
  157. except IndexError:
  158. return
  159. try:
  160. del d[key]
  161. except KeyError:
  162. pass
  163. def _scrub_removals(self) -> None:
  164. d = self.data
  165. self._pending_removals = [k for k in self._pending_removals if k in d]
  166. self._dirty_len = False
  167. def __delitem__(self, key) -> None:
  168. self._dirty_len = True
  169. del self.data[self.ref_type(key)] # CHANGED
  170. def __getitem__(self, key):
  171. return self.data[self.ref_type(key)] # CHANGED
  172. def __len__(self) -> int:
  173. if self._dirty_len and self._pending_removals:
  174. # self._pending_removals may still contain keys which were
  175. # explicitly removed, we have to scrub them (see issue #21173).
  176. self._scrub_removals()
  177. return len(self.data) - len(self._pending_removals)
  178. def __repr__(self) -> str:
  179. return f"<{self.__class__.__name__} at {id(self):#x}>"
  180. def __setitem__(self, key, value) -> None:
  181. self.data[self.ref_type(key, self._remove)] = value # CHANGED
  182. def copy(self):
  183. new = WeakIdKeyDictionary()
  184. with _IterationGuard(self):
  185. for key, value in self.data.items():
  186. o = key()
  187. if o is not None:
  188. new[o] = value
  189. return new
  190. __copy__ = copy
  191. def __deepcopy__(self, memo):
  192. from copy import deepcopy
  193. new = self.__class__()
  194. with _IterationGuard(self):
  195. for key, value in self.data.items():
  196. o = key()
  197. if o is not None:
  198. new[o] = deepcopy(value, memo)
  199. return new
  200. def get(self, key, default=None):
  201. return self.data.get(self.ref_type(key), default) # CHANGED
  202. def __contains__(self, key) -> bool:
  203. try:
  204. wr = self.ref_type(key) # CHANGED
  205. except TypeError:
  206. return False
  207. return wr in self.data
  208. def items(self):
  209. with _IterationGuard(self):
  210. for wr, value in self.data.items():
  211. key = wr()
  212. if key is not None:
  213. yield key, value
  214. def keys(self):
  215. with _IterationGuard(self):
  216. for wr in self.data:
  217. obj = wr()
  218. if obj is not None:
  219. yield obj
  220. __iter__ = keys
  221. def values(self):
  222. with _IterationGuard(self):
  223. for wr, value in self.data.items():
  224. if wr() is not None:
  225. yield value
  226. def keyrefs(self):
  227. """Return a list of weak references to the keys.
  228. The references are not guaranteed to be 'live' at the time
  229. they are used, so the result of calling the references needs
  230. to be checked before being used. This can be used to avoid
  231. creating references that will cause the garbage collector to
  232. keep the keys around longer than needed.
  233. """
  234. return list(self.data)
  235. def popitem(self):
  236. self._dirty_len = True
  237. while True:
  238. key, value = self.data.popitem()
  239. o = key()
  240. if o is not None:
  241. return o, value
  242. # pyrefly: ignore [bad-override]
  243. def pop(self, key, *args):
  244. self._dirty_len = True
  245. return self.data.pop(self.ref_type(key), *args) # CHANGED
  246. def setdefault(self, key, default=None):
  247. return self.data.setdefault(
  248. self.ref_type(key, self._remove), default
  249. ) # CHANGED
  250. def update(self, dict=None, **kwargs) -> None: # type: ignore[override]
  251. d = self.data
  252. if dict is not None:
  253. if not hasattr(dict, "items"):
  254. dict = type({})(dict)
  255. for key, value in dict.items():
  256. d[self.ref_type(key, self._remove)] = value # CHANGED
  257. if kwargs:
  258. self.update(kwargs)
  259. def __ior__(self, other):
  260. self.update(other)
  261. return self
  262. def __or__(self, other):
  263. if isinstance(other, _collections_abc.Mapping):
  264. c = self.copy()
  265. c.update(other)
  266. return c
  267. return NotImplemented
  268. def __ror__(self, other):
  269. if isinstance(other, _collections_abc.Mapping):
  270. c = self.__class__()
  271. c.update(other)
  272. c.update(self)
  273. return c
  274. return NotImplemented
  275. # Default Mapping equality will tests keys for equality, but
  276. # we want to test ids for equality
  277. def __eq__(self, other):
  278. if not isinstance(other, Mapping):
  279. return NotImplemented
  280. return {id(k): v for k, v in self.items()} == {
  281. id(k): v for k, v in other.items()
  282. }
  283. # Convenience alias
  284. WeakTensorKeyDictionary = WeakIdKeyDictionary
  285. class TensorWeakRef:
  286. """Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
  287. ref: WeakRef[Tensor]
  288. def __init__(self, tensor: Tensor) -> None:
  289. if not isinstance(tensor, Tensor):
  290. raise AssertionError(f"expected torch.Tensor, got {type(tensor)}.")
  291. self.ref = weakref.ref(tensor)
  292. def __call__(self):
  293. out = self.ref()
  294. if out is None:
  295. return out
  296. if not isinstance(out, Tensor):
  297. raise AssertionError(f"expected torch.Tensor, got {type(out)}.")
  298. # TODO, add _fix_weakref type binding
  299. out._fix_weakref() # type: ignore[attr-defined]
  300. return out