| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353 |
- from __future__ import annotations
- import contextlib
- import dataclasses
- import enum
- import functools
- import logging
- import re
- import threading
- import traceback
- import unittest.mock
- import weakref
- from abc import abstractmethod
- from collections import defaultdict
- from contextlib import contextmanager
- from dataclasses import dataclass
- from typing import Any, Generic, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar
- from typing_extensions import dataclass_transform
- import torch
- from torch.utils import _pytree as pytree
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- from torch.utils._traceback import CapturedTraceback, format_frame
- from torch.utils.weak import WeakTensorKeyDictionary
- log = logging.getLogger(__name__)
- if TYPE_CHECKING:
- from collections.abc import Callable, Generator, Iterator
- from types import CodeType
- import sympy
- from torch._dynamo.backends.distributed import DDPOptimizerContext
- from torch._dynamo.codegen import PyCodegen
- from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
- from torch._subclasses.fake_tensor import FakeTensorMode
- """
- torch._guards is the definitional source of truth for general purpose guard structures.
- An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
- and no guard installation notions here.
- """
- COMPILE_ID_PATTERN = re.compile(r"^(?P<frame_id>\d+)/(?P<frame_compile_id>\d+)$")
- CA_COMPILE_ID_PATTERN = re.compile(
- r"^!(?P<compiled_autograd_id>\d+)(?:/(?P<frame_id>\d+)/(?P<frame_compile_id>\d+))?$"
- )
- # [Note: Updating CompiledId]
- #
- # CompiledId represents a unique program-level identifier, and we want to keep that
- # property as the codebase evolves. This property is relied on even outside of the pytorch
- # repo, e.g. tlparse or other internal tooling. The in-memory format can be freely changed,
- # as those dependencies only consume the string serialization.
- #
- # The string form should be:
- # 1. Program-level uid: CompileId can uniquely identify a compiled graph.
- # 2. Storage efficient: This object is logged in nearly every entry. We should elide symbols when possible.
- # 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
- @dataclass(frozen=True, kw_only=True, slots=True)
- class CompileId:
- frame_id: int | None
- # This id is per-frame, and counts how many times we've compiled this
- # frame. This could have been a global id but having this be per-frame
- # gives you a better intuitive sense for how many recompiles have occurred
- # so far.
- frame_compile_id: int | None
- # torch.compiling a compiled autograd graph
- compiled_autograd_id: int | None = None
- # TODO: consider also tracking the recompilation count
- # See Note: Updating CompileId
- def __str__(self) -> str:
- # NOTE: Keep this in sync with both from_string and the tlparse repo
- if self.compiled_autograd_id is not None:
- if (self.frame_id is None) != (self.frame_compile_id is None):
- raise AssertionError(
- f"frame_id and frame_compile_id must both be None or both be set, "
- f"got frame_id={self.frame_id}, frame_compile_id={self.frame_compile_id}"
- )
- frame_str = ""
- if self.frame_id is not None:
- frame_str = f"/{self.frame_id}/{self.frame_compile_id}"
- return f"!{self.compiled_autograd_id}{frame_str}"
- else:
- if self.frame_id is None or self.frame_compile_id is None:
- raise AssertionError(
- f"frame_id and frame_compile_id must not be None, "
- f"got frame_id={self.frame_id}, frame_compile_id={self.frame_compile_id}"
- )
- return f"{self.frame_id}/{self.frame_compile_id}"
- @classmethod
- def from_string(cls, compile_id: str | None) -> CompileId | None:
- """
- Factory method that creates a CompileId from its string representation.
- Keep this in sync with the __str__ method.
- """
- if compile_id is None:
- return None
- try:
- for pattern in (COMPILE_ID_PATTERN, CA_COMPILE_ID_PATTERN):
- if match := pattern.match(compile_id):
- groups = match.groupdict()
- for k, v in groups.items():
- if v is not None:
- groups[k] = int(v)
- return cls(**groups) # type: ignore[arg-type]
- else:
- raise ValueError
- except Exception as e:
- raise ValueError(f"Invalid compile_id '{compile_id}'") from e
- class TraceId(NamedTuple):
- compile_id: CompileId
- # This starts off as 0, and every time we restart analysis it goes
- # up by one
- attempt: int
- def __str__(self) -> str:
- # Keep this in sync with tlparse repo
- if self.attempt == 0:
- return str(self.compile_id)
- else:
- return f"{self.compile_id}_{self.attempt}"
- class GuardSource(enum.Enum):
- LOCAL = 0
- GLOBAL = 1
- LOCAL_SPECIALIZED_NN_MODULE = 2
- GLOBAL_SPECIALIZED_NN_MODULE = 3
- CONSTANT = 4
- RANDOM_VALUE = 5
- SHAPE_ENV = 6
- LOCAL_FSDP_MODULE = 7
- GLOBAL_FSDP_MODULE = 8
- BACKWARD_STATE = 9
- EPHEMERAL = 10
- SYNTHETIC_LOCAL = 11
- LOCAL_UNSPECIALIZED_NN_MODULE = 12
- GLOBAL_UNSPECIALIZED_NN_MODULE = 13
- LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14
- GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15
- TEMP_LOCAL = 16
- def is_fsdp_module(self) -> bool:
- return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
- def is_specialized_nn_module(self) -> bool:
- import torch._dynamo.config as config
- if config._unsafe_skip_fsdp_module_guards:
- return (
- self
- in (
- GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
- )
- or self.is_fsdp_module()
- )
- return self in (
- GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
- )
- def is_unspecialized_nn_module(self) -> bool:
- return self in (
- GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
- GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- )
- def is_unspecialized_builtin_nn_module(self) -> bool:
- return self in (
- GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- )
- def is_local(self) -> bool:
- return self in (
- GuardSource.LOCAL,
- GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_FSDP_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
- GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
- )
- """
- Base class for a "GuardBuilder" role.
- The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
- confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
- to torchdynamo's GuardBuilder.
- Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
- on GuardSource's select function.
- There is value in keeping this GuardBuilderBase empty to keep layering clean.
- """
- class GuardBuilderBase:
- pass
- @dataclasses.dataclass(frozen=True)
- class SLoc:
- framework_loc: traceback.FrameSummary | str | None
- maybe_user_loc: str | None
- def __str__(self) -> str:
- floc = (
- self.framework_loc
- if isinstance(self.framework_loc, str)
- else format_frame(self.framework_loc)
- )
- if self.maybe_user_loc is not None:
- return f"{self.maybe_user_loc} ({floc})"
- else:
- return f"({floc})"
- class ShapeGuard(NamedTuple):
- expr: sympy.logic.boolalg.Boolean
- sloc: SLoc
- size_oblivious: bool
- @dataclasses.dataclass(slots=True)
- class Guard:
- # originating_source is the source that called the make_guard method to
- # construct this guard object. The property name specifies what exactly it
- # is the guard is guarding on. The meaning of the name is dependent on the
- # create_fn; you must look at the use-site inside create_fn to know what
- # name means.
- #
- # That being said, although you might think this is just a "name", name is
- # usually an arbitrary Python expression that will be evaluated with all
- # globals (and locals, if you create a LOCAL guard) to extract the Python
- # object that we want to perform guard tests on. This evaluation
- # typically happens in GuardBuilder.eval. In these cases, name is
- # typically produced by originating_source.name (not to be confused with
- # GuardSource - the property source).
- #
- # Occasionally, name is not a valid Python expression; sometimes
- # it is meaningless. Example create_fns that are like this include
- # GRAD_MODE and SHAPE_ENV.
- originating_source: Source
- create_fn: Callable[[GuardBuilderBase, Guard], None]
- # Export only. These values are written to at time of guard check_fn creation.
- guard_types: list[str] | None = None
- code_list: list[str] | None = None
- obj_weakref: object | None = None
- guarded_class_weakref: weakref.ReferenceType[Any] | None = None
- stack: CapturedTraceback | None = None
- user_stack: traceback.StackSummary | None = None
- _hash: int | None = None
- _unserializable: bool = False
- def __hash__(self) -> int:
- if self._hash is None:
- self._hash = hash((self.name, self.source, id(self.create_fn)))
- return self._hash
- def sort_key(self) -> tuple[bool, int, int, str, int]:
- # Put the duplicate input guards at the end. The duplicate guards have
- # two sources while guard.name only considers one source.
- is_duplicate_input = (
- isinstance(self.create_fn, functools.partial)
- and self.create_fn.func is torch._dynamo.guards.GuardBuilder.DUPLICATE_INPUT
- )
- return (
- is_duplicate_input,
- self.source.value if self.source else -1,
- len(self.name),
- self.name,
- self.inner_create_fn().__code__.co_firstlineno,
- )
- def __lt__(self, other: Guard) -> bool:
- return self.sort_key() < other.sort_key()
- def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
- if isinstance(self.create_fn, functools.partial):
- return self.create_fn.func
- else:
- return self.create_fn
- @property
- def name(self) -> str:
- return self.originating_source.name
- @property
- def source(self) -> GuardSource:
- return self.originating_source.guard_source
- @staticmethod
- def weakref_to_str(obj_weakref: object) -> str:
- """
- This is a workaround of a Python weakref bug.
- `obj_weakref` is instance returned by `weakref.ref`,
- `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
- class MyConfig(dict):
- def __getattr__(self, x):
- return self[x]
- obj = MyConfig(offset=5)
- obj_weakref = weakref.ref(obj)
- str(obj_weakref) # raise error: KeyError: '__name__'
- """
- if isinstance(obj_weakref, weakref.ReferenceType):
- obj = obj_weakref()
- if obj is not None:
- return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
- else:
- return f"<weakref at {hex(id(obj_weakref))}; dead>"
- else:
- return str(obj_weakref)
- def __repr__(self) -> str:
- s = f"""
- {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
- {{
- 'guard_types': {self.guard_types},
- 'code': {self.code_list},
- 'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
- 'guarded_class': {self.guarded_class_weakref}
- }}
- """
- return s
- def __str__(self) -> str:
- output = f"Name: {repr(self.name)}\n"
- source = self.source.name.lower() if self.source else ""
- output += f" Source: {source}\n"
- output += f" Create Function: {self.inner_create_fn().__name__}\n"
- output += f" Guard Types: {self.guard_types}\n"
- output += f" Code List: {self.code_list}\n"
- output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
- output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
- return output
- def create(self, builder: GuardBuilderBase) -> Any:
- try:
- return self.create_fn(builder, self)
- except Exception:
- log.exception("Error while creating guard:\n%s", str(self).rstrip())
- if self.stack:
- log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
- raise
- def is_specialized_nn_module(self) -> bool:
- return self.source.is_specialized_nn_module()
- def is_fsdp_module(self) -> bool:
- return self.source.is_fsdp_module()
- def is_local(self) -> bool:
- return self.source.is_local()
- def create_fn_name(self) -> str:
- if isinstance(self.create_fn, functools.partial):
- create_fn = self.create_fn.func # type: ignore[attr-defined]
- else:
- create_fn = self.create_fn
- return create_fn.__name__
- def set_export_info(
- self,
- guard_type: str,
- guarded_class: weakref.ReferenceType[Any] | None,
- code_list: list[str],
- obj_weakref: object,
- ) -> None:
- if not self.guard_types:
- self.guard_types = []
- self.guard_types.append(guard_type)
- if self.guarded_class_weakref not in (guarded_class, None):
- raise AssertionError(
- f"Guarded class id must be identical, or None, "
- f"got {self.guarded_class_weakref} vs {guarded_class}"
- )
- self.guarded_class_weakref = guarded_class
- if not self.code_list:
- self.code_list = code_list
- else:
- self.code_list.extend(code_list)
- # Some objects are ephemeral, e.g., list[slice(1, 2)]. If we have
- # multiple guards on the same object, the weakref can die between the
- # invocation of set_export_info calls. So a dead weakref is also
- # acceptable.
- is_valid = (
- self.obj_weakref in (obj_weakref, None)
- or callable(self.obj_weakref)
- and self.obj_weakref() is None
- )
- if not is_valid:
- raise AssertionError(
- f"Guarded object must be identical, None or ephemeral (dead weakref), "
- f"got {self.obj_weakref} vs {obj_weakref}"
- )
- self.obj_weakref = obj_weakref
- T = TypeVar("T")
- """
- Parent structure for guard env expressions.
- A GuardEnvExpr can have any subtype.
- Note: All subtypes must be handled exhaustively in
- torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
- """
- @dataclasses.dataclass(frozen=True)
- class GuardEnvExpr:
- pass
- """
- A class representing a pair of duplicate inputs.
- input_pos_a and input_pos_b are input positions we have deduped.
- """
- @dataclasses.dataclass(frozen=True)
- class DuplicateInputs(GuardEnvExpr):
- input_source_a: Source
- input_source_b: Source
- def __post_init__(self) -> None:
- if self.input_source_a == self.input_source_b:
- raise AssertionError(
- f"input_source_a and input_source_b must be different, "
- f"got {self.input_source_a}"
- )
- """
- A class representing storage overlap relations among inputs that aliases the same storage.
- Given that a set of tensors alias the same storage, this guard checks whether they actually
- have overlapping storages.
- While non_overlapping_sources represent input tensors that definitely don't have any storage
- overlapping with any other input, overlapping_sources represent tensors that either:
- 1. Do overlap some other input tensor
- 2. Might not overlap some other input tensor, but we are not sure
- """
- @dataclasses.dataclass(frozen=True)
- class StorageOverlap(GuardEnvExpr):
- overlapping_sources: list[Source]
- non_overlapping_sources: list[Source]
- """
- Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
- copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
- can also be taken in at restore_graphstate(T) calls.
- When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
- does not provide any guarantees around consistency, idempotency, or safety of calling its APIs, yet.
- In the future, it will have a closer coupling to a generic Checkpoint management system.
- """
- class Checkpointable(Generic[T]):
- @abstractmethod
- def copy_graphstate(self) -> T: ...
- @abstractmethod
- def restore_graphstate(self, state: T) -> None: ...
- class GuardsCheckpointState:
- """
- The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
- """
- dynamo_guards: OrderedSet[Guard]
- def __init__(self, dynamo_guards: OrderedSet[Guard]) -> None:
- self.dynamo_guards = dynamo_guards
- def diff(self, other: GuardsCheckpointState) -> Optional[OrderedSet[Guard]]:
- """
- Produces a delta against another GuardsCheckpointState.
- Returns None if no delta is found, otherwise, return an OrderedSet() of mismatched
- Guard type objects.
- """
- r = self.dynamo_guards.difference(other.dynamo_guards)
- if len(r) == 0:
- return None
- return r
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, GuardsCheckpointState):
- return False
- return self.diff(other) is None
- class ModuleContextCheckpointState:
- nn_modules: dict[str, torch.nn.Module] = {}
- def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
- self.nn_modules = nn_modules
- def diff(self, other: ModuleContextCheckpointState) -> set[str] | None:
- """
- Produces a delta against another ModuleContextCheckpointState.
- Returns None if no delta is found, otherwise, return a set() of mismatched
- module key names.
- """
- r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
- if len(r) == 0:
- return None
- return r
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, ModuleContextCheckpointState):
- return False
- return self.diff(other) is None
- class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
- def __init__(self) -> None:
- self.nn_modules: dict[str, Any] = {}
- def copy_graphstate(self) -> ModuleContextCheckpointState:
- return ModuleContextCheckpointState(dict(self.nn_modules))
- def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
- if not isinstance(state, ModuleContextCheckpointState):
- raise AssertionError(
- f"expected ModuleContextCheckpointState, got {type(state)}"
- )
- self.nn_modules = state.nn_modules
- class GlobalContextCheckpointState:
- global_state: dict[str, tuple[Callable, Any]] = {}
- def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
- self.global_state = global_states
- def diff(self, other: GlobalContextCheckpointState) -> set[str] | None:
- """
- Produces a delta against another GlobalContextCheckpointState.
- Returns None if no delta is found, otherwise, return a set() of mismatched
- global key names.
- """
- r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
- if len(r) == 0:
- return None
- return r
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, GlobalContextCheckpointState):
- return False
- return self.diff(other) is None
- class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
- """
- This keeps track of the global torch state during tracing of a function.
- For example, torch.is_grad_enabled.
- """
- _supported_global_states = {
- "grad_enabled",
- "autocast_enabled",
- "autocast_cpu_enabled",
- "autocast_gpu_dtype",
- "autocast_cpu_dtype",
- "autocast_cache_enabled",
- }
- def __init__(self) -> None:
- self.global_state: dict[str, tuple[Callable, Any]] = {}
- def copy_graphstate(self) -> GlobalContextCheckpointState:
- return GlobalContextCheckpointState(self.global_state)
- def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
- if not isinstance(state, GlobalContextCheckpointState):
- raise AssertionError(
- f"expected GlobalContextCheckpointState, got {type(state)}"
- )
- self.global_state = state.global_state
- if not (
- len(self.global_state) == len(self._supported_global_states)
- and set(self.global_state.keys()) == self._supported_global_states
- ):
- raise AssertionError(
- f"Global state mismatch: got keys {set(self.global_state.keys())}, "
- f"expected {self._supported_global_states}"
- )
- for func, args in self.global_state.values():
- func(args)
- # Like a Set[Guard] but will record the user stack on all guards at the
- # time they were installed at their destination
- class GuardsSet:
- def __init__(self, inner: Optional[OrderedSet[Guard]] = None) -> None:
- if inner is None:
- self.inner: OrderedSet[Guard] = OrderedSet()
- else:
- self.inner = inner
- def __iter__(self) -> Iterator[Guard]:
- return iter(self.inner)
- def __len__(self) -> int:
- return len(self.inner)
- # Subtraction along with bool is typically used to determine the delta of
- # added guards between checkpoints for higher order ops
- def __sub__(self, other: GuardsSet) -> GuardsSet:
- return GuardsSet(self.inner - other.inner)
- def __bool__(self) -> bool:
- return bool(self.inner)
- def add(
- self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
- ) -> None:
- if guard in self.inner:
- return
- if collect_debug_stack:
- if guard.stack is None:
- guard.stack = CapturedTraceback.extract(skip=1 + skip)
- if guard.user_stack is None:
- guard.user_stack = TracingContext.extract_stack()
- self.inner.add(guard)
- def update(self, *others: set[Guard]) -> None:
- for o in others:
- for g in o:
- self.add(g, skip=1)
- def remove_guards_with_source(self, source: Source) -> None:
- """Delete all guards that contains a given source"""
- from ._dynamo.source import is_from_source
- self.inner = OrderedSet(
- g for g in self.inner if not is_from_source(g.originating_source, source)
- )
- """
- A GuardsContext is a checkpointable representation of all the guards in the current tracing
- context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
- directly outside of it. For passing around internal state representations of this object,
- prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
- """
- class GuardsContext(Checkpointable[GuardsCheckpointState]):
- def __init__(self) -> None:
- self.dynamo_guards: GuardsSet = GuardsSet()
- self.aotautograd_guards: list[GuardEnvExpr] = []
- def copy_graphstate(self) -> GuardsCheckpointState:
- return GuardsCheckpointState(OrderedSet(self.dynamo_guards.inner))
- def restore_graphstate(self, state: GuardsCheckpointState) -> None:
- # NB: "steals" the passed in state
- if not isinstance(state, GuardsCheckpointState):
- raise AssertionError(f"expected GuardsCheckpointState, got {type(state)}")
- self.dynamo_guards = GuardsSet(state.dynamo_guards)
- class HopSubgraphCache:
- @abstractmethod
- def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
- @abstractmethod
- def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
- @abstractmethod
- def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
- @abstractmethod
- def get_autograd_key_entry(self, identifier: str) -> Callable | None: ...
- @abstractmethod
- def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
- @abstractmethod
- def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None: ...
- @abstractmethod
- def add_lazy_bwd_entry(
- self,
- identifier: str,
- tangent_metadata: tuple[object],
- gmod: torch.fx.GraphModule,
- ) -> int: ...
- @abstractmethod
- def get_lazy_bwd_entry(
- self, identifier: str, tangent_metadata: tuple[object]
- ) -> tuple[torch.fx.GraphModule | None, int | None]: ...
- class InvokeSubgraphCache(HopSubgraphCache):
- def __init__(self) -> None:
- self.autograd_cache: dict[str, Callable] = {}
- self.proxy_dispatch_cache: dict[str, Callable] = {}
- self.dynamo_installed_submodules: dict[int, list[str]] = defaultdict(list)
- self.lazy_bwd_cache: dict[
- str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
- ] = defaultdict(dict)
- self.effects_cache: dict[
- str, set
- ] = {} # Maps identifier -> set of effect types
- def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
- self.dynamo_installed_submodules[fn_id].append(identifier)
- def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
- return self.dynamo_installed_submodules.get(fn_id, [])
- def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
- self.autograd_cache[identifier] = key
- def get_autograd_key_entry(self, identifier: str) -> Callable | None:
- return self.autograd_cache.get(identifier, None)
- def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
- self.proxy_dispatch_cache[identifier] = key
- def get_proxy_dispatch_entry(self, identifier: str) -> Callable | None:
- return self.proxy_dispatch_cache.get(identifier, None)
- def add_lazy_bwd_entry(
- self,
- identifier: str,
- tangent_metadata: tuple[object],
- gmod: torch.fx.GraphModule,
- ) -> int:
- # Save the number of existing graph modules in the dictionary to get the suffix
- num_gmods = len(self.lazy_bwd_cache[identifier])
- self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
- return num_gmods
- def get_lazy_bwd_entry(
- self, identifier: str, tangent_metadata: tuple[object]
- ) -> tuple[torch.fx.GraphModule | None, int | None]:
- if identifier not in self.lazy_bwd_cache:
- return (None, None)
- return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))
- def add_effects(self, identifier: str, effects: set) -> None:
- """Store the effect types for a given invoke_subgraph identifier."""
- if prev_effects := self.effects_cache.get(identifier, None):
- if effects != prev_effects:
- raise AssertionError(
- "Different number of effects were found for invoke_subgraph "
- f"call with identifier {identifier}. \n"
- f"Previously we had the following effects: {prev_effects}.\n"
- f"But now we have: {effects}."
- )
- self.effects_cache[identifier] = effects
- def get_effects(self, identifier: str) -> set | None:
- """Retrieve the effect types for a given invoke_subgraph identifier."""
- return self.effects_cache.get(identifier, None)
- class HopDispatchSetCache:
- def __init__(self) -> None:
- # Delayed import to avoid circular dependency
- from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
- self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()}
- def get_cache(self, op: torch._ops.HigherOrderOperator) -> HopSubgraphCache | None:
- if op not in self.hop_cache_map:
- return None
- return self.hop_cache_map[op] # type: ignore[index]
- _TLS = threading.local()
- """
- TracingContext is the source of truth for all currently accumulated information
- needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
- are open to managing their own TracingContext with that in mind.
- The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
- having to plumb complex subsystems across multiple verticals.
- Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
- Accessing the current tracing context via
- TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
- to plumb objects back up to where frame interpretation happened.
- Note that you can end up with multiple TracingContext for a single compilation
- of a frame, as we reset the TracingContext whenever we restart analysis.
- CompileContext is a more overarching context that encompasses multiple restarts.
- """
- class CompileContext:
- @staticmethod
- def get() -> CompileContext:
- if _TLS.compile_context is None:
- raise AssertionError("compile_context is not set")
- return _TLS.compile_context
- @staticmethod
- def try_get() -> CompileContext | None:
- return getattr(_TLS, "compile_context", None)
- def __init__(self, compile_id: CompileId | None) -> None:
- if compile_id is not None and not isinstance(compile_id, CompileId):
- raise AssertionError(
- f"compile_id must be None or CompileId, got {type(compile_id)}"
- )
- self.compile_id: CompileId | None = compile_id
- self.attempt = 0
- # Verbose ShapeEnv guards produced.
- self.shape_env_guards: list[str] = []
- @staticmethod
- def current_compile_id() -> CompileId | None:
- self = CompileContext.try_get()
- if self is None:
- return None
- return self.compile_id
- @staticmethod
- def current_trace_id() -> TraceId | None:
- self = CompileContext.try_get()
- if self is None:
- return None
- if self.compile_id is None:
- return None
- return TraceId(self.compile_id, self.attempt)
- @dataclass
- class InlinedCodeCache:
- """Cache for code-object-derived data used during inlining."""
- instructions: list[Any]
- indexof: dict[Any, int]
- code_options: dict[str, Any]
- class TracingContext:
- """
- Provides the currently installed TracingContext, or None.
- Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
- will return None.
- """
- @staticmethod
- def try_get() -> TracingContext | None:
- return getattr(_TLS, "tracing_context", None)
- @staticmethod
- def get() -> TracingContext:
- if ctx := TracingContext.try_get():
- return ctx
- raise RuntimeError(
- "TracingContext.get() must be called within an ongoing trace."
- )
- def __init__(self, fake_mode: FakeTensorMode | None) -> None:
- self.guards_context = GuardsContext()
- self.module_context = ModuleContext()
- self.global_context = GlobalContext()
- self.previously_inlined_functions: dict[Any, Any] = dict()
- self.previously_cleaned_instructions: dict[Any, Any] = dict()
- # Combined cache for inlined code data (instructions, indexof, code_options)
- self.inlined_code_cache: dict[Any, InlinedCodeCache] = dict()
- self.fake_mode: FakeTensorMode | None = fake_mode
- self.frame_summary_stack: list[traceback.FrameSummary] = []
- # This is morally part of frame_summary_stack, but it is kept separate
- # for clarity. As we process a frame, this variable gets updated
- # to keep track of what line we are in the function. We make a
- # function call, this gets cleared and the frame location is pushed
- # to frame_summary_stack (prepping this variable for the inner frame's
- # progress)
- self.loc_in_frame: tuple[str, int, str] | None = None
- # this is only set after aot_autograd
- self.fw_metadata: ViewAndMutationMeta | None = None
- # this is only set when the DDPOptimizer is used
- self.ddp_optimizer_ctx: DDPOptimizerContext | None = None
- # this is only set after aot_autograd
- self.aot_graph_name: list[str] | None = None
- self.params_flat: list[Any] | None = None
- self.params_flat_unwrap_subclasses: list[Any] | None = None
- self.params_unwrapped_to_flat_index: list[Any] | None = None
- # this is for extended return calling convention from backend
- # compiler to aot_autograd
- # Per output, what the compiler specified stride of the output is,
- # or None if no stride is known. This is always the HINT, it
- # is never a SymInt (it would be better if it was a SymInt, but
- # I can't conveniently get this from Inductor atm. Also, be
- # careful not to accidentally induce guards on the SymInt if
- # you ever do change this in aot_autograd.py; you should check
- # on permutations preferentially.)
- self.output_strides: list[tuple[int, ...] | None] | None = None
- # When this is True, whenever we encounter an int in Dynamo tracing,
- # we will (1) force unspec it and (2) force it as a size-like unbacked
- # integer. This is currently used when processing certain lists of
- # ints that are known to be size-like and may have 0/1 entries that we
- # must not specialize on.
- self.force_unspec_int_unbacked_size_like = False
- # See note [Tensor Fakification and Symbol Caching]
- self.tensor_to_context = WeakTensorKeyDictionary()
- # If this true, Aot Autograd will return output Fake Tensors with appropriate
- # meta on the first invocation
- # see note: [Returning Fake Tensors on First AOT Autograd Call]
- self.fakify_first_call = False
- self.hop_dispatch_set_cache = HopDispatchSetCache()
- # list of code objects for inlined functions
- self.traced_code: list[CodeType] = []
- def clear(self) -> None:
- # Look at the note in output_graph.py in function `save_global_state`
- # for the context on clearing global context.
- self.global_context.global_state = {}
- self.previously_inlined_functions.clear()
- self.previously_cleaned_instructions.clear()
- self.inlined_code_cache.clear()
- @staticmethod
- @contextmanager
- def patch(**kwargs: Any) -> Generator[None, None, None]:
- prior = {}
- ctx = TracingContext.get()
- for key in kwargs:
- # KeyError on invalid entry
- prior[key] = getattr(ctx, key)
- for key, val in kwargs.items():
- setattr(ctx, key, val)
- try:
- yield
- finally:
- for key, val in prior.items():
- setattr(ctx, key, val)
- @staticmethod
- def extract_stack() -> traceback.StackSummary:
- self = TracingContext.try_get()
- if self is None:
- return traceback.StackSummary()
- stack = self.frame_summary_stack
- if self.loc_in_frame is not None:
- stack = stack + [self._populate_loc_in_frame_summary()]
- return traceback.StackSummary.from_list(stack)
- def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
- if self.loc_in_frame is None:
- raise AssertionError("loc_in_frame must not be None")
- filename, lineno, frame_name = self.loc_in_frame
- return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
- # Call this when you want to call into some code that isn't necessarily
- # associated with the current frame state
- @staticmethod
- @contextlib.contextmanager
- def clear_frame() -> Generator[None, None, None]:
- tc = TracingContext.get()
- with (
- unittest.mock.patch.object(tc, "frame_summary_stack", []),
- unittest.mock.patch.object(tc, "loc_in_frame", None),
- ):
- try:
- yield
- except Exception as e:
- # Prevent real_stack from getting attached
- #
- # The invariant is that if an Exception as real_stack, we've
- # appropriately attached a user stack and we no longer need to
- # attach anything. Because we cannot conveniently interpose
- # when an exception is thrown, we instead interpose everywhere
- # we set what the user stack is set (using the context
- # manager). However, our compiler stack does "tail calls"
- # (when it calls into user compiler), at which point the
- # parent exception frames would incorrectly attach an
- # incorrect frame.
- #
- # However, if, somehow, someone raised an exception with this
- # scope that had a stack (for example, because they are
- # restoring the user stack state appropriately as they process
- # node by node), we should respect it. Thus, we cannot
- # unconditionally set None.
- if not hasattr(e, "real_stack"):
- e.real_stack = None # type: ignore[attr-defined]
- raise
- @staticmethod
- @contextlib.contextmanager
- def current_frame(
- frame_summary: traceback.FrameSummary | None,
- ) -> Generator[None, None, None]:
- # frame_summary can be None to solely take advantage of real_stack
- # attachment to thrown exceptions
- tc = TracingContext.get()
- if frame_summary is not None:
- tc.frame_summary_stack.append(frame_summary)
- old = tc.loc_in_frame
- tc.loc_in_frame = None
- try:
- yield
- except Exception as e:
- if not hasattr(e, "real_stack"):
- e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
- raise
- finally:
- if frame_summary is not None:
- tc.frame_summary_stack.pop()
- tc.loc_in_frame = old
- @staticmethod
- @contextlib.contextmanager
- def report_output_strides() -> Generator[
- list[tuple[int, ...] | None] | None, None, None
- ]:
- tc = TracingContext.try_get()
- if tc is None:
- yield None
- return
- old_output_strides = tc.output_strides
- tc.output_strides = []
- try:
- yield tc.output_strides
- finally:
- tc.output_strides = old_output_strides
- @staticmethod
- def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
- # Save the current location in the frame. Lazily generate the
- # framesummary.
- TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
- @staticmethod
- def get_traced_code() -> list[CodeType] | None:
- tc = TracingContext.try_get()
- if tc is None:
- return None
- return tc.traced_code
- @contextmanager
- def compile_context(
- context: CompileContext | None,
- ) -> Generator[CompileContext | None, None, None]:
- old_context = getattr(_TLS, "compile_context", None)
- _TLS.compile_context = context
- try:
- yield context
- finally:
- _TLS.compile_context = old_context
- @contextmanager
- def tracing(
- context: TracingContext | None,
- ) -> Generator[TracingContext | None, None, None]:
- """
- This function installs the passed in tracing context as a dynamic scoped
- global variable.
- Calls to TracingContext.get() while not under a `with tracing()` context
- will return None.
- """
- old_context = getattr(_TLS, "tracing_context", None)
- _TLS.tracing_context = context
- try:
- yield context
- except Exception as e:
- if not hasattr(e, "real_stack") and context is not None:
- e.real_stack = context.extract_stack() # type: ignore[attr-defined]
- raise
- finally:
- if (
- context is not None
- and context.fake_mode is not None
- and context.fake_mode.shape_env is not None
- ):
- context.fake_mode.shape_env.cleanup()
- _TLS.tracing_context = old_context
- @overload
- def dataclass_with_cached_hash(cls: type[T], **kwargs: Any) -> type[T]: ...
- @overload
- def dataclass_with_cached_hash(
- cls: None = None, **kwargs: Any
- ) -> Callable[[type[T]], type[T]]: ...
- @dataclass_transform()
- def dataclass_with_cached_hash(
- cls: type[T] | None = None, **kwargs: Any
- ) -> type[T] | Callable[[type[T]], type[T]]:
- def wrap(cls_inner: type[T]) -> type[T]:
- new_cls = dataclasses.dataclass(cls_inner, **kwargs)
- old_hash = cls_inner.__hash__
- def __hash__(self) -> int:
- if not hasattr(self, "_hash"):
- object.__setattr__(self, "_hash", old_hash(self))
- return self._hash
- def __reduce__(self):
- # Exclude _hash from pickling to ensure deterministic cache keys.
- # The _hash is a cached value that can be nondeterministically computed
- # (e.g., based on id() of objects), so it should not affect pickling.
- fields = dataclasses.fields(self)
- field_values = tuple(getattr(self, f.name) for f in fields)
- return (self.__class__, field_values)
- new_cls.__hash__ = __hash__
- new_cls.__reduce__ = __reduce__
- return new_cls # type: ignore[return-value]
- if cls is None:
- return wrap
- return wrap(cls)
- # Subclasses can be found in torch/_dynamo/source.py
- # TODO(voz): Consider a toplevel torch/_source.py
- @dataclass_with_cached_hash(frozen=True)
- class Source:
- def is_dict_key(self) -> bool:
- return False
- def is_ephemeral(self) -> bool:
- return False
- def reconstruct(self, codegen: PyCodegen) -> None:
- raise NotImplementedError
- @functools.cached_property
- def guard_source(self) -> GuardSource:
- raise NotImplementedError
- @property
- def _name_template(self) -> str:
- """
- A template for the name of the source. Used to prevent code duplication between
- `name` and `get_value`.
- For non-ChainedSources, `name` and `get_value` use the returned string directly.
- For ChainedSources, `name` and `get_value` expect the return to be a format string
- with `{0}` present - `name` and `get_value` will apply different values to this function's
- returned format string.
- """
- raise NotImplementedError
- @functools.cached_property
- def name(self) -> str:
- return self._name_template
- def get_value(
- self,
- globals: dict[str, Any],
- locals: dict[str, Any],
- cache: weakref.WeakKeyDictionary[Source, Any],
- ) -> Any:
- if self in cache:
- return cache[self]
- value = eval(self._name_template, globals, locals)
- cache[self] = value
- return value
- def make_guard(self, fn: Callable[..., Any]) -> Guard:
- if self.guard_source is GuardSource.CONSTANT:
- raise NotImplementedError
- return Guard(self, fn)
- def is_specialized_nn_module(self) -> bool:
- return self.guard_source.is_specialized_nn_module()
- def subguards_allowed(self) -> bool:
- """True if you can guard on attributes of this"""
- return self.guard_source != GuardSource.SYNTHETIC_LOCAL
- # Subclasses can be found in torch/_dynamo/source.py
- @dataclass_with_cached_hash(frozen=True)
- class ChainedSource(Source):
- base: Source
- def is_dict_key(self) -> bool:
- # Recurse until you either hit a ConstDictKey or a Source
- return self.base.is_dict_key()
- def is_ephemeral(self) -> bool:
- return self.base.is_ephemeral()
- @functools.cached_property
- def guard_source(self) -> GuardSource:
- return self.base.guard_source
- def get_base(self) -> Source:
- current: Source = self
- while isinstance(current, ChainedSource):
- current = current.base
- return current
- @functools.cached_property
- def name(self) -> str:
- return self._name_template.format(self.base.name)
- def get_value(
- self,
- globals: dict[str, Any],
- locals: dict[str, Any],
- cache: weakref.WeakKeyDictionary[Source, Any],
- ) -> Any:
- if self in cache:
- return cache[self]
- tmpvar = "tmp"
- counter = 0
- while tmpvar in locals:
- tmpvar = f"tmp{counter}"
- counter += 1
- locals[tmpvar] = self.base.get_value(globals, locals, cache)
- value = eval(self._name_template.format(tmpvar), globals, locals)
- del locals[tmpvar]
- cache[self] = value
- return value
- def detect_fake_mode(inputs: Any = None) -> FakeTensorMode | None:
- """
- Attempts to "detect" what the current fake mode is. If there is one ambiently
- available from TracingContext, we preferentially use that. Otherwise, we
- heuristically detect the fake mode via the following sources, in order of
- priority:
- - Currently active fake mode on stack
- - Fake mode associated with passed in tensors (inputs does not
- have to be flattened)
- """
- from torch._subclasses.fake_tensor import (
- FakeTensor,
- FakeTensorMode,
- get_plain_tensors,
- )
- # If TracingContext has a fake_mode, use it authoritatively.
- # This is the case when Dynamo is driving compilation - any fake tensors
- # from other modes in the inputs will be refakified by the caller.
- if context := TracingContext.try_get():
- fake_mode = context.fake_mode
- if fake_mode is not None:
- return fake_mode
- fake_modes = []
- from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
- for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
- if isinstance(m, FakeTensorMode):
- fake_modes.append((m, "active fake mode", i))
- flat_inputs = pytree.tree_leaves(inputs)
- for i, flat_input in enumerate(flat_inputs):
- if isinstance(flat_input, FakeTensor):
- fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
- if is_traceable_wrapper_subclass(flat_input):
- out: list[torch.Tensor | int | torch.SymInt] = []
- get_plain_tensors(flat_input, out=out) # type: ignore[arg-type]
- fake_tensors: list[FakeTensor] = [
- x for x in out if isinstance(x, FakeTensor)
- ]
- fake_modes.extend(
- [
- (tensor.fake_mode, f"subclass input {i}", ix)
- for ix, tensor in enumerate(fake_tensors)
- ]
- )
- if fake_modes:
- fake_mode, desc1, i1 = fake_modes[0]
- for m, desc2, i2 in fake_modes[1:]:
- if fake_mode is not m:
- raise AssertionError(
- f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
- f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
- f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
- )
- return fake_mode
- else:
- return None
- def active_fake_mode() -> FakeTensorMode | None:
- """
- Inspects the dispatch mode stack for an active fake mode and returns it.
- Returns None if no fake mode is active.
- """
- from torch._subclasses.fake_tensor import FakeTensorMode
- from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
- for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
- if isinstance(m, FakeTensorMode):
- return m
- return None
|