| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334 |
- import collections
- import contextlib
- import copy
- import dataclasses
- import functools
- import io
- import itertools
- import json
- import logging
- import os
- import os.path
- import pickle
- import pstats
- import shutil
- import tempfile
- import traceback
- from collections.abc import Callable, Iterator, Sequence
- from typing import Any, IO, Optional, Union
- from unittest.mock import patch
- import torch
- from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
- from torch import fx
- from torch._dynamo.repro.after_aot import save_graph_repro
- from torch._dynamo.utils import get_debug_dir
- from torch._inductor import utils
- from torch._logging import getArtifactLogger
- from torch._logging._internal import trace_structured
- from torch._utils_internal import signpost_event
- from torch.fx.graph_module import GraphModule
- from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
- from torch.fx.passes.tools_common import legalize_graph
- from torch.types import FileLike
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._pytree import tree_map
- from . import config, ir # noqa: F811, this is needed
- from .ir import ExternKernel
- from .scheduler import (
- BaseSchedulerNode,
- FusedSchedulerNode,
- NopKernelSchedulerNode,
- OutputNode,
- SchedulerNode,
- )
- from .virtualized import V
- log = logging.getLogger(__name__)
- # Graph execution tracking for debugging
- GRAPH_EXECUTION_ORDER: Optional[list[dict[str, object]]] = None
- RECORD_GRAPH_EXECUTION: bool = False
- GRAPH_COMPILE_IDS: Optional[dict[int, Optional[str]]] = None
- ir_pre_fusion_log = getArtifactLogger(__name__, "ir_pre_fusion")
- ir_post_fusion_log = getArtifactLogger(__name__, "ir_post_fusion")
- SchedulerNodeList = list[Any]
- BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
- GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
- @functools.cache
- def has_dot() -> bool:
- return shutil.which("dot") is not None
- def draw_buffers(
- nodes: list[BaseSchedulerNode],
- print_graph: bool = False,
- fname: Optional[str] = None,
- ) -> None:
- """
- Draw a graph in fname.svg.
- """
- if not has_dot():
- log.warning("draw_buffers() requires `graphviz` package")
- return
- if fname is None:
- fname = get_graph_being_compiled()
- graph = create_fx_from_snodes(nodes)
- for node in graph.nodes:
- if "fusion_meta" not in node.meta:
- continue
- group = node.meta["fusion_meta"].group
- if isinstance(group, tuple):
- if isinstance(group[1], int):
- group = (group[1],)
- else:
- group = group[1]
- # gather meta data
- dtype = None
- if isinstance(node, ir.ComputedBuffer):
- dtype = node.data.dtype
- metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type]
- node.meta["tensor_meta"] = metadata
- if print_graph:
- print(graph)
- gm = GraphModule({}, graph)
- legalize_graph(gm)
- gm.graph.lint()
- draw_graph(
- gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
- )
- def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph:
- """
- Creates a FX Graph from a list of SchedulerNode objects.
- """
- def get_fake_func(name: str) -> Callable[..., int]:
- def func1(*args: Any) -> int:
- return 0
- func1.__name__ = name
- return func1
- FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
- buf_to_fx_node = {}
- node_to_fx_node = {}
- graph = torch.fx.Graph()
- first_node = None
- outputs = []
- group: Any = None
- # create call_function node for each Buffer and Kernel
- # pyrefly: ignore [bad-assignment]
- for snode in snodes:
- if snode.is_extern():
- node_type = "extern"
- group = node_type
- elif snode.is_template():
- node_type = "template"
- group = node_type
- elif isinstance(snode, NopKernelSchedulerNode):
- node_type = "nop"
- group = node_type
- elif isinstance(snode, SchedulerNode):
- node_type = "compute"
- group = snode.group
- elif isinstance(snode, FusedSchedulerNode):
- node_type = "fused"
- group = snode.group
- else:
- raise RuntimeError("Unknown node type")
- fused_name = torch._inductor.utils.get_fused_kernel_name(
- snode.get_nodes(), "original_aten"
- )
- func_name = f"{node_type}: {fused_name}"
- node_func = get_fake_func(func_name)
- kwargs = {}
- if hasattr(snode, "get_device"):
- kwargs = {"device": snode.get_device()}
- fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type]
- def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
- if isinstance(snode, FusedSchedulerNode):
- return any(in_output(x) for x in snode.snodes)
- return any(
- isinstance(user.node, OutputNode)
- for buf in snode.get_outputs()
- for user in buf.users
- )
- if in_output(snode):
- outputs.append(fx_node)
- name = snode.get_name()
- fx_node.name = name
- fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
- node_to_fx_node[name] = fx_node
- for buf in snode.get_outputs():
- buf_to_fx_node[buf.get_name()] = fx_node
- if first_node is None:
- first_node = fx_node
- # create edges between nodes
- for snode in snodes:
- name = snode.get_name()
- deps = snode.read_writes.reads
- fx_node = node_to_fx_node[name]
- new_args = []
- for dep in deps:
- if dep.name in buf_to_fx_node:
- dep_node = buf_to_fx_node[dep.name]
- else:
- with graph.inserting_before(first_node):
- dep_node = graph.placeholder(dep.name)
- buf_to_fx_node[dep.name] = dep_node
- if dep_node == fx_node: # to avoid cycles
- continue
- new_args.append(dep_node)
- fx_node.args = tuple(new_args)
- graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
- return graph
- def update_orig_fx_node_name_to_buf_name(
- nodes: Optional[SchedulerNodeList],
- node_name_to_buf_name: dict[str, str],
- parent_buf_name: Optional[str] = None,
- n_origins: int = 0,
- ) -> None:
- if nodes is None:
- return
- for node in nodes:
- # for FusedSchedulerNode, traverse recursively into get_nodes()
- buf_name = node.get_name()
- children_nodes = node.get_nodes()
- if children_nodes is not None and len(children_nodes) > 1:
- update_orig_fx_node_name_to_buf_name(
- children_nodes,
- node_name_to_buf_name,
- buf_name if parent_buf_name is None else parent_buf_name,
- )
- continue
- else:
- # pyrefly: ignore [bad-argument-type, unsupported-operation]
- assert len(children_nodes) == 1 and children_nodes[0] == node
- ir_node = node.node
- if ir_node is None or ir_node.origins is None:
- continue
- for origin in ir_node.origins:
- node_name = origin.name
- # when buf1 and buf2 both have origin=node1
- # we draw node1 according to buf1
- if node_name not in node_name_to_buf_name:
- node_name_to_buf_name[node_name] = (
- buf_name if parent_buf_name is None else parent_buf_name
- )
- def get_node_name_to_buf_meta(
- node_name_to_buf_name: dict[str, str],
- ) -> dict[str, BufMeta]:
- buf_name_to_n_node = {}
- for node_name, buf_name in node_name_to_buf_name.items():
- if buf_name not in buf_name_to_n_node:
- buf_name_to_n_node[buf_name] = OrderedSet([node_name])
- else:
- # pyrefly: ignore [missing-attribute]
- buf_name_to_n_node[buf_name].add(node_name)
- node_name_to_buf_meta = {}
- for node_name, buf_name in node_name_to_buf_name.items():
- n_node = len(buf_name_to_n_node[buf_name])
- node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
- return node_name_to_buf_meta
- def annotate_orig_fx_with_snodes(
- gm: torch.fx.GraphModule,
- snodes: SchedulerNodeList,
- ) -> None:
- """
- Creates a FX Graph from a list of SchedulerNode objects.
- """
- node_name_to_buf_name: dict[str, str] = {}
- update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
- if node_name_to_buf_name is None:
- return
- node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
- for node in gm.graph.nodes:
- if node.name in node_name_to_buf_meta:
- node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
- @contextlib.contextmanager
- def enable_aot_logging() -> Iterator[None]:
- compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
- import torch._functorch.aot_autograd
- log = logging.getLogger(torch._functorch.aot_autograd.__name__)
- stack = contextlib.ExitStack()
- if not compile_debug:
- try:
- yield
- finally:
- stack.close()
- return
- # Enable all graphs to be logged to a file by setting the flags to True
- # and the log level of the file logger to DEBUG
- stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
- path = os.path.join(get_debug_dir(), "torchinductor")
- os.makedirs(path, exist_ok=True)
- fh = logging.FileHandler(
- os.path.join(
- path,
- f"aot_{get_aot_graph_name()}_debug.log",
- )
- )
- fh.setLevel(logging.DEBUG)
- fh.setFormatter(
- logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
- )
- log.addHandler(fh)
- try:
- yield
- finally:
- log.removeHandler(fh)
- stack.close()
- # Used for provenance tracking
- # They are not stored in DebugContext because they are not set in
- # _inductor_triton_kernel_to_post_grad_node_info's Debug Context
- _inductor_post_to_pre_grad_nodes: dict[str, dict[str, list[str]]] = {}
- _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
- _pre_grad_graph_id: Optional[int] = None
- _inductor_pre_grad_node_stack_trace: dict[str, str] = {}
- _inductor_kernel_stack_trace: dict[str, list[str]] = {}
- _inductor_kernel_provenance_debug_handle: int = 0
- def reset_inductor_kernel_provenance_debug_handle() -> None:
- global _inductor_kernel_provenance_debug_handle
- _inductor_kernel_provenance_debug_handle = 0
- @contextlib.contextmanager
- def reset_provenance_globals() -> Iterator[None]:
- """Context manager that resets provenance tracking globals upon entering
- and restores their original values when exiting."""
- global _pre_grad_graph_id
- global _inductor_post_to_pre_grad_nodes
- global _inductor_triton_kernel_to_post_grad_node_info
- global _inductor_pre_grad_node_stack_trace
- global _inductor_kernel_stack_trace
- global _inductor_kernel_provenance_debug_handle
- # Store original values
- original_pre_grad_graph_id = _pre_grad_graph_id
- original_post_to_pre_grad_nodes = _inductor_post_to_pre_grad_nodes.copy()
- original_triton_kernel_to_post_grad_node_info = (
- _inductor_triton_kernel_to_post_grad_node_info.copy()
- )
- original_inductor_pre_grad_node_stack_trace = (
- _inductor_pre_grad_node_stack_trace.copy()
- )
- original_inductor_kernel_stack_trace = _inductor_kernel_stack_trace.copy()
- original_inductor_kernel_provenance_debug_handle = (
- _inductor_kernel_provenance_debug_handle
- )
- # Reset to default values
- _pre_grad_graph_id = -1
- _inductor_post_to_pre_grad_nodes = {}
- _inductor_triton_kernel_to_post_grad_node_info = {}
- _inductor_pre_grad_node_stack_trace = {}
- _inductor_kernel_stack_trace = {}
- _inductor_kernel_provenance_debug_handle = 0
- try:
- yield
- finally:
- # Restore original values
- _pre_grad_graph_id = original_pre_grad_graph_id
- _inductor_post_to_pre_grad_nodes = original_post_to_pre_grad_nodes
- _inductor_triton_kernel_to_post_grad_node_info = (
- original_triton_kernel_to_post_grad_node_info
- )
- _inductor_kernel_stack_trace = original_inductor_kernel_stack_trace
- _inductor_pre_grad_node_stack_trace = (
- original_inductor_pre_grad_node_stack_trace
- )
- _inductor_kernel_provenance_debug_handle = (
- original_inductor_kernel_provenance_debug_handle
- )
- class DebugContext:
- _counter = itertools.count()
- @staticmethod
- def create_debug_dir(folder_name: str) -> Optional[str]:
- debug_dir = config.trace.debug_dir or get_debug_dir()
- for n in DebugContext._counter:
- dirname = os.path.join(
- debug_dir,
- "torchinductor",
- f"{folder_name}.{n}",
- )
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- return dirname
- return None
- def __init__(self) -> None:
- self._prof = None
- self._path = None
- self._stack = contextlib.ExitStack()
- def copy(self, new_path: str) -> None:
- if not self._path:
- return
- assert new_path.endswith(".debug"), new_path
- from filelock import FileLock
- try:
- with FileLock(f"{new_path}.lock"):
- if os.path.exists(new_path):
- shutil.rmtree(new_path)
- shutil.copytree(self._path, new_path)
- except OSError:
- log.warning(
- "Failed to copy debug files from %s to %s", self._path, new_path
- )
- def fopen(
- self,
- filename: str,
- write_mode: str = "w",
- *args: Any,
- **kwargs: Any,
- ) -> IO[Any]:
- assert self._path
- return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
- @contextlib.contextmanager
- def fopen_context(
- self,
- filename: str,
- write_mode: str = "w",
- *args: Any,
- **kwargs: Any,
- ) -> Iterator[IO[Any]]:
- assert self._path
- with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
- yield f
- def filename(self, suffix: str) -> str:
- assert self._path
- return os.path.join(self._path, suffix)
- def upload_tar(self) -> None:
- if config.trace.upload_tar is not None:
- import tarfile
- assert self._path
- tar_file = os.path.join(
- self._path, f"{os.path.basename(self._path)}.tar.gz"
- )
- with tarfile.open(tar_file, "w:gz") as tar:
- tar.add(self._path, arcname=os.path.basename(self._path))
- config.trace.upload_tar(tar_file)
- def __enter__(self) -> None:
- if config.debug:
- log = logging.getLogger("torch._dynamo")
- prev_level = log.level
- log.setLevel(logging.DEBUG)
- def reset_log_level(level: Any) -> None:
- log.setLevel(level)
- self._stack.callback(reset_log_level, prev_level)
- self._stack.enter_context(V.set_debug_handler(self))
- if not config.trace.enabled:
- return
- self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment]
- if config.trace.debug_log:
- self._setup_log_capture("debug.log", logging.DEBUG)
- if config.trace.info_log:
- self._setup_log_capture("info.log", logging.INFO)
- def _setup_log_capture(
- self,
- filename: str,
- level: int,
- ) -> None:
- log = logging.getLogger("torch._inductor")
- fd = self._stack.enter_context(self.fopen(filename))
- ch = logging.StreamHandler(fd)
- ch.setLevel(level)
- ch.setFormatter(
- logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
- )
- log.addHandler(ch)
- log.setLevel(min(log.level, level))
- self._stack.callback(log.removeHandler, ch)
- def __exit__(
- self,
- exc_type: Optional[type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[Any],
- ) -> None:
- if self._prof:
- self._prof.disable()
- self._save_profile_data()
- if self._path:
- self.upload_tar()
- log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
- self._stack.close()
- def _save_profile_data(self) -> None:
- assert self._prof
- self._prof.dump_stats(self.filename("compile.prof"))
- with self.fopen("compile.stats") as fd:
- stats = pstats.Stats(self._prof, stream=fd)
- stats.strip_dirs()
- stats.sort_stats("cumtime")
- stats.print_stats(100)
- stats.sort_stats("tottime")
- stats.print_stats(100)
- def __getattr__(self, name: str) -> Optional[Callable[..., None]]:
- if config.trace.enabled and getattr(config.trace, name):
- try:
- return getattr(DebugFormatter(self), name)
- except Exception:
- log.warning("Ignoring exception in debug code", exc_info=True)
- return None
- else:
- def ignored(*args: Any, **kwargs: Any) -> None:
- pass
- return ignored
- class DebugFormatter:
- def __init__(self, handler: DebugContext) -> None:
- self.fopen = handler.fopen
- self.fopen_context = handler.fopen_context
- self.filename = handler.filename
- self.handler = handler
- def fx_graph(
- self,
- gm: torch.fx.GraphModule,
- inputs: list[torch.Tensor],
- ) -> None:
- with self.fopen("fx_graph_runnable.py") as fd:
- save_dir = None
- if torch._inductor.config.trace.save_real_tensors:
- inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs)
- save_dir = os.path.dirname(fd.name)
- # dont try to use stable hash torchinductor compilation if saving real tensors
- # and avoid recursively trying to save real tensors inside of the inductor compilation
- # regardless
- stable_hash = torch._inductor.config.trace.save_real_tensors
- with torch._inductor.config.patch(
- {"trace.enabled": False, "trace.save_real_tensors": False}
- ):
- save_graph_repro(
- fd,
- gm,
- inputs,
- "inductor",
- save_dir=save_dir,
- stable_hash=stable_hash,
- )
- with self.fopen("fx_graph_readable.py") as fd:
- fd.write(gm.print_readable(print_output=False))
- def fx_graph_transformed(
- self,
- gm: torch.fx.GraphModule,
- inputs: list[torch.Tensor],
- ) -> None:
- with self.fopen("fx_graph_transformed.py") as fd:
- fd.write(gm.print_readable(print_output=False))
- def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None:
- with self.fopen("ir_pre_fusion.txt") as fd:
- fd.write(self._write_ir(nodes))
- def ir_post_fusion(self, nodes: SchedulerNodeList) -> None:
- with self.fopen("ir_post_fusion.txt") as fd:
- fd.write(self._write_ir(nodes))
- @staticmethod
- def _write_ir(nodes: SchedulerNodeList) -> str:
- buf = io.StringIO()
- for node in nodes:
- buf.write(node.debug_str())
- buf.write("\n\n\n")
- return buf.getvalue()
- def graph_diagram(self, nodes: SchedulerNodeList) -> None:
- draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
- def draw_orig_fx_graph(
- self,
- gm: torch.fx.GraphModule,
- nodes: SchedulerNodeList,
- ) -> None:
- annotate_orig_fx_with_snodes(gm, nodes)
- draw_graph(
- gm,
- fname=self.filename("orig_fx_graph_diagram.svg"),
- clear_meta=False,
- prog=GRAPHVIZ_COMMAND_SCALABLE,
- parse_stack_trace=True,
- dot_graph_shape=config.trace.dot_graph_shape,
- )
- def output_code(self, filename: str, extension: str = "py") -> None:
- shutil.copy(filename, self.filename(f"output_code.{extension}"))
- def log_autotuning_results(
- self,
- name: str,
- input_nodes: list[ir.IRNode],
- timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
- elapse: float,
- precompile_elapse: float,
- prescreening_elapse: Optional[float],
- ) -> None:
- from .ir import FixedLayout
- def build_node_info(node: ir.IRNode) -> dict[str, str]:
- if hasattr(node, "name"):
- node_name = node.name
- else:
- node_name = ""
- node_info = {
- "name": node_name,
- "type": type(node).__name__,
- }
- try:
- layout = node.get_output_spec()
- if isinstance(layout, FixedLayout):
- static_layout = FixedLayout(
- layout.device,
- dtype=layout.dtype,
- size=V.graph.sizevars.optimization_hints(layout.size),
- stride=V.graph.sizevars.optimization_hints(layout.stride),
- offset=V.graph.sizevars.optimization_hint(
- layout.offset, fallback=0
- ),
- )
- node_info["layout"] = str(static_layout)
- else:
- node_info["layout"] = str(layout)
- except Exception:
- pass
- try:
- node_info["dtype"] = str(node.get_dtype())
- except Exception:
- pass
- try:
- node_info["device"] = str(node.get_device())
- except Exception:
- pass
- try:
- node_info["stride"] = str(
- V.graph.sizevars.optimization_hints(node.get_stride())
- )
- except Exception:
- pass
- try:
- node_info["size"] = str(
- V.graph.sizevars.optimization_hints(node.get_size())
- ) # type: ignore[arg-type]
- except Exception:
- pass
- try:
- node_info["numel"] = str(
- V.graph.sizevars.optimization_hint(node.get_numel())
- )
- except Exception:
- pass
- if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
- node_info["data"] = build_node_info(node.data)
- return node_info
- general_properties = {
- "op_name": name,
- "cuda_device_name": torch.cuda.get_device_name(),
- "cuda_device_count": torch.cuda.device_count(),
- "input_nodes": [build_node_info(node) for node in input_nodes],
- "autotuning_time": elapse,
- "precompile_time": precompile_elapse,
- "prescreening_time": prescreening_elapse,
- }
- with self.fopen_context(
- "autotuning_result_json_list.txt", "at", encoding="utf-8"
- ) as fd:
- for caller, time in timings.items():
- info_dict = dict(caller.info_dict())
- info_dict.update(general_properties)
- info_dict["benchmark_result"] = time
- json.dump(info_dict, fd)
- fd.write("\n")
- def log_ir_pre_fusion(nodes: SchedulerNodeList) -> None:
- if ir_pre_fusion_log.isEnabledFor(logging.INFO):
- ir_pre_fusion_log.info("BEFORE FUSION\n%s", DebugFormatter._write_ir(nodes))
- V.debug.ir_pre_fusion(nodes)
- def log_ir_post_fusion(nodes: SchedulerNodeList) -> None:
- if ir_post_fusion_log.isEnabledFor(logging.INFO):
- ir_post_fusion_log.info("AFTER FUSION\n%s", DebugFormatter._write_ir(nodes))
- V.debug.ir_post_fusion(nodes)
- def _dump_collective_schedule(schedule: list[Union[str, None]]) -> None:
- try:
- trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "inductor_collective_schedule",
- "encoding": "json",
- },
- payload_fn=lambda: schedule,
- )
- except Exception:
- log.debug(
- "Failed to log inductor_collective_schedule via structured logging",
- exc_info=True,
- )
- def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None:
- schedule = [
- getattr(op, "python_kernel_name", None)
- for node in nodes
- if isinstance(op := getattr(node, "node", None), ir._CollectiveKernel)
- ]
- # Only log when there is at least one collective op
- if schedule:
- _dump_collective_schedule(schedule)
- def log_runtime_and_tensor_meta(node_runtimes: Sequence[tuple[Any, float]]) -> None:
- """Log per-op runtime estimates and output tensor metadata for TLParse."""
- try:
- to_optimization_hints = V.graph.sizevars.optimization_hints
- def to_list(x: Optional[Sequence[Any]]) -> list[Any]:
- return list(to_optimization_hints(x)) if x is not None else []
- def dtype_to_str(dtype: Any) -> Optional[str]:
- if dtype is None:
- return None
- s = str(dtype)
- s = s.removeprefix("torch.")
- return s
- ops: list[dict[str, Any]] = []
- for s, runtime_ns in node_runtimes:
- name = getattr(s.node, "python_kernel_name", s.get_name())
- op_type = "collective" if utils.is_collective(s.node) else "compute"
- # Build outputs metadata if available
- outputs: list[dict[str, Any]] = []
- try:
- for buf in s.get_outputs():
- irnode = buf.node
- shape = irnode.maybe_get_size()
- stride = (
- irnode.get_stride()
- if isinstance(irnode.layout, ir.Layout)
- else None
- )
- dtype = irnode.maybe_get_dtype()
- outputs.append(
- {
- "shape": to_list(shape),
- "stride": to_list(stride),
- "dtype": dtype_to_str(dtype),
- }
- )
- except Exception:
- pass
- ops.append(
- {
- "name": name,
- "type": op_type,
- "estimated_runtime_ns": runtime_ns,
- "outputs": outputs,
- }
- )
- trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "inductor_runtime_and_tensor_meta",
- "encoding": "json",
- },
- payload_fn=lambda: {"ops": ops},
- )
- except Exception:
- log.debug("Failed to log inductor_runtime_and_tensor_meta", exc_info=True)
- def log_graph_execution() -> None:
- """Emit a structured artifact with the graph execution order."""
- if not GRAPH_EXECUTION_ORDER:
- return
- try:
- trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "graph_execution",
- "encoding": "json",
- },
- payload_fn=lambda: {"graph_execution_order": GRAPH_EXECUTION_ORDER},
- )
- except Exception:
- log.debug("Failed to log graph_execution", exc_info=True)
- @contextlib.contextmanager
- def record_and_log_graph_execution_order() -> Iterator[None]:
- """Record graph execution order and log it once on exit."""
- global RECORD_GRAPH_EXECUTION, GRAPH_EXECUTION_ORDER, GRAPH_COMPILE_IDS
- GRAPH_EXECUTION_ORDER = []
- GRAPH_COMPILE_IDS = {}
- RECORD_GRAPH_EXECUTION = True
- try:
- yield
- finally:
- log_graph_execution()
- RECORD_GRAPH_EXECUTION = False
- GRAPH_EXECUTION_ORDER = None
- GRAPH_COMPILE_IDS = None
- @dataclasses.dataclass
- class TensorMetadataHolder:
- tensor_metadata: TensorMetadata
- device: torch.device
- save_args_cnt = itertools.count()
- def create_mapping_pre_post_grad_nodes(
- pre_grad_graph_id: Optional[int],
- post_to_pre_grad_nodes_json: dict[str, Any],
- ) -> dict[str, dict[str, list[str]]]:
- """
- Create bidirectional mappings between pre_grad graph nodes
- and post_grad graph code nodes, and vice versa.
- """
- # return a dummy dict if there's any error
- empty_return: dict[str, dict[str, list[str]]] = {
- "preToPost": {},
- "postToPre": {},
- }
- if not isinstance(post_to_pre_grad_nodes_json, dict):
- log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
- return empty_return
- if not isinstance(pre_grad_graph_id, int):
- # pre_grad_graph_id may be empty if there's no pre_grad graph
- # and there's only a backward graph from backward pass engine
- return empty_return
- pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
- post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet)
- try:
- def check_format(node: dict[str, Any]) -> bool:
- if not isinstance(node, dict):
- log.error(
- "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict"
- )
- return False
- if "graph_id" not in node or "name" not in node or "from_node" not in node:
- log.error(
- "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format"
- )
- return False
- return True
- for outer_key, node_array in post_to_pre_grad_nodes_json.items():
- if not isinstance(node_array, list):
- log.error(
- "Provenance tacking error: post_to_pre_grad_nodes_json value is not a list"
- )
- return empty_return
- for node in node_array:
- if not check_format(node):
- return empty_return
- # Check the current node first
- if node.get("graph_id") == pre_grad_graph_id:
- pre_to_post[node["name"]].add(outer_key)
- post_to_pre[outer_key].add(node["name"])
- # Check nested from_node array recursively, add node with the right graph_id to the map
- stack = [(n, outer_key) for n in node.get("from_node", [])]
- while stack:
- current_node, parent_key = stack.pop()
- if not check_format(current_node):
- return empty_return
- if current_node.get("graph_id") == pre_grad_graph_id:
- pre_to_post[current_node["name"]].add(parent_key)
- post_to_pre[parent_key].add(current_node["name"])
- stack.extend(
- (n, parent_key) for n in current_node.get("from_node", [])
- )
- def convert_sets_to_lists(d: dict[str, Any]) -> None:
- for key in d:
- d[key] = list(d[key])
- d = dict(d)
- # convert to list because set is not JSON serializable
- convert_sets_to_lists(pre_to_post)
- convert_sets_to_lists(post_to_pre)
- return {
- "preToPost": pre_to_post,
- "postToPre": post_to_pre,
- }
- except Exception as e:
- # Since this is just logging code, it should never interfere with regular
- # program execution, so we use this try-except to guard against any error
- signpost_event(
- "inductor",
- "provenance_tracking_error",
- {
- "function": "create_mapping_pre_post_grad_nodes",
- "error_msg": str(e),
- "stack_trace": traceback.format_exc(),
- },
- )
- log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
- log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
- return empty_return
- def create_node_mapping_kernel_to_post_grad(
- triton_kernel_to_post_grad_json: dict[str, Any],
- ) -> dict[str, dict[str, Any]]:
- """Create bidirectional mappings between triton kernel name and post_grad
- graph code nodes, and vice versa.
- """
- # return a dummy dict if there's any error
- empty_return: dict[str, dict[str, Any]] = {
- "cppCodeToPost": {},
- "postToCppCode": {},
- }
- if not isinstance(triton_kernel_to_post_grad_json, dict):
- log.error(
- "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
- )
- return empty_return
- post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
- try:
- for outer_key, node_array in triton_kernel_to_post_grad_json.items():
- if not isinstance(node_array, list):
- log.error(
- "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
- )
- return empty_return
- for curr_node in node_array:
- post_to_cpp_code[curr_node].add(outer_key)
- def convert_sets_to_lists(d: dict[str, Any]) -> None:
- for key in d:
- d[key] = list(d[key])
- d = dict(d)
- # convert to list because set is not JSON serializable
- convert_sets_to_lists(post_to_cpp_code)
- return {
- "cppCodeToPost": triton_kernel_to_post_grad_json,
- "postToCppCode": post_to_cpp_code,
- }
- except Exception as e:
- # Since this is just logging code, it should never interfere with regular
- # program execution, so we use this try-except to guard against any error
- signpost_event(
- "inductor",
- "provenance_tracking_error",
- {
- "function": "create_mapping_kernel_to_post_grad",
- "error_msg": str(e),
- "stack_trace": traceback.format_exc(),
- },
- )
- log.error(
- "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json
- )
- return empty_return
- def dump_inductor_provenance_info() -> dict[str, Any]:
- try:
- global _pre_grad_graph_id
- global _inductor_post_to_pre_grad_nodes
- global _inductor_triton_kernel_to_post_grad_node_info
- node_mapping: dict[str, Any] = {}
- if _pre_grad_graph_id:
- node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
- _inductor_triton_kernel_to_post_grad_node_info
- )
- node_mapping = {
- **_inductor_post_to_pre_grad_nodes,
- **node_mapping_kernel,
- }
- if config.trace.enabled:
- with V.debug.fopen(
- "inductor_provenance_tracking_node_mappings.json", "w"
- ) as fd:
- json.dump(node_mapping, fd)
- # we need to update the node mapping version when node mapping format changes
- # so the tlparse tool knows which node mapping version it is looking at
- node_mapping["version"] = 2.0
- return node_mapping
- except Exception as e:
- # Since this is just debugging, it should never interfere with regular
- # program execution, so we use this try-except to guard against any error
- signpost_event(
- "inductor",
- "provenance_tracking_error",
- {
- "function": "dump_inductor_provenance_info",
- "error_msg": str(e),
- "stack_trace": traceback.format_exc(),
- },
- )
- return {}
- def create_kernel_information_json() -> dict[str, dict[str, list[str]]]:
- """Create kernel information JSON"""
- try:
- global _inductor_post_to_pre_grad_nodes
- global _inductor_kernel_stack_trace
- global _inductor_triton_kernel_to_post_grad_node_info
- post_to_pre = _inductor_post_to_pre_grad_nodes.get("postToPre", {})
- all_kernels = OrderedSet(_inductor_kernel_stack_trace.keys()) | OrderedSet(
- _inductor_triton_kernel_to_post_grad_node_info.keys()
- )
- result = {}
- for kernel_name in all_kernels:
- post_grad_nodes = _inductor_triton_kernel_to_post_grad_node_info.get(
- kernel_name, []
- )
- pre_grad_nodes: OrderedSet[str] = OrderedSet()
- for post_node in post_grad_nodes:
- pre_grad_nodes.update(post_to_pre.get(post_node, []))
- result[kernel_name] = {
- "stack_traces": _inductor_kernel_stack_trace.get(kernel_name, []),
- "post_grad_nodes": post_grad_nodes,
- "pre_grad_nodes": list(pre_grad_nodes),
- }
- return result
- except Exception as e:
- signpost_event(
- "inductor",
- "provenance_tracking_error",
- {
- "function": "create_kernel_information_json",
- "error_msg": str(e),
- "stack_trace": traceback.format_exc(),
- },
- )
- return {}
- def set_kernel_post_grad_provenance_tracing(
- node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
- kernel_name: str,
- is_extern: bool = False,
- ) -> Optional[int]:
- """
- Set the mapping between `kernel_name` and the post_grad nodes in `node_schedule`.
- Returns a unique int debug handler for each call to this function.
- """
- if config.trace.provenance_tracking_level == 0:
- return None
- try:
- from .codegen.simd_kernel_features import DisableReduction, EnableReduction
- global _inductor_triton_kernel_to_post_grad_node_info
- global _inductor_kernel_stack_trace
- global _inductor_kernel_provenance_debug_handle
- _inductor_kernel_provenance_debug_handle += 1
- stack_traces: list[str] = []
- kernel_name = f"{kernel_name}:{_inductor_kernel_provenance_debug_handle}"
- if is_extern:
- assert isinstance(node_schedule, ExternKernel)
- curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
- kernel_name, []
- )
- # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
- # "origin_node" is more precise and says that the contents of this node corresponds
- # EXACTLY to the output of a particular FX node, but it's not always available
- if node_schedule.origin_node:
- origin_node_name = node_schedule.origin_node.name
- if origin_node_name not in curr_node_info:
- curr_node_info.append(origin_node_name)
- else:
- curr_node_info.extend(
- origin.name
- for origin in node_schedule.origins
- if origin.name not in curr_node_info
- )
- stack_traces = list(node_schedule.get_stack_traces())
- else:
- assert isinstance(node_schedule, list)
- stack_traces_set: OrderedSet[str] = OrderedSet()
- for snode in node_schedule:
- if snode not in (EnableReduction, DisableReduction):
- if snode.node is not None:
- curr_node_info = (
- _inductor_triton_kernel_to_post_grad_node_info.setdefault(
- kernel_name, []
- )
- )
- # pyrefly: ignore [missing-attribute]
- stack_traces_set.update(snode.node.get_stack_traces())
- curr_node_info.extend(
- origin.name
- # pyrefly: ignore [missing-attribute]
- for origin in snode.node.origins
- if origin.name not in curr_node_info
- )
- stack_traces = list(stack_traces_set)
- _inductor_kernel_stack_trace.setdefault(kernel_name, []).extend(stack_traces)
- return _inductor_kernel_provenance_debug_handle
- except Exception as e:
- # Since this is just debugging, it should never interfere with regular
- # program execution, so we use this try-except to guard against any error
- signpost_event(
- "inductor",
- "provenance_tracking_error",
- {
- "function": "set_kernel_post_grad_provenance_tracing",
- "error_msg": str(e),
- "stack_trace": traceback.format_exc(),
- },
- )
- return None
- def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
- """
- This function is used to save arguments for a compile_fx_inner function call
- to the file system. Later on one can replay the compile_fx_inner call
- with the saved arguments using load_args_and_run_compile_fx_inner.
- """
- folder = os.path.join(tempfile.gettempdir(), "inductor_saved_args")
- if not os.path.exists(folder):
- os.mkdir(folder)
- def handle_tensor(x: Any) -> Any:
- """
- Pickle FakeTensor will result in error:
- AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
- Convert all Tensor to metadata. This may also makes pickle faster.
- """
- if isinstance(x, torch.Tensor):
- return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
- else:
- return x
- args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
- fn_name = "compile_fx_inner"
- path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
- with open(path, "wb") as f:
- pickle.dump((args_to_save, kwargs_to_save), f)
- if log.isEnabledFor(logging.DEBUG):
- message = f"""
- Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
- run the following:
- from torch._inductor.debug import load_args_and_run_compile_fx_inner
- load_args_and_run_compile_fx_inner({path!r})
- """
- # call print rather than log.debug. log.debug will print message
- # prefix for each line which makes the code snippet harder to be
- # copied.
- # Not a big deal since the code is already been guarded by checking
- # the log level.
- print(message)
- def load_args_and_run_compile_fx_inner(path: str) -> Any:
- from torch._inductor.compile_fx import compile_fx_inner
- with open(path, "rb") as f:
- args, kwargs = pickle.load(f)
- def handle_tensor(x: Any) -> Any:
- if isinstance(x, TensorMetadataHolder):
- return torch._dynamo.testing.rand_strided(
- x.tensor_metadata.shape,
- x.tensor_metadata.stride,
- x.tensor_metadata.dtype,
- x.device,
- )
- else:
- return x
- fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
- with fake_mode, config.patch("save_args", False):
- args, kwargs = tree_map(handle_tensor, (args, kwargs))
- return compile_fx_inner(*args, **kwargs)
- def aot_inductor_minifier_wrapper(
- func: Callable[..., str],
- exported_program: torch.export.ExportedProgram,
- *,
- inductor_configs: dict[str, Any],
- package_path: Optional[FileLike] = None,
- ) -> str:
- from torch._dynamo.debug_utils import AccuracyError
- from torch._dynamo.repro.aoti import dump_to_minify
- from torch._inductor import config
- from torch._inductor.compile_fx import _aoti_flatten_inputs
- use_minifier = config.aot_inductor.dump_aoti_minifier
- gm = exported_program.module(check_guards=False)
- assert isinstance(gm, torch.fx.GraphModule)
- args, kwargs = exported_program.example_inputs
- try:
- if use_minifier and config.aot_inductor.repro_level == 3:
- # Always dump the original module in case we have segfaults
- dump_to_minify(
- exported_program,
- "aot_inductor",
- options=inductor_configs,
- )
- if use_minifier and config.aot_inductor.repro_level == 4:
- # Check for accuracy
- # We will first flatten the inputs before compiling and checking for accuracy.
- # This is ok because we will flatten the inputs in the minifier anyway.
- gm_copy = copy.deepcopy(gm)
- example_inputs_copy = copy.deepcopy(exported_program.example_inputs)
- config_copy = copy.deepcopy(inductor_configs)
- flat_example_inputs, config_copy = _aoti_flatten_inputs(
- gm_copy,
- example_inputs_copy[0],
- example_inputs_copy[1],
- options=config_copy,
- )
- tuple_inputs = tuple(flat_example_inputs)
- flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False)
- func(
- flattened_ep.module(check_guards=False),
- tuple_inputs,
- inductor_configs=config_copy,
- package_path=package_path,
- load_and_run=True,
- check_accuracy="accuracy",
- )
- return func(
- gm,
- args,
- kwargs,
- inductor_configs=inductor_configs,
- package_path=package_path,
- load_and_run=use_minifier,
- )
- except AccuracyError as e:
- dump_to_minify(
- exported_program,
- "aot_inductor_accuracy",
- command="minify",
- options=inductor_configs,
- )
- log.warning("Accuracy failed")
- raise e
- except Exception as e:
- if use_minifier:
- command = "minify"
- if config.aot_inductor.repro_level == 1:
- command = "run"
- dump_to_minify(
- exported_program,
- "aot_inductor",
- command=command,
- options=inductor_configs,
- )
- raise e
|