| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506 |
- # mypy: allow-untyped-defs
- import gc
- import sys
- from typing import Any, NamedTuple
- import types
- import weakref
- import json
- from tempfile import NamedTemporaryFile
- import torch
- from torch.cuda._memory_viz import _frames_fmt, _block_extra
- import atexit
- import logging
- logger = logging.getLogger(__name__)
- def observe_garbage(observer):
- enabled = True
- def disable() -> None:
- # when GC runs during exit, things like `sys` will already be unloaded
- # so we have to disable the callback to avoid hitting errors.
- nonlocal enabled
- enabled = False
- atexit.register(disable)
- def gc_callback(phase, info) -> None:
- nonlocal enabled
- if not enabled:
- return
- if phase == "start":
- gc.set_debug(gc.DEBUG_SAVEALL)
- elif phase == "stop":
- orig_trace = sys.getprofile()
- self_return = [False]
- def do_collect(*args, **kwargs):
- nonlocal enabled
- if not self_return[0]:
- self_return[0] = True
- else:
- sys.setprofile(orig_trace)
- enabled = False
- try:
- # things in gc.garbage have survived a collection
- # so to free them we have to collect a generation greater than them
- # but that might _also_ free other stuff and we don't want to miss
- # that stuff. So we have to now force gc at the highest level here,
- # report all of what we found, _then_ we can free it up.
- if info['generation'] != 2:
- gc.collect()
- observer(gc.garbage)
- gc.garbage.clear()
- # we have to re-run GC to clean up the cycles
- # we saved from before.
- gc.set_debug(0)
- before = torch.cuda.memory_allocated()
- gc.collect()
- after = torch.cuda.memory_allocated()
- if before != after:
- logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after)
- finally:
- enabled = True
- if orig_trace is not None:
- return orig_trace(*args, **kwargs)
- sys.setprofile(do_collect)
- gc.callbacks.append(gc_callback)
- # provide a way to disarm the callback
- def remove() -> None:
- gc.callbacks.remove(gc_callback)
- return remove
- # Function to visualize cycles adapted from refcycle:
- # Copyright 2013 Mark Dickinson
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- def _get_cell_type():
- def f(x=None):
- return lambda: x
- return type(f().__closure__[0])
- CellType = _get_cell_type()
- def annotated_references(obj):
- """
- Return known information about references held by the given object.
- Returns a mapping from referents to lists of descriptions. Note that there
- may be more than one edge leading to any particular referent; hence the
- need for a list. Descriptions are currently strings.
- """
- references: dict[int, list[str]] = {}
- def add_reference(name, obj) -> None:
- references.setdefault(id(obj), []).append(name)
- def add_attrs(*attrs) -> None:
- for attr in attrs:
- if hasattr(obj, attr):
- add_reference(attr, getattr(obj, attr))
- def add_cell_references() -> None:
- try:
- add_attrs("cell_contents")
- except ValueError:
- # if cell_contents is empty,
- # accessing it raises ValueError
- # in this case there is no object to
- # annotate
- pass
- def add_function_references() -> None:
- add_attrs("__defaults__",
- "__closure__",
- "__globals__",
- "__code__",
- "__name__",
- "__module__",
- "__doc__"
- "__qualname__",
- "__annotations__",
- "__kwdefaults__")
- def add_sequence_references() -> None:
- for position, item in enumerate(obj):
- add_reference(f"[{position}]", item)
- def add_dict_references() -> None:
- for key, value in obj.items():
- add_reference("key", key)
- add_reference(f"[{repr(key)}]", value)
- def add_set_references() -> None:
- for elt in obj:
- add_reference("element", elt)
- def add_bound_method_references() -> None:
- add_attrs("__self__", "__func__", "im_class")
- def add_weakref_references() -> None:
- # For subclasses of weakref, we can't reliably distinguish the
- # callback (if any) from other attributes.
- if type(obj) is weakref.ref:
- referents = gc.get_referents(obj)
- if len(referents) == 1:
- target = referents[0]
- add_reference("__callback__", target)
- def add_frame_references() -> None:
- f_locals = obj.f_locals
- add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals")
- # Some badly-behaved code replaces the f_locals dict with
- # something that doesn't support the full dict interface. So we
- # only continue with the annotation if f_locals is a Python dict.
- if type(f_locals) is dict:
- for name, local in obj.f_locals.items():
- add_reference(f"local {name}", local)
- def add_getset_descriptor_references() -> None:
- add_attrs("__objclass__", "__name__", "__doc__")
- type_based_references = {
- tuple: add_sequence_references,
- list: add_sequence_references,
- dict: add_dict_references,
- set: add_set_references,
- frozenset: add_set_references,
- types.FunctionType: add_function_references,
- types.FrameType: add_frame_references,
- CellType: add_cell_references,
- types.MethodType: add_bound_method_references,
- weakref.ref: add_weakref_references,
- types.GetSetDescriptorType: add_getset_descriptor_references,
- }
- for type_ in type(obj).__mro__:
- if type_ in type_based_references:
- type_based_references[type_]()
- add_attrs("__dict__", "__class__")
- if isinstance(obj, type):
- add_attrs("__mro__")
- return references
- ###############################################################################
- # Object annotations.
- BASE_TYPES = (int, float, complex, type(None), str, bytes)
- FRAME_FILENAME_LIMIT = 32
- def object_annotation(obj):
- """
- Return a string to be used for Graphviz nodes.
- The string should be short but as informative as possible.
- """
- def format_sequence(obj):
- body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for x in obj[:8])
- if len(obj) > 8:
- body = f'{body}, ...{len(obj) - 8}'
- return body
- # For basic types, use the repr.
- if isinstance(obj, BASE_TYPES):
- return repr(obj)
- if type(obj).__name__ == 'function':
- return f"function\n{obj.__name__}"
- elif isinstance(obj, types.MethodType):
- try:
- func_name = obj.__func__.__qualname__
- except AttributeError:
- func_name = "<anonymous>"
- return f"instancemethod\n{func_name}"
- elif isinstance(obj, list):
- return f"[{format_sequence(obj)}]"
- elif isinstance(obj, tuple):
- return f"({format_sequence(obj)})"
- elif isinstance(obj, dict):
- return f"dict[{len(obj)}]"
- elif isinstance(obj, types.ModuleType):
- return f"module\n{obj.__name__}"
- elif isinstance(obj, type):
- return f"type\n{obj.__name__}"
- elif isinstance(obj, weakref.ref):
- referent = obj()
- if referent is None:
- return "weakref (dead referent)"
- else:
- return f"weakref to id 0x{id(referent):x}"
- elif isinstance(obj, types.FrameType):
- filename = obj.f_code.co_filename
- if len(filename) > FRAME_FILENAME_LIMIT:
- filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):]
- return f"frame\n{filename}:{obj.f_lineno}"
- elif is_cuda_tensor(obj):
- return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})"
- else:
- return f"object\n{type(obj).__module__}.{type(obj).__name__}"
- class Node(NamedTuple):
- label: str
- context: str | None
- root: bool
- referrents: list[tuple[str, int]]
- def create_graph(objects, *, context=None, filter=None):
- if context is None:
- context = cuda_allocation_context()
- if filter is None:
- filter = is_cuda_tensor
- objects = [obj for obj in objects if not isinstance(obj, weakref.ProxyTypes)]
- nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
- node_referrers: list[list[int]] = [[] for obj in objects]
- id_to_node = {id(obj): i for i, obj in enumerate(objects)}
- for obj in objects:
- fidx = id_to_node[id(obj)]
- f = nodes[fidx]
- references = annotated_references(obj)
- for referrent in gc.get_referents(obj):
- rid = id(referrent)
- tidx = id_to_node.get(rid)
- if tidx is None:
- continue
- labels = references.get(rid, ["?"])
- node_referrers[tidx].append(fidx)
- for label in labels:
- f.referrents.append((label, tidx))
- to_search = [i for i, n in enumerate(nodes) if n.root]
- to_keep = set()
- while to_search:
- idx = to_search.pop()
- if idx in to_keep:
- continue
- to_keep.add(idx)
- referrers = node_referrers[idx]
- to_search.extend(referrers)
- id_to_filtered_id: dict[int, int] = {}
- filtered: list[Any] = []
- for i, n in enumerate(nodes):
- if i in to_keep:
- id_to_filtered_id[i] = len(id_to_filtered_id)
- filtered.append(n)
- for n in filtered:
- n.referrents[:] = [(label, id_to_filtered_id[idx])
- for (label, idx) in n.referrents
- if idx in id_to_filtered_id]
- return filtered
- def escape(n):
- return json.dumps(n)
- def is_cuda_tensor(obj):
- return (
- isinstance(obj, torch.Tensor) and
- obj.device.type == "cuda" and
- not isinstance(obj, torch._subclasses.FakeTensor)
- )
- def cuda_allocation_context():
- snapshot = torch.cuda.memory._snapshot()
- addr_to_frame = {}
- for seg in snapshot['segments']:
- addr = seg['address']
- for blk in seg['blocks']:
- if blk['state'] == 'active_allocated':
- frames, _real_size = _block_extra(blk)
- addr_to_frame[addr] = frames
- addr += blk['size']
- def object_context(obj):
- if is_cuda_tensor(obj):
- addr = obj.untyped_storage().data_ptr()
- frames = addr_to_frame.get(addr)
- if frames is not None:
- return '\n'.join(_frames_fmt(frames, full_filename=True))
- return None
- return object_context
- def to_dot(nodes):
- lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;']
- for i, n in enumerate(nodes):
- lines.append(f'{i} [label={escape(n.label)}, color={"red" if n.root else "black"}];')
- for i, f in enumerate(nodes):
- for label, j in f.referrents:
- lines.append(f'{i} -> {j} [label = {escape(label)}]')
- lines.append("}\n")
- return '\n'.join(lines)
- _template = """
- <!DOCTYPE html>
- <html>
- <head>
- <style>
- body {
- margin: 0;
- padding: 0;
- overflow: hidden;
- }
- #container {
- display: flex;
- flex-direction: column;
- height: 100vh;
- }
- #main {
- flex: 2;
- height: 60vh;
- overflow: clip;
- }
- #preContainer {
- flex: 1;
- height: 40vh;
- overflow: auto;
- }
- pre {
- margin: 0;
- padding: 10px;
- }
- </style>
- </head>
- <body>
- <div id="container">
- <div id="main">
- </div>
- <div id="preContainer">
- <pre id="stacktrace">Mouse over tensor objects to see where they were allocated.</pre>
- </div>
- </div>
- <script src='https://cdnjs.cloudflare.com/ajax/libs/viz.js/1.8.0/viz-lite.js'></script>
- <script>
- let dot = $DOT
- let image = Viz(dot, {format: 'svg', 'totalMemory': 1024*1024*1024});
- let main = document.getElementById('main')
- main.innerHTML = image
- let svg = main.firstElementChild
- // Panning and zooming logic
- let isPanning = false;
- let startX, startY;
- let viewBox = { x: 0, y: 0, width: parseFloat(svg.getAttribute('width')), height: parseFloat(svg.getAttribute('height')) };
- svg.removeAttribute('width');
- svg.removeAttribute('height');
- function updateViewBox() {
- svg.setAttribute('viewBox', `${viewBox.x} ${viewBox.y} ${viewBox.width} ${viewBox.height}`);
- }
- updateViewBox()
- svg.setAttribute('preserveAspectRatio', 'xMidYMid meet');
- svg.addEventListener('mousedown', function(e) {
- isPanning = true;
- startX = e.clientX;
- startY = e.clientY;
- });
- svg.addEventListener('mousemove', function(e) {
- if (!isPanning) return;
- const dx = (e.clientX - startX) * (viewBox.width / svg.clientWidth);
- const dy = (e.clientY - startY) * (viewBox.height / svg.clientHeight);
- viewBox.x -= dx;
- viewBox.y -= dy;
- startX = e.clientX;
- startY = e.clientY;
- updateViewBox();
- });
- svg.addEventListener('mouseup', function() {
- isPanning = false;
- });
- svg.addEventListener('mouseleave', function() {
- isPanning = false;
- });
- svg.addEventListener('wheel', function(e) {
- e.preventDefault();
- const zoomFactor = 0.1;
- const zoomAmount = e.deltaY > 0 ? 1 + zoomFactor : 1 - zoomFactor;
- // Calculate mouse position relative to the SVG
- const rect = svg.getBoundingClientRect();
- const mouseX = e.clientX - rect.left;
- const mouseY = e.clientY - rect.top;
- const mouseXRel = mouseX / svg.clientWidth;
- const mouseYRel = mouseY / svg.clientHeight;
- // Adjust viewBox to zoom around the mouse position
- const newWidth = viewBox.width * zoomAmount;
- const newHeight = viewBox.height * zoomAmount;
- viewBox.x += (viewBox.width - newWidth) * mouseXRel;
- viewBox.y += (viewBox.height - newHeight) * mouseYRel;
- viewBox.width = newWidth;
- viewBox.height = newHeight;
- updateViewBox();
- });
- $LISTENERS
- </script>
- </body>
- </html>
- """
- _listener_template = """
- document.getElementById('node{id}').addEventListener('mouseover', function(event) {{
- document.getElementById("stacktrace").textContent = {stack}
- }})
- """
- def to_html(nodes):
- listeners = []
- for i, n in enumerate(nodes):
- if n.context is None:
- continue
- s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
- # pyrefly: ignore [bad-argument-type]
- listeners.append(s)
- dot = to_dot(nodes)
- return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))
- def observe_tensor_cycles(callback):
- torch.cuda.memory._record_memory_history(max_entries=100000)
- def observer(garbage) -> None:
- if garbage:
- if not any(is_cuda_tensor(obj) for obj in garbage):
- logger.info("No CUDA Tensors found in garbage")
- return
- callback(to_html(create_graph(garbage)))
- return observe_garbage(observer)
- def warn_tensor_cycles():
- """
- Install a warning that reports whenever a cycle that is holding CUDA memory is observed.
- The warning produces an .html file that visualizes the cycle,
- and links it to the stack frame that allocated the CUDA tensor.
- Reference cycles are freed by the cycle collector rather than being cleaned up
- when the objects in the cycle first become unreachable. If a cycle points to a tensor,
- the CUDA memory for that tensor will not be freed until garbage collection runs.
- Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as
- non-deterministic allocation behavior which is harder to debug.
- """
- logger.info("Watching Python reference cycles for CUDA Tensors.")
- def write_and_log(html) -> None:
- with NamedTemporaryFile('w', suffix='.html') as f:
- f.write(html)
- logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name)
- return observe_tensor_cycles(write_and_log)
|