| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 |
- import dataclasses
- import importlib
- import inspect
- import io
- import logging
- import pickle
- import types
- from collections.abc import Callable, Sequence
- from contextlib import AbstractContextManager, ExitStack, nullcontext
- from dataclasses import dataclass
- from typing import Any, Optional, TYPE_CHECKING
- import torch
- import torch.fx
- from torch._dynamo.convert_frame import GraphRuntimeEnv
- from torch._dynamo.graph_utils import _graph_device_type
- from torch._dynamo.package import SystemInfo
- from . import convert_frame
- from .aot_compile_types import (
- BundledAOTAutogradSerializableCallable,
- SerializableCallable,
- )
- from .hooks import Hooks
- if TYPE_CHECKING:
- from .guards import GuardManagerWrapper
- from .package import SerializedCode, SourceInfo
- log = logging.getLogger(__name__)
- def bind_locals(
- signature: inspect.Signature, *args: Any, **kwargs: Any
- ) -> dict[str, Any]:
- bound_arguments = signature.bind(*args, **kwargs)
- bound_arguments.apply_defaults()
- return bound_arguments.arguments
- @dataclass
- class CompileArtifacts:
- signature: inspect.Signature
- guard_manager: Optional["GuardManagerWrapper"]
- guards_state: bytes
- backend_id: str
- compiled_fn: SerializableCallable
- original_code: types.CodeType
- runtime_env: GraphRuntimeEnv
- source_info: "SourceInfo"
- device_type: str
- backend_name: str
- system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
- def check_compatibility(self) -> None:
- current_system = SystemInfo.current()
- current_system.check_compatibility(self.system_info, self.device_type)
- class AOTCompilePickler(pickle.Pickler):
- def __init__(self, external_data: dict[str, object], buf: io.BytesIO) -> None:
- super().__init__(buf)
- self.external_data = external_data
- self.id_map: dict[int, str] = {
- id(value): key for key, value in external_data.items()
- }
- self.errors = {}
- def persistent_id(self, obj: object) -> int | str | None:
- if id(obj) in self.id_map:
- return self.id_map[id(obj)]
- elif isinstance(obj, torch.nn.Module):
- self.errors[id(obj)] = obj
- return id(obj)
- else:
- return None
- @classmethod
- def _unpickle_cell(cls, val: object) -> object:
- def _() -> object:
- return val
- assert _.__closure__ is not None
- return _.__closure__[0]
- @classmethod
- # pyrefly: ignore [implicit-any]
- def _unpickle_bound_method(cls, func: Callable, base: object) -> types.MethodType:
- return types.MethodType(func, base)
- @classmethod
- def _unpickle_module(cls, name: str) -> types.ModuleType:
- return importlib.import_module(name)
- @classmethod
- def _unpickle_code(cls, serialized_code: "SerializedCode") -> types.CodeType:
- from torch._dynamo.package import SerializedCode
- return SerializedCode.to_code_object(serialized_code)
- @classmethod
- def _unpickle_nested_function(
- cls,
- code: types.CodeType,
- module: str,
- qualname: str,
- argdefs: tuple[object, ...] | None,
- closure: tuple[types.CellType, ...] | None,
- ) -> types.FunctionType:
- f_globals = importlib.import_module(module).__dict__
- return types.FunctionType(code, f_globals, qualname, argdefs, closure)
- # pyrefly: ignore [bad-override]
- def reducer_override(self, obj: Any) -> Any:
- if isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
- return type(self)._unpickle_cell, (obj.cell_contents,)
- elif inspect.iscode(obj):
- from torch._dynamo.package import SerializedCode
- return type(self)._unpickle_code, (SerializedCode.from_code_object(obj),)
- elif inspect.ismodule(obj):
- return type(self)._unpickle_module, (obj.__name__,)
- elif inspect.ismethod(obj):
- """
- By default, pickle will call getattr() directly on the self object
- for pickling bounded methods, this is not what we want, instead we
- always want to serialize the original function and the self object
- in their original form.
- """
- func = obj.__func__
- method_self = obj.__self__
- inner_func = getattr(method_self, func.__name__)
- if inspect.ismethod(inner_func):
- inner_func = inner_func.__func__
- if func is not inner_func:
- return type(self)._unpickle_bound_method, (func, method_self)
- elif inspect.isfunction(obj):
- if "<locals>" in obj.__qualname__:
- return type(self)._unpickle_nested_function, (
- obj.__code__,
- obj.__module__,
- obj.__qualname__,
- obj.__defaults__,
- obj.__closure__,
- )
- return NotImplemented
- class AOTCompileUnpickler(pickle.Unpickler):
- def __init__(self, external_data: dict[str, object], file: io.BytesIO) -> object:
- super().__init__(file)
- self.external_data = external_data
- def persistent_load(self, key: str) -> object:
- if key not in self.external_data:
- raise RuntimeError(
- f"Missing required external reference to data: {key}. "
- "Please load AOT compiled function with "
- "`external_data=<external data dictionary>`"
- f"{self.external_data}"
- )
- return self.external_data[key]
- @dataclass
- class AOTCompileSaveResult:
- serialized_data: bytes
- @dataclass
- class AOTCompiledFunction:
- _artifacts: CompileArtifacts
- _guard_check_enabled: bool = True
- _extra_globals: dict[str, object] | None = None
- def prepare_f_locals(self, *args: object, **kwargs: object) -> dict[str, object]:
- f_locals: dict[str, object] = {}
- env = self._artifacts.runtime_env
- if env.closure:
- assert env.bytecode.co_freevars and len(env.closure) == len(
- env.bytecode.co_freevars
- )
- f_locals = {
- name: cell.cell_contents
- for name, cell in zip(env.bytecode.co_freevars, env.closure)
- }
- f_locals.update(bind_locals(self._artifacts.signature, *args, **kwargs))
- return f_locals
- def guard_check(self, *args: Any, **kwargs: Any) -> bool:
- f_locals = self.prepare_f_locals(*args, **kwargs)
- assert self._artifacts.guard_manager is not None
- return self._artifacts.guard_manager.check(f_locals)
- def __post_init__(self) -> None:
- from .package import load_guard_manager, load_guards_state
- self._artifacts.check_compatibility()
- self.fn = self._artifacts.runtime_env.forward_callable(
- self._artifacts.backend_id,
- self._artifacts.compiled_fn,
- extra_globals=self._extra_globals,
- )
- if self._artifacts.guard_manager is None:
- guards_state = load_guards_state(self._artifacts.guards_state)
- self._artifacts.guard_manager = load_guard_manager(
- guards_state,
- self._artifacts.original_code,
- self.fn.__globals__,
- )
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- assert self._artifacts.guard_manager is not None
- if self._guard_check_enabled and not self.guard_check(*args, **kwargs):
- f_locals = self.prepare_f_locals(*args, **kwargs)
- reason = str(self._artifacts.guard_manager.check_verbose(f_locals))
- raise RuntimeError(f"GuardManager check failed, reason: {reason}")
- return self.fn(*args, **kwargs)
- def source_info(self) -> "SourceInfo":
- return self._artifacts.source_info
- def save_compiled_function(
- self, path: str, external_data: dict[str, Any] | None = None
- ) -> AOTCompileSaveResult:
- with open(path, "wb") as f:
- result = type(self).serialize(self, external_data)
- f.write(result.serialized_data)
- return result
- @classmethod
- def serialize(
- cls, fn: "AOTCompiledFunction", external_data: dict[str, Any] | None = None
- ) -> AOTCompileSaveResult:
- from torch._dynamo.package import SerializedCode
- state = fn._artifacts.__dict__.copy()
- state["guard_manager"] = None
- state["runtime_env"] = dataclasses.replace(
- state["runtime_env"],
- bytecode=SerializedCode.from_code_object(state["runtime_env"].bytecode),
- )
- compiled_fn = state["compiled_fn"]
- state["compiled_fn"] = (
- type(compiled_fn).deserialize_compile_artifacts,
- type(compiled_fn).serialize_compile_artifacts(compiled_fn),
- )
- state["original_code"] = SerializedCode.from_code_object(state["original_code"])
- buf = io.BytesIO()
- pickler = AOTCompilePickler(external_data or {}, buf)
- pickler.dump(state)
- if pickler.errors:
- raise RuntimeError(
- f"Failed to serialize the following objects: {list(pickler.errors.values())}\n"
- "Please mark these as external data by using `external_data={'key': ...}`"
- )
- return AOTCompileSaveResult(serialized_data=buf.getvalue())
- @classmethod
- def deserialize(
- cls,
- data: bytes,
- f_globals: dict[str, object] | None = None,
- external_closure_data: dict[str, Any] | None = None,
- ) -> "AOTCompiledFunction":
- from torch._dynamo.package import SerializedCode
- f = io.BytesIO(data)
- f.seek(0)
- unpickler = AOTCompileUnpickler(external_closure_data or {}, f)
- state = unpickler.load()
- f.close()
- state["runtime_env"] = dataclasses.replace(
- state["runtime_env"],
- bytecode=SerializedCode.to_code_object(state["runtime_env"].bytecode),
- )
- deserializer, compiled_fn_state = state["compiled_fn"]
- with torch._inductor.config.patch(enable_autograd_for_aot=True):
- state["compiled_fn"] = deserializer(compiled_fn_state)
- state["original_code"] = SerializedCode.to_code_object(state["original_code"])
- artifacts = CompileArtifacts(**state)
- return cls(artifacts, _extra_globals=f_globals)
- def disable_guard_check(self) -> None:
- self._guard_check_enabled = False
- def aot_compile_fullgraph(
- model: Any,
- example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
- hooks: Hooks,
- backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable],
- dynamic: bool | None = None,
- ) -> AOTCompiledFunction:
- from torch._dynamo.guards import CheckFunctionManager
- from torch._dynamo.package import SourceInfo
- from torch._dynamo.utils import dynamo_timed, get_metrics_context
- from torch._guards import TracingContext
- args, kwargs = example_inputs
- dynamic_ctx = nullcontext()
- if dynamic is not None:
- from torch._dynamo.eval_frame import set_enable_dynamic
- dynamic_ctx = set_enable_dynamic(dynamic)
- with (
- get_metrics_context(),
- dynamo_timed("fullgraph_capture"),
- torch._functorch.config.patch(strict_autograd_cache=True),
- dynamic_ctx,
- ):
- capture_output = convert_frame.fullgraph_capture(model, args, kwargs)
- graph_capture_output = capture_output.graph_capture_output
- assert graph_capture_output.output_graph is not None
- if not hooks.guard_filter_fn:
- from torch._dynamo.types import GuardFilterEntry
- def new_guard_filter_fn(
- guard_entries: Sequence[GuardFilterEntry],
- ) -> Sequence[bool]:
- return [
- (
- not (
- g.is_global
- or g.guard_type
- in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
- )
- )
- for g in guard_entries
- ]
- hooks.guard_filter_fn = new_guard_filter_fn
- fn, _ = convert_frame.get_traced_fn(model)
- backend_input = capture_output.backend_input
- assert backend_input is not None
- backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
- device_type = _graph_device_type(backend_input.graph_module.graph)
- assert (
- backend_input.fake_mode.shape_env
- is graph_capture_output.output_graph.shape_env
- )
- tracing_context = TracingContext(backend_input.fake_mode)
- tracing_context.tensor_to_context = backend_input.tensor_to_context
- with (
- torch._guards.tracing(tracing_context),
- torch._functorch.config.patch(
- {
- "strict_autograd_cache": True,
- "bypass_autograd_cache_key": True,
- "bundled_autograd_cache": True,
- "force_non_lazy_backward_lowering": True,
- "force_autograd_cache": True,
- }
- ),
- ):
- compiled_fn = backend(
- backend_input.graph_module, backend_input.example_inputs
- )
- # If Inductor backend is used, grab the compiled_fn from PrecompileContext
- # TODO: this should be replaced once we make the backend return the SerializableCallable directly.
- if isinstance(backend, torch._TorchCompileInductorWrapper) or (
- hasattr(backend, "compiler_fn")
- and isinstance(
- backend.compiler_fn, torch._dynamo.backends.common.AotAutograd
- )
- ):
- compiled_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
- if not isinstance(compiled_fn, SerializableCallable):
- if hasattr(backend, "compiler_fn"):
- compiler_fn = backend.compiler_fn
- else:
- compiler_fn = backend
- raise RuntimeError(
- f"Compiled function type {type(compiled_fn)} (produced "
- + f"from backend {compiler_fn}) does not implement SerializableCallable."
- )
- check_fn = graph_capture_output.build_guards(
- fn.__code__, hooks=hooks, save=True, strict_error=True
- )
- assert check_fn.guards_state is not None
- source_info = SourceInfo(inlined_sources=set())
- for traced_code in graph_capture_output.traced_code:
- source_info.add_code(traced_code)
- artifacts = CompileArtifacts(
- signature=convert_frame._get_signature(fn),
- guard_manager=check_fn.guard_manager,
- guards_state=check_fn.guards_state,
- backend_id=backend_input.backend_id,
- compiled_fn=compiled_fn,
- original_code=fn.__code__,
- runtime_env=graph_capture_output.get_runtime_env(),
- source_info=source_info,
- device_type=device_type,
- backend_name=getattr(backend, "compiler_name", "unknown"),
- )
- aot_compiled_fn = AOTCompiledFunction(
- _artifacts=artifacts, _extra_globals=fn.__globals__
- )
- return aot_compiled_fn
- @dataclass
- class ModelInput:
- """
- WIP type: represents a single model input
- Which consists of a tuple of arguments and a set of contexts in which to run the model.
- For each ModelInput, we'll compile one full graph of the model, and then use the guards generated
- to dispatch between the compiled graphs.
- """
- args: tuple[Any]
- kwargs: dict[str, Any]
- contexts: list[AbstractContextManager[Any]]
- @dataclass
- class AOTCompiledModel:
- # Represents a single forward function of a model along with dispatch
- # compiled_results is serializable. We require the model to deserialize again.
- model: torch.nn.Module
- compiled_results: list[AOTCompiledFunction]
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- for result in self.compiled_results:
- if result.guard_check(self.model, *args, **kwargs):
- return result(self.model, *args, **kwargs)
- # All guards failed, just run one of them and throw the guard check error.
- return self.compiled_results[0](self.model, *args, **kwargs)
- def serialize(self) -> bytes:
- data: list[bytes] = []
- for result in self.compiled_results:
- data.append(AOTCompiledFunction.serialize(result).serialized_data)
- return pickle.dumps(data)
- @classmethod
- def deserialize(cls, model: torch.nn.Module, data: bytes) -> "AOTCompiledModel":
- from torch._dynamo.utils import get_metrics_context
- from torch._guards import compile_context, CompileContext
- results: list[bytes] = pickle.loads(data)
- compiled_results = []
- for result in results:
- with (
- compile_context(CompileContext(convert_frame.get_compile_id({}))),
- get_metrics_context(),
- ):
- compiled_results.append(AOTCompiledFunction.deserialize(result))
- return cls(model, compiled_results)
- def aot_compile_module(
- model: torch.nn.Module,
- inputs: list[ModelInput],
- hooks: Hooks,
- backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable],
- ) -> AOTCompiledModel:
- """
- Compiles a single nn.Module with any number of inputs, and returns a compiled forward function.
- """
- def compile_single_graph(model_input: ModelInput) -> AOTCompiledFunction:
- example_inputs = (model_input.args, model_input.kwargs)
- orig_forward = model.forward
- with ExitStack() as stack:
- for ctx in model_input.contexts:
- stack.enter_context(ctx)
- return aot_compile_fullgraph(
- orig_forward,
- example_inputs,
- hooks=hooks,
- backend=backend,
- )
- # pyrefly: ignore [implicit-any]
- compiled_results = []
- for model_input in inputs:
- log.info("Compiling input %s..", model_input)
- compiled_results.append(compile_single_graph(model_input))
- assert len(compiled_results) > 0
- return AOTCompiledModel(model, compiled_results)
|