_cycles.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. # mypy: allow-untyped-defs
  2. import gc
  3. import sys
  4. from typing import Any, NamedTuple
  5. import types
  6. import weakref
  7. import json
  8. from tempfile import NamedTemporaryFile
  9. import torch
  10. from torch.cuda._memory_viz import _frames_fmt, _block_extra
  11. import atexit
  12. import logging
  13. logger = logging.getLogger(__name__)
  14. def observe_garbage(observer):
  15. enabled = True
  16. def disable() -> None:
  17. # when GC runs during exit, things like `sys` will already be unloaded
  18. # so we have to disable the callback to avoid hitting errors.
  19. nonlocal enabled
  20. enabled = False
  21. atexit.register(disable)
  22. def gc_callback(phase, info) -> None:
  23. nonlocal enabled
  24. if not enabled:
  25. return
  26. if phase == "start":
  27. gc.set_debug(gc.DEBUG_SAVEALL)
  28. elif phase == "stop":
  29. orig_trace = sys.getprofile()
  30. self_return = [False]
  31. def do_collect(*args, **kwargs):
  32. nonlocal enabled
  33. if not self_return[0]:
  34. self_return[0] = True
  35. else:
  36. sys.setprofile(orig_trace)
  37. enabled = False
  38. try:
  39. # things in gc.garbage have survived a collection
  40. # so to free them we have to collect a generation greater than them
  41. # but that might _also_ free other stuff and we don't want to miss
  42. # that stuff. So we have to now force gc at the highest level here,
  43. # report all of what we found, _then_ we can free it up.
  44. if info['generation'] != 2:
  45. gc.collect()
  46. observer(gc.garbage)
  47. gc.garbage.clear()
  48. # we have to re-run GC to clean up the cycles
  49. # we saved from before.
  50. gc.set_debug(0)
  51. before = torch.cuda.memory_allocated()
  52. gc.collect()
  53. after = torch.cuda.memory_allocated()
  54. if before != after:
  55. logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after)
  56. finally:
  57. enabled = True
  58. if orig_trace is not None:
  59. return orig_trace(*args, **kwargs)
  60. sys.setprofile(do_collect)
  61. gc.callbacks.append(gc_callback)
  62. # provide a way to disarm the callback
  63. def remove() -> None:
  64. gc.callbacks.remove(gc_callback)
  65. return remove
  66. # Function to visualize cycles adapted from refcycle:
  67. # Copyright 2013 Mark Dickinson
  68. #
  69. # Licensed under the Apache License, Version 2.0 (the "License");
  70. # you may not use this file except in compliance with the License.
  71. # You may obtain a copy of the License at
  72. #
  73. # http://www.apache.org/licenses/LICENSE-2.0
  74. #
  75. # Unless required by applicable law or agreed to in writing, software
  76. # distributed under the License is distributed on an "AS IS" BASIS,
  77. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  78. # See the License for the specific language governing permissions and
  79. # limitations under the License.
  80. def _get_cell_type():
  81. def f(x=None):
  82. return lambda: x
  83. return type(f().__closure__[0])
  84. CellType = _get_cell_type()
  85. def annotated_references(obj):
  86. """
  87. Return known information about references held by the given object.
  88. Returns a mapping from referents to lists of descriptions. Note that there
  89. may be more than one edge leading to any particular referent; hence the
  90. need for a list. Descriptions are currently strings.
  91. """
  92. references: dict[int, list[str]] = {}
  93. def add_reference(name, obj) -> None:
  94. references.setdefault(id(obj), []).append(name)
  95. def add_attrs(*attrs) -> None:
  96. for attr in attrs:
  97. if hasattr(obj, attr):
  98. add_reference(attr, getattr(obj, attr))
  99. def add_cell_references() -> None:
  100. try:
  101. add_attrs("cell_contents")
  102. except ValueError:
  103. # if cell_contents is empty,
  104. # accessing it raises ValueError
  105. # in this case there is no object to
  106. # annotate
  107. pass
  108. def add_function_references() -> None:
  109. add_attrs("__defaults__",
  110. "__closure__",
  111. "__globals__",
  112. "__code__",
  113. "__name__",
  114. "__module__",
  115. "__doc__"
  116. "__qualname__",
  117. "__annotations__",
  118. "__kwdefaults__")
  119. def add_sequence_references() -> None:
  120. for position, item in enumerate(obj):
  121. add_reference(f"[{position}]", item)
  122. def add_dict_references() -> None:
  123. for key, value in obj.items():
  124. add_reference("key", key)
  125. add_reference(f"[{repr(key)}]", value)
  126. def add_set_references() -> None:
  127. for elt in obj:
  128. add_reference("element", elt)
  129. def add_bound_method_references() -> None:
  130. add_attrs("__self__", "__func__", "im_class")
  131. def add_weakref_references() -> None:
  132. # For subclasses of weakref, we can't reliably distinguish the
  133. # callback (if any) from other attributes.
  134. if type(obj) is weakref.ref:
  135. referents = gc.get_referents(obj)
  136. if len(referents) == 1:
  137. target = referents[0]
  138. add_reference("__callback__", target)
  139. def add_frame_references() -> None:
  140. f_locals = obj.f_locals
  141. add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals")
  142. # Some badly-behaved code replaces the f_locals dict with
  143. # something that doesn't support the full dict interface. So we
  144. # only continue with the annotation if f_locals is a Python dict.
  145. if type(f_locals) is dict:
  146. for name, local in obj.f_locals.items():
  147. add_reference(f"local {name}", local)
  148. def add_getset_descriptor_references() -> None:
  149. add_attrs("__objclass__", "__name__", "__doc__")
  150. type_based_references = {
  151. tuple: add_sequence_references,
  152. list: add_sequence_references,
  153. dict: add_dict_references,
  154. set: add_set_references,
  155. frozenset: add_set_references,
  156. types.FunctionType: add_function_references,
  157. types.FrameType: add_frame_references,
  158. CellType: add_cell_references,
  159. types.MethodType: add_bound_method_references,
  160. weakref.ref: add_weakref_references,
  161. types.GetSetDescriptorType: add_getset_descriptor_references,
  162. }
  163. for type_ in type(obj).__mro__:
  164. if type_ in type_based_references:
  165. type_based_references[type_]()
  166. add_attrs("__dict__", "__class__")
  167. if isinstance(obj, type):
  168. add_attrs("__mro__")
  169. return references
  170. ###############################################################################
  171. # Object annotations.
  172. BASE_TYPES = (int, float, complex, type(None), str, bytes)
  173. FRAME_FILENAME_LIMIT = 32
  174. def object_annotation(obj):
  175. """
  176. Return a string to be used for Graphviz nodes.
  177. The string should be short but as informative as possible.
  178. """
  179. def format_sequence(obj):
  180. body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for x in obj[:8])
  181. if len(obj) > 8:
  182. body = f'{body}, ...{len(obj) - 8}'
  183. return body
  184. # For basic types, use the repr.
  185. if isinstance(obj, BASE_TYPES):
  186. return repr(obj)
  187. if type(obj).__name__ == 'function':
  188. return f"function\n{obj.__name__}"
  189. elif isinstance(obj, types.MethodType):
  190. try:
  191. func_name = obj.__func__.__qualname__
  192. except AttributeError:
  193. func_name = "<anonymous>"
  194. return f"instancemethod\n{func_name}"
  195. elif isinstance(obj, list):
  196. return f"[{format_sequence(obj)}]"
  197. elif isinstance(obj, tuple):
  198. return f"({format_sequence(obj)})"
  199. elif isinstance(obj, dict):
  200. return f"dict[{len(obj)}]"
  201. elif isinstance(obj, types.ModuleType):
  202. return f"module\n{obj.__name__}"
  203. elif isinstance(obj, type):
  204. return f"type\n{obj.__name__}"
  205. elif isinstance(obj, weakref.ref):
  206. referent = obj()
  207. if referent is None:
  208. return "weakref (dead referent)"
  209. else:
  210. return f"weakref to id 0x{id(referent):x}"
  211. elif isinstance(obj, types.FrameType):
  212. filename = obj.f_code.co_filename
  213. if len(filename) > FRAME_FILENAME_LIMIT:
  214. filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):]
  215. return f"frame\n{filename}:{obj.f_lineno}"
  216. elif is_cuda_tensor(obj):
  217. return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})"
  218. else:
  219. return f"object\n{type(obj).__module__}.{type(obj).__name__}"
  220. class Node(NamedTuple):
  221. label: str
  222. context: str | None
  223. root: bool
  224. referrents: list[tuple[str, int]]
  225. def create_graph(objects, *, context=None, filter=None):
  226. if context is None:
  227. context = cuda_allocation_context()
  228. if filter is None:
  229. filter = is_cuda_tensor
  230. objects = [obj for obj in objects if not isinstance(obj, weakref.ProxyTypes)]
  231. nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
  232. node_referrers: list[list[int]] = [[] for obj in objects]
  233. id_to_node = {id(obj): i for i, obj in enumerate(objects)}
  234. for obj in objects:
  235. fidx = id_to_node[id(obj)]
  236. f = nodes[fidx]
  237. references = annotated_references(obj)
  238. for referrent in gc.get_referents(obj):
  239. rid = id(referrent)
  240. tidx = id_to_node.get(rid)
  241. if tidx is None:
  242. continue
  243. labels = references.get(rid, ["?"])
  244. node_referrers[tidx].append(fidx)
  245. for label in labels:
  246. f.referrents.append((label, tidx))
  247. to_search = [i for i, n in enumerate(nodes) if n.root]
  248. to_keep = set()
  249. while to_search:
  250. idx = to_search.pop()
  251. if idx in to_keep:
  252. continue
  253. to_keep.add(idx)
  254. referrers = node_referrers[idx]
  255. to_search.extend(referrers)
  256. id_to_filtered_id: dict[int, int] = {}
  257. filtered: list[Any] = []
  258. for i, n in enumerate(nodes):
  259. if i in to_keep:
  260. id_to_filtered_id[i] = len(id_to_filtered_id)
  261. filtered.append(n)
  262. for n in filtered:
  263. n.referrents[:] = [(label, id_to_filtered_id[idx])
  264. for (label, idx) in n.referrents
  265. if idx in id_to_filtered_id]
  266. return filtered
  267. def escape(n):
  268. return json.dumps(n)
  269. def is_cuda_tensor(obj):
  270. return (
  271. isinstance(obj, torch.Tensor) and
  272. obj.device.type == "cuda" and
  273. not isinstance(obj, torch._subclasses.FakeTensor)
  274. )
  275. def cuda_allocation_context():
  276. snapshot = torch.cuda.memory._snapshot()
  277. addr_to_frame = {}
  278. for seg in snapshot['segments']:
  279. addr = seg['address']
  280. for blk in seg['blocks']:
  281. if blk['state'] == 'active_allocated':
  282. frames, _real_size = _block_extra(blk)
  283. addr_to_frame[addr] = frames
  284. addr += blk['size']
  285. def object_context(obj):
  286. if is_cuda_tensor(obj):
  287. addr = obj.untyped_storage().data_ptr()
  288. frames = addr_to_frame.get(addr)
  289. if frames is not None:
  290. return '\n'.join(_frames_fmt(frames, full_filename=True))
  291. return None
  292. return object_context
  293. def to_dot(nodes):
  294. lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;']
  295. for i, n in enumerate(nodes):
  296. lines.append(f'{i} [label={escape(n.label)}, color={"red" if n.root else "black"}];')
  297. for i, f in enumerate(nodes):
  298. for label, j in f.referrents:
  299. lines.append(f'{i} -> {j} [label = {escape(label)}]')
  300. lines.append("}\n")
  301. return '\n'.join(lines)
  302. _template = """
  303. <!DOCTYPE html>
  304. <html>
  305. <head>
  306. <style>
  307. body {
  308. margin: 0;
  309. padding: 0;
  310. overflow: hidden;
  311. }
  312. #container {
  313. display: flex;
  314. flex-direction: column;
  315. height: 100vh;
  316. }
  317. #main {
  318. flex: 2;
  319. height: 60vh;
  320. overflow: clip;
  321. }
  322. #preContainer {
  323. flex: 1;
  324. height: 40vh;
  325. overflow: auto;
  326. }
  327. pre {
  328. margin: 0;
  329. padding: 10px;
  330. }
  331. </style>
  332. </head>
  333. <body>
  334. <div id="container">
  335. <div id="main">
  336. </div>
  337. <div id="preContainer">
  338. <pre id="stacktrace">Mouse over tensor objects to see where they were allocated.</pre>
  339. </div>
  340. </div>
  341. <script src='https://cdnjs.cloudflare.com/ajax/libs/viz.js/1.8.0/viz-lite.js'></script>
  342. <script>
  343. let dot = $DOT
  344. let image = Viz(dot, {format: 'svg', 'totalMemory': 1024*1024*1024});
  345. let main = document.getElementById('main')
  346. main.innerHTML = image
  347. let svg = main.firstElementChild
  348. // Panning and zooming logic
  349. let isPanning = false;
  350. let startX, startY;
  351. let viewBox = { x: 0, y: 0, width: parseFloat(svg.getAttribute('width')), height: parseFloat(svg.getAttribute('height')) };
  352. svg.removeAttribute('width');
  353. svg.removeAttribute('height');
  354. function updateViewBox() {
  355. svg.setAttribute('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.width} ${viewBox.height}`);
  356. }
  357. updateViewBox()
  358. svg.setAttribute('preserveAspectRatio', 'xMidYMid meet');
  359. svg.addEventListener('mousedown', function(e) {
  360. isPanning = true;
  361. startX = e.clientX;
  362. startY = e.clientY;
  363. });
  364. svg.addEventListener('mousemove', function(e) {
  365. if (!isPanning) return;
  366. const dx = (e.clientX - startX) * (viewBox.width / svg.clientWidth);
  367. const dy = (e.clientY - startY) * (viewBox.height / svg.clientHeight);
  368. viewBox.x -= dx;
  369. viewBox.y -= dy;
  370. startX = e.clientX;
  371. startY = e.clientY;
  372. updateViewBox();
  373. });
  374. svg.addEventListener('mouseup', function() {
  375. isPanning = false;
  376. });
  377. svg.addEventListener('mouseleave', function() {
  378. isPanning = false;
  379. });
  380. svg.addEventListener('wheel', function(e) {
  381. e.preventDefault();
  382. const zoomFactor = 0.1;
  383. const zoomAmount = e.deltaY > 0 ? 1 + zoomFactor : 1 - zoomFactor;
  384. // Calculate mouse position relative to the SVG
  385. const rect = svg.getBoundingClientRect();
  386. const mouseX = e.clientX - rect.left;
  387. const mouseY = e.clientY - rect.top;
  388. const mouseXRel = mouseX / svg.clientWidth;
  389. const mouseYRel = mouseY / svg.clientHeight;
  390. // Adjust viewBox to zoom around the mouse position
  391. const newWidth = viewBox.width * zoomAmount;
  392. const newHeight = viewBox.height * zoomAmount;
  393. viewBox.x += (viewBox.width - newWidth) * mouseXRel;
  394. viewBox.y += (viewBox.height - newHeight) * mouseYRel;
  395. viewBox.width = newWidth;
  396. viewBox.height = newHeight;
  397. updateViewBox();
  398. });
  399. $LISTENERS
  400. </script>
  401. </body>
  402. </html>
  403. """
  404. _listener_template = """
  405. document.getElementById('node{id}').addEventListener('mouseover', function(event) {{
  406. document.getElementById("stacktrace").textContent = {stack}
  407. }})
  408. """
  409. def to_html(nodes):
  410. listeners = []
  411. for i, n in enumerate(nodes):
  412. if n.context is None:
  413. continue
  414. s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
  415. # pyrefly: ignore [bad-argument-type]
  416. listeners.append(s)
  417. dot = to_dot(nodes)
  418. return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))
  419. def observe_tensor_cycles(callback):
  420. torch.cuda.memory._record_memory_history(max_entries=100000)
  421. def observer(garbage) -> None:
  422. if garbage:
  423. if not any(is_cuda_tensor(obj) for obj in garbage):
  424. logger.info("No CUDA Tensors found in garbage")
  425. return
  426. callback(to_html(create_graph(garbage)))
  427. return observe_garbage(observer)
  428. def warn_tensor_cycles():
  429. """
  430. Install a warning that reports whenever a cycle that is holding CUDA memory is observed.
  431. The warning produces an .html file that visualizes the cycle,
  432. and links it to the stack frame that allocated the CUDA tensor.
  433. Reference cycles are freed by the cycle collector rather than being cleaned up
  434. when the objects in the cycle first become unreachable. If a cycle points to a tensor,
  435. the CUDA memory for that tensor will not be freed until garbage collection runs.
  436. Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as
  437. non-deterministic allocation behavior which is harder to debug.
  438. """
  439. logger.info("Watching Python reference cycles for CUDA Tensors.")
  440. def write_and_log(html) -> None:
  441. with NamedTemporaryFile('w', suffix='.html') as f:
  442. f.write(html)
  443. logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name)
  444. return observe_tensor_cycles(write_and_log)