import dataclasses import importlib import inspect import io import logging import pickle import types from collections.abc import Callable, Sequence from contextlib import AbstractContextManager, ExitStack, nullcontext from dataclasses import dataclass from typing import Any, Optional, TYPE_CHECKING import torch import torch.fx from torch._dynamo.convert_frame import GraphRuntimeEnv from torch._dynamo.graph_utils import _graph_device_type from torch._dynamo.package import SystemInfo from . import convert_frame from .aot_compile_types import ( BundledAOTAutogradSerializableCallable, SerializableCallable, ) from .hooks import Hooks if TYPE_CHECKING: from .guards import GuardManagerWrapper from .package import SerializedCode, SourceInfo log = logging.getLogger(__name__) def bind_locals( signature: inspect.Signature, *args: Any, **kwargs: Any ) -> dict[str, Any]: bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() return bound_arguments.arguments @dataclass class CompileArtifacts: signature: inspect.Signature guard_manager: Optional["GuardManagerWrapper"] guards_state: bytes backend_id: str compiled_fn: SerializableCallable original_code: types.CodeType runtime_env: GraphRuntimeEnv source_info: "SourceInfo" device_type: str backend_name: str system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current) def check_compatibility(self) -> None: current_system = SystemInfo.current() current_system.check_compatibility(self.system_info, self.device_type) class AOTCompilePickler(pickle.Pickler): def __init__(self, external_data: dict[str, object], buf: io.BytesIO) -> None: super().__init__(buf) self.external_data = external_data self.id_map: dict[int, str] = { id(value): key for key, value in external_data.items() } self.errors = {} def persistent_id(self, obj: object) -> int | str | None: if id(obj) in self.id_map: return self.id_map[id(obj)] elif isinstance(obj, torch.nn.Module): self.errors[id(obj)] = obj return id(obj) else: return None @classmethod def _unpickle_cell(cls, val: object) -> object: def _() -> object: return val assert _.__closure__ is not None return _.__closure__[0] @classmethod # pyrefly: ignore [implicit-any] def _unpickle_bound_method(cls, func: Callable, base: object) -> types.MethodType: return types.MethodType(func, base) @classmethod def _unpickle_module(cls, name: str) -> types.ModuleType: return importlib.import_module(name) @classmethod def _unpickle_code(cls, serialized_code: "SerializedCode") -> types.CodeType: from torch._dynamo.package import SerializedCode return SerializedCode.to_code_object(serialized_code) @classmethod def _unpickle_nested_function( cls, code: types.CodeType, module: str, qualname: str, argdefs: tuple[object, ...] | None, closure: tuple[types.CellType, ...] | None, ) -> types.FunctionType: f_globals = importlib.import_module(module).__dict__ return types.FunctionType(code, f_globals, qualname, argdefs, closure) # pyrefly: ignore [bad-override] def reducer_override(self, obj: Any) -> Any: if isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002 return type(self)._unpickle_cell, (obj.cell_contents,) elif inspect.iscode(obj): from torch._dynamo.package import SerializedCode return type(self)._unpickle_code, (SerializedCode.from_code_object(obj),) elif inspect.ismodule(obj): return type(self)._unpickle_module, (obj.__name__,) elif inspect.ismethod(obj): """ By default, pickle will call getattr() directly on the self object for pickling bounded methods, this is not what we want, instead we always want to serialize the original function and the self object in their original form. """ func = obj.__func__ method_self = obj.__self__ inner_func = getattr(method_self, func.__name__) if inspect.ismethod(inner_func): inner_func = inner_func.__func__ if func is not inner_func: return type(self)._unpickle_bound_method, (func, method_self) elif inspect.isfunction(obj): if "" in obj.__qualname__: return type(self)._unpickle_nested_function, ( obj.__code__, obj.__module__, obj.__qualname__, obj.__defaults__, obj.__closure__, ) return NotImplemented class AOTCompileUnpickler(pickle.Unpickler): def __init__(self, external_data: dict[str, object], file: io.BytesIO) -> object: super().__init__(file) self.external_data = external_data def persistent_load(self, key: str) -> object: if key not in self.external_data: raise RuntimeError( f"Missing required external reference to data: {key}. " "Please load AOT compiled function with " "`external_data=`" f"{self.external_data}" ) return self.external_data[key] @dataclass class AOTCompileSaveResult: serialized_data: bytes @dataclass class AOTCompiledFunction: _artifacts: CompileArtifacts _guard_check_enabled: bool = True _extra_globals: dict[str, object] | None = None def prepare_f_locals(self, *args: object, **kwargs: object) -> dict[str, object]: f_locals: dict[str, object] = {} env = self._artifacts.runtime_env if env.closure: assert env.bytecode.co_freevars and len(env.closure) == len( env.bytecode.co_freevars ) f_locals = { name: cell.cell_contents for name, cell in zip(env.bytecode.co_freevars, env.closure) } f_locals.update(bind_locals(self._artifacts.signature, *args, **kwargs)) return f_locals def guard_check(self, *args: Any, **kwargs: Any) -> bool: f_locals = self.prepare_f_locals(*args, **kwargs) assert self._artifacts.guard_manager is not None return self._artifacts.guard_manager.check(f_locals) def __post_init__(self) -> None: from .package import load_guard_manager, load_guards_state self._artifacts.check_compatibility() self.fn = self._artifacts.runtime_env.forward_callable( self._artifacts.backend_id, self._artifacts.compiled_fn, extra_globals=self._extra_globals, ) if self._artifacts.guard_manager is None: guards_state = load_guards_state(self._artifacts.guards_state) self._artifacts.guard_manager = load_guard_manager( guards_state, self._artifacts.original_code, self.fn.__globals__, ) def __call__(self, *args: Any, **kwargs: Any) -> Any: assert self._artifacts.guard_manager is not None if self._guard_check_enabled and not self.guard_check(*args, **kwargs): f_locals = self.prepare_f_locals(*args, **kwargs) reason = str(self._artifacts.guard_manager.check_verbose(f_locals)) raise RuntimeError(f"GuardManager check failed, reason: {reason}") return self.fn(*args, **kwargs) def source_info(self) -> "SourceInfo": return self._artifacts.source_info def save_compiled_function( self, path: str, external_data: dict[str, Any] | None = None ) -> AOTCompileSaveResult: with open(path, "wb") as f: result = type(self).serialize(self, external_data) f.write(result.serialized_data) return result @classmethod def serialize( cls, fn: "AOTCompiledFunction", external_data: dict[str, Any] | None = None ) -> AOTCompileSaveResult: from torch._dynamo.package import SerializedCode state = fn._artifacts.__dict__.copy() state["guard_manager"] = None state["runtime_env"] = dataclasses.replace( state["runtime_env"], bytecode=SerializedCode.from_code_object(state["runtime_env"].bytecode), ) compiled_fn = state["compiled_fn"] state["compiled_fn"] = ( type(compiled_fn).deserialize_compile_artifacts, type(compiled_fn).serialize_compile_artifacts(compiled_fn), ) state["original_code"] = SerializedCode.from_code_object(state["original_code"]) buf = io.BytesIO() pickler = AOTCompilePickler(external_data or {}, buf) pickler.dump(state) if pickler.errors: raise RuntimeError( f"Failed to serialize the following objects: {list(pickler.errors.values())}\n" "Please mark these as external data by using `external_data={'key': ...}`" ) return AOTCompileSaveResult(serialized_data=buf.getvalue()) @classmethod def deserialize( cls, data: bytes, f_globals: dict[str, object] | None = None, external_closure_data: dict[str, Any] | None = None, ) -> "AOTCompiledFunction": from torch._dynamo.package import SerializedCode f = io.BytesIO(data) f.seek(0) unpickler = AOTCompileUnpickler(external_closure_data or {}, f) state = unpickler.load() f.close() state["runtime_env"] = dataclasses.replace( state["runtime_env"], bytecode=SerializedCode.to_code_object(state["runtime_env"].bytecode), ) deserializer, compiled_fn_state = state["compiled_fn"] with torch._inductor.config.patch(enable_autograd_for_aot=True): state["compiled_fn"] = deserializer(compiled_fn_state) state["original_code"] = SerializedCode.to_code_object(state["original_code"]) artifacts = CompileArtifacts(**state) return cls(artifacts, _extra_globals=f_globals) def disable_guard_check(self) -> None: self._guard_check_enabled = False def aot_compile_fullgraph( model: Any, example_inputs: tuple[tuple[Any, ...], dict[str, Any]], hooks: Hooks, backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable], dynamic: bool | None = None, ) -> AOTCompiledFunction: from torch._dynamo.guards import CheckFunctionManager from torch._dynamo.package import SourceInfo from torch._dynamo.utils import dynamo_timed, get_metrics_context from torch._guards import TracingContext args, kwargs = example_inputs dynamic_ctx = nullcontext() if dynamic is not None: from torch._dynamo.eval_frame import set_enable_dynamic dynamic_ctx = set_enable_dynamic(dynamic) with ( get_metrics_context(), dynamo_timed("fullgraph_capture"), torch._functorch.config.patch(strict_autograd_cache=True), dynamic_ctx, ): capture_output = convert_frame.fullgraph_capture(model, args, kwargs) graph_capture_output = capture_output.graph_capture_output assert graph_capture_output.output_graph is not None if not hooks.guard_filter_fn: from torch._dynamo.types import GuardFilterEntry def new_guard_filter_fn( guard_entries: Sequence[GuardFilterEntry], ) -> Sequence[bool]: return [ ( not ( g.is_global or g.guard_type in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES ) ) for g in guard_entries ] hooks.guard_filter_fn = new_guard_filter_fn fn, _ = convert_frame.get_traced_fn(model) backend_input = capture_output.backend_input assert backend_input is not None backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment] device_type = _graph_device_type(backend_input.graph_module.graph) assert ( backend_input.fake_mode.shape_env is graph_capture_output.output_graph.shape_env ) tracing_context = TracingContext(backend_input.fake_mode) tracing_context.tensor_to_context = backend_input.tensor_to_context with ( torch._guards.tracing(tracing_context), torch._functorch.config.patch( { "strict_autograd_cache": True, "bypass_autograd_cache_key": True, "bundled_autograd_cache": True, "force_non_lazy_backward_lowering": True, "force_autograd_cache": True, } ), ): compiled_fn = backend( backend_input.graph_module, backend_input.example_inputs ) # If Inductor backend is used, grab the compiled_fn from PrecompileContext # TODO: this should be replaced once we make the backend return the SerializableCallable directly. if isinstance(backend, torch._TorchCompileInductorWrapper) or ( hasattr(backend, "compiler_fn") and isinstance( backend.compiler_fn, torch._dynamo.backends.common.AotAutograd ) ): compiled_fn = BundledAOTAutogradSerializableCallable(compiled_fn) if not isinstance(compiled_fn, SerializableCallable): if hasattr(backend, "compiler_fn"): compiler_fn = backend.compiler_fn else: compiler_fn = backend raise RuntimeError( f"Compiled function type {type(compiled_fn)} (produced " + f"from backend {compiler_fn}) does not implement SerializableCallable." ) check_fn = graph_capture_output.build_guards( fn.__code__, hooks=hooks, save=True, strict_error=True ) assert check_fn.guards_state is not None source_info = SourceInfo(inlined_sources=set()) for traced_code in graph_capture_output.traced_code: source_info.add_code(traced_code) artifacts = CompileArtifacts( signature=convert_frame._get_signature(fn), guard_manager=check_fn.guard_manager, guards_state=check_fn.guards_state, backend_id=backend_input.backend_id, compiled_fn=compiled_fn, original_code=fn.__code__, runtime_env=graph_capture_output.get_runtime_env(), source_info=source_info, device_type=device_type, backend_name=getattr(backend, "compiler_name", "unknown"), ) aot_compiled_fn = AOTCompiledFunction( _artifacts=artifacts, _extra_globals=fn.__globals__ ) return aot_compiled_fn @dataclass class ModelInput: """ WIP type: represents a single model input Which consists of a tuple of arguments and a set of contexts in which to run the model. For each ModelInput, we'll compile one full graph of the model, and then use the guards generated to dispatch between the compiled graphs. """ args: tuple[Any] kwargs: dict[str, Any] contexts: list[AbstractContextManager[Any]] @dataclass class AOTCompiledModel: # Represents a single forward function of a model along with dispatch # compiled_results is serializable. We require the model to deserialize again. model: torch.nn.Module compiled_results: list[AOTCompiledFunction] def __call__(self, *args: Any, **kwargs: Any) -> Any: for result in self.compiled_results: if result.guard_check(self.model, *args, **kwargs): return result(self.model, *args, **kwargs) # All guards failed, just run one of them and throw the guard check error. return self.compiled_results[0](self.model, *args, **kwargs) def serialize(self) -> bytes: data: list[bytes] = [] for result in self.compiled_results: data.append(AOTCompiledFunction.serialize(result).serialized_data) return pickle.dumps(data) @classmethod def deserialize(cls, model: torch.nn.Module, data: bytes) -> "AOTCompiledModel": from torch._dynamo.utils import get_metrics_context from torch._guards import compile_context, CompileContext results: list[bytes] = pickle.loads(data) compiled_results = [] for result in results: with ( compile_context(CompileContext(convert_frame.get_compile_id({}))), get_metrics_context(), ): compiled_results.append(AOTCompiledFunction.deserialize(result)) return cls(model, compiled_results) def aot_compile_module( model: torch.nn.Module, inputs: list[ModelInput], hooks: Hooks, backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable], ) -> AOTCompiledModel: """ Compiles a single nn.Module with any number of inputs, and returns a compiled forward function. """ def compile_single_graph(model_input: ModelInput) -> AOTCompiledFunction: example_inputs = (model_input.args, model_input.kwargs) orig_forward = model.forward with ExitStack() as stack: for ctx in model_input.contexts: stack.enter_context(ctx) return aot_compile_fullgraph( orig_forward, example_inputs, hooks=hooks, backend=backend, ) # pyrefly: ignore [implicit-any] compiled_results = [] for model_input in inputs: log.info("Compiling input %s..", model_input) compiled_results.append(compile_single_graph(model_input)) assert len(compiled_results) > 0 return AOTCompiledModel(model, compiled_results)