aot_compile.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. import dataclasses
  2. import importlib
  3. import inspect
  4. import io
  5. import logging
  6. import pickle
  7. import types
  8. from collections.abc import Callable, Sequence
  9. from contextlib import AbstractContextManager, ExitStack, nullcontext
  10. from dataclasses import dataclass
  11. from typing import Any, Optional, TYPE_CHECKING
  12. import torch
  13. import torch.fx
  14. from torch._dynamo.convert_frame import GraphRuntimeEnv
  15. from torch._dynamo.graph_utils import _graph_device_type
  16. from torch._dynamo.package import SystemInfo
  17. from . import convert_frame
  18. from .aot_compile_types import (
  19. BundledAOTAutogradSerializableCallable,
  20. SerializableCallable,
  21. )
  22. from .hooks import Hooks
  23. if TYPE_CHECKING:
  24. from .guards import GuardManagerWrapper
  25. from .package import SerializedCode, SourceInfo
  26. log = logging.getLogger(__name__)
  27. def bind_locals(
  28. signature: inspect.Signature, *args: Any, **kwargs: Any
  29. ) -> dict[str, Any]:
  30. bound_arguments = signature.bind(*args, **kwargs)
  31. bound_arguments.apply_defaults()
  32. return bound_arguments.arguments
  33. @dataclass
  34. class CompileArtifacts:
  35. signature: inspect.Signature
  36. guard_manager: Optional["GuardManagerWrapper"]
  37. guards_state: bytes
  38. backend_id: str
  39. compiled_fn: SerializableCallable
  40. original_code: types.CodeType
  41. runtime_env: GraphRuntimeEnv
  42. source_info: "SourceInfo"
  43. device_type: str
  44. backend_name: str
  45. system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
  46. def check_compatibility(self) -> None:
  47. current_system = SystemInfo.current()
  48. current_system.check_compatibility(self.system_info, self.device_type)
  49. class AOTCompilePickler(pickle.Pickler):
  50. def __init__(self, external_data: dict[str, object], buf: io.BytesIO) -> None:
  51. super().__init__(buf)
  52. self.external_data = external_data
  53. self.id_map: dict[int, str] = {
  54. id(value): key for key, value in external_data.items()
  55. }
  56. self.errors = {}
  57. def persistent_id(self, obj: object) -> int | str | None:
  58. if id(obj) in self.id_map:
  59. return self.id_map[id(obj)]
  60. elif isinstance(obj, torch.nn.Module):
  61. self.errors[id(obj)] = obj
  62. return id(obj)
  63. else:
  64. return None
  65. @classmethod
  66. def _unpickle_cell(cls, val: object) -> object:
  67. def _() -> object:
  68. return val
  69. assert _.__closure__ is not None
  70. return _.__closure__[0]
  71. @classmethod
  72. # pyrefly: ignore [implicit-any]
  73. def _unpickle_bound_method(cls, func: Callable, base: object) -> types.MethodType:
  74. return types.MethodType(func, base)
  75. @classmethod
  76. def _unpickle_module(cls, name: str) -> types.ModuleType:
  77. return importlib.import_module(name)
  78. @classmethod
  79. def _unpickle_code(cls, serialized_code: "SerializedCode") -> types.CodeType:
  80. from torch._dynamo.package import SerializedCode
  81. return SerializedCode.to_code_object(serialized_code)
  82. @classmethod
  83. def _unpickle_nested_function(
  84. cls,
  85. code: types.CodeType,
  86. module: str,
  87. qualname: str,
  88. argdefs: tuple[object, ...] | None,
  89. closure: tuple[types.CellType, ...] | None,
  90. ) -> types.FunctionType:
  91. f_globals = importlib.import_module(module).__dict__
  92. return types.FunctionType(code, f_globals, qualname, argdefs, closure)
  93. # pyrefly: ignore [bad-override]
  94. def reducer_override(self, obj: Any) -> Any:
  95. if isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
  96. return type(self)._unpickle_cell, (obj.cell_contents,)
  97. elif inspect.iscode(obj):
  98. from torch._dynamo.package import SerializedCode
  99. return type(self)._unpickle_code, (SerializedCode.from_code_object(obj),)
  100. elif inspect.ismodule(obj):
  101. return type(self)._unpickle_module, (obj.__name__,)
  102. elif inspect.ismethod(obj):
  103. """
  104. By default, pickle will call getattr() directly on the self object
  105. for pickling bounded methods, this is not what we want, instead we
  106. always want to serialize the original function and the self object
  107. in their original form.
  108. """
  109. func = obj.__func__
  110. method_self = obj.__self__
  111. inner_func = getattr(method_self, func.__name__)
  112. if inspect.ismethod(inner_func):
  113. inner_func = inner_func.__func__
  114. if func is not inner_func:
  115. return type(self)._unpickle_bound_method, (func, method_self)
  116. elif inspect.isfunction(obj):
  117. if "<locals>" in obj.__qualname__:
  118. return type(self)._unpickle_nested_function, (
  119. obj.__code__,
  120. obj.__module__,
  121. obj.__qualname__,
  122. obj.__defaults__,
  123. obj.__closure__,
  124. )
  125. return NotImplemented
  126. class AOTCompileUnpickler(pickle.Unpickler):
  127. def __init__(self, external_data: dict[str, object], file: io.BytesIO) -> object:
  128. super().__init__(file)
  129. self.external_data = external_data
  130. def persistent_load(self, key: str) -> object:
  131. if key not in self.external_data:
  132. raise RuntimeError(
  133. f"Missing required external reference to data: {key}. "
  134. "Please load AOT compiled function with "
  135. "`external_data=<external data dictionary>`"
  136. f"{self.external_data}"
  137. )
  138. return self.external_data[key]
  139. @dataclass
  140. class AOTCompileSaveResult:
  141. serialized_data: bytes
  142. @dataclass
  143. class AOTCompiledFunction:
  144. _artifacts: CompileArtifacts
  145. _guard_check_enabled: bool = True
  146. _extra_globals: dict[str, object] | None = None
  147. def prepare_f_locals(self, *args: object, **kwargs: object) -> dict[str, object]:
  148. f_locals: dict[str, object] = {}
  149. env = self._artifacts.runtime_env
  150. if env.closure:
  151. assert env.bytecode.co_freevars and len(env.closure) == len(
  152. env.bytecode.co_freevars
  153. )
  154. f_locals = {
  155. name: cell.cell_contents
  156. for name, cell in zip(env.bytecode.co_freevars, env.closure)
  157. }
  158. f_locals.update(bind_locals(self._artifacts.signature, *args, **kwargs))
  159. return f_locals
  160. def guard_check(self, *args: Any, **kwargs: Any) -> bool:
  161. f_locals = self.prepare_f_locals(*args, **kwargs)
  162. assert self._artifacts.guard_manager is not None
  163. return self._artifacts.guard_manager.check(f_locals)
  164. def __post_init__(self) -> None:
  165. from .package import load_guard_manager, load_guards_state
  166. self._artifacts.check_compatibility()
  167. self.fn = self._artifacts.runtime_env.forward_callable(
  168. self._artifacts.backend_id,
  169. self._artifacts.compiled_fn,
  170. extra_globals=self._extra_globals,
  171. )
  172. if self._artifacts.guard_manager is None:
  173. guards_state = load_guards_state(self._artifacts.guards_state)
  174. self._artifacts.guard_manager = load_guard_manager(
  175. guards_state,
  176. self._artifacts.original_code,
  177. self.fn.__globals__,
  178. )
  179. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  180. assert self._artifacts.guard_manager is not None
  181. if self._guard_check_enabled and not self.guard_check(*args, **kwargs):
  182. f_locals = self.prepare_f_locals(*args, **kwargs)
  183. reason = str(self._artifacts.guard_manager.check_verbose(f_locals))
  184. raise RuntimeError(f"GuardManager check failed, reason: {reason}")
  185. return self.fn(*args, **kwargs)
  186. def source_info(self) -> "SourceInfo":
  187. return self._artifacts.source_info
  188. def save_compiled_function(
  189. self, path: str, external_data: dict[str, Any] | None = None
  190. ) -> AOTCompileSaveResult:
  191. with open(path, "wb") as f:
  192. result = type(self).serialize(self, external_data)
  193. f.write(result.serialized_data)
  194. return result
  195. @classmethod
  196. def serialize(
  197. cls, fn: "AOTCompiledFunction", external_data: dict[str, Any] | None = None
  198. ) -> AOTCompileSaveResult:
  199. from torch._dynamo.package import SerializedCode
  200. state = fn._artifacts.__dict__.copy()
  201. state["guard_manager"] = None
  202. state["runtime_env"] = dataclasses.replace(
  203. state["runtime_env"],
  204. bytecode=SerializedCode.from_code_object(state["runtime_env"].bytecode),
  205. )
  206. compiled_fn = state["compiled_fn"]
  207. state["compiled_fn"] = (
  208. type(compiled_fn).deserialize_compile_artifacts,
  209. type(compiled_fn).serialize_compile_artifacts(compiled_fn),
  210. )
  211. state["original_code"] = SerializedCode.from_code_object(state["original_code"])
  212. buf = io.BytesIO()
  213. pickler = AOTCompilePickler(external_data or {}, buf)
  214. pickler.dump(state)
  215. if pickler.errors:
  216. raise RuntimeError(
  217. f"Failed to serialize the following objects: {list(pickler.errors.values())}\n"
  218. "Please mark these as external data by using `external_data={'key': ...}`"
  219. )
  220. return AOTCompileSaveResult(serialized_data=buf.getvalue())
  221. @classmethod
  222. def deserialize(
  223. cls,
  224. data: bytes,
  225. f_globals: dict[str, object] | None = None,
  226. external_closure_data: dict[str, Any] | None = None,
  227. ) -> "AOTCompiledFunction":
  228. from torch._dynamo.package import SerializedCode
  229. f = io.BytesIO(data)
  230. f.seek(0)
  231. unpickler = AOTCompileUnpickler(external_closure_data or {}, f)
  232. state = unpickler.load()
  233. f.close()
  234. state["runtime_env"] = dataclasses.replace(
  235. state["runtime_env"],
  236. bytecode=SerializedCode.to_code_object(state["runtime_env"].bytecode),
  237. )
  238. deserializer, compiled_fn_state = state["compiled_fn"]
  239. with torch._inductor.config.patch(enable_autograd_for_aot=True):
  240. state["compiled_fn"] = deserializer(compiled_fn_state)
  241. state["original_code"] = SerializedCode.to_code_object(state["original_code"])
  242. artifacts = CompileArtifacts(**state)
  243. return cls(artifacts, _extra_globals=f_globals)
  244. def disable_guard_check(self) -> None:
  245. self._guard_check_enabled = False
  246. def aot_compile_fullgraph(
  247. model: Any,
  248. example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
  249. hooks: Hooks,
  250. backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable],
  251. dynamic: bool | None = None,
  252. ) -> AOTCompiledFunction:
  253. from torch._dynamo.guards import CheckFunctionManager
  254. from torch._dynamo.package import SourceInfo
  255. from torch._dynamo.utils import dynamo_timed, get_metrics_context
  256. from torch._guards import TracingContext
  257. args, kwargs = example_inputs
  258. dynamic_ctx = nullcontext()
  259. if dynamic is not None:
  260. from torch._dynamo.eval_frame import set_enable_dynamic
  261. dynamic_ctx = set_enable_dynamic(dynamic)
  262. with (
  263. get_metrics_context(),
  264. dynamo_timed("fullgraph_capture"),
  265. torch._functorch.config.patch(strict_autograd_cache=True),
  266. dynamic_ctx,
  267. ):
  268. capture_output = convert_frame.fullgraph_capture(model, args, kwargs)
  269. graph_capture_output = capture_output.graph_capture_output
  270. assert graph_capture_output.output_graph is not None
  271. if not hooks.guard_filter_fn:
  272. from torch._dynamo.types import GuardFilterEntry
  273. def new_guard_filter_fn(
  274. guard_entries: Sequence[GuardFilterEntry],
  275. ) -> Sequence[bool]:
  276. return [
  277. (
  278. not (
  279. g.is_global
  280. or g.guard_type
  281. in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
  282. )
  283. )
  284. for g in guard_entries
  285. ]
  286. hooks.guard_filter_fn = new_guard_filter_fn
  287. fn, _ = convert_frame.get_traced_fn(model)
  288. backend_input = capture_output.backend_input
  289. assert backend_input is not None
  290. backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
  291. device_type = _graph_device_type(backend_input.graph_module.graph)
  292. assert (
  293. backend_input.fake_mode.shape_env
  294. is graph_capture_output.output_graph.shape_env
  295. )
  296. tracing_context = TracingContext(backend_input.fake_mode)
  297. tracing_context.tensor_to_context = backend_input.tensor_to_context
  298. with (
  299. torch._guards.tracing(tracing_context),
  300. torch._functorch.config.patch(
  301. {
  302. "strict_autograd_cache": True,
  303. "bypass_autograd_cache_key": True,
  304. "bundled_autograd_cache": True,
  305. "force_non_lazy_backward_lowering": True,
  306. "force_autograd_cache": True,
  307. }
  308. ),
  309. ):
  310. compiled_fn = backend(
  311. backend_input.graph_module, backend_input.example_inputs
  312. )
  313. # If Inductor backend is used, grab the compiled_fn from PrecompileContext
  314. # TODO: this should be replaced once we make the backend return the SerializableCallable directly.
  315. if isinstance(backend, torch._TorchCompileInductorWrapper) or (
  316. hasattr(backend, "compiler_fn")
  317. and isinstance(
  318. backend.compiler_fn, torch._dynamo.backends.common.AotAutograd
  319. )
  320. ):
  321. compiled_fn = BundledAOTAutogradSerializableCallable(compiled_fn)
  322. if not isinstance(compiled_fn, SerializableCallable):
  323. if hasattr(backend, "compiler_fn"):
  324. compiler_fn = backend.compiler_fn
  325. else:
  326. compiler_fn = backend
  327. raise RuntimeError(
  328. f"Compiled function type {type(compiled_fn)} (produced "
  329. + f"from backend {compiler_fn}) does not implement SerializableCallable."
  330. )
  331. check_fn = graph_capture_output.build_guards(
  332. fn.__code__, hooks=hooks, save=True, strict_error=True
  333. )
  334. assert check_fn.guards_state is not None
  335. source_info = SourceInfo(inlined_sources=set())
  336. for traced_code in graph_capture_output.traced_code:
  337. source_info.add_code(traced_code)
  338. artifacts = CompileArtifacts(
  339. signature=convert_frame._get_signature(fn),
  340. guard_manager=check_fn.guard_manager,
  341. guards_state=check_fn.guards_state,
  342. backend_id=backend_input.backend_id,
  343. compiled_fn=compiled_fn,
  344. original_code=fn.__code__,
  345. runtime_env=graph_capture_output.get_runtime_env(),
  346. source_info=source_info,
  347. device_type=device_type,
  348. backend_name=getattr(backend, "compiler_name", "unknown"),
  349. )
  350. aot_compiled_fn = AOTCompiledFunction(
  351. _artifacts=artifacts, _extra_globals=fn.__globals__
  352. )
  353. return aot_compiled_fn
  354. @dataclass
  355. class ModelInput:
  356. """
  357. WIP type: represents a single model input
  358. Which consists of a tuple of arguments and a set of contexts in which to run the model.
  359. For each ModelInput, we'll compile one full graph of the model, and then use the guards generated
  360. to dispatch between the compiled graphs.
  361. """
  362. args: tuple[Any]
  363. kwargs: dict[str, Any]
  364. contexts: list[AbstractContextManager[Any]]
  365. @dataclass
  366. class AOTCompiledModel:
  367. # Represents a single forward function of a model along with dispatch
  368. # compiled_results is serializable. We require the model to deserialize again.
  369. model: torch.nn.Module
  370. compiled_results: list[AOTCompiledFunction]
  371. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  372. for result in self.compiled_results:
  373. if result.guard_check(self.model, *args, **kwargs):
  374. return result(self.model, *args, **kwargs)
  375. # All guards failed, just run one of them and throw the guard check error.
  376. return self.compiled_results[0](self.model, *args, **kwargs)
  377. def serialize(self) -> bytes:
  378. data: list[bytes] = []
  379. for result in self.compiled_results:
  380. data.append(AOTCompiledFunction.serialize(result).serialized_data)
  381. return pickle.dumps(data)
  382. @classmethod
  383. def deserialize(cls, model: torch.nn.Module, data: bytes) -> "AOTCompiledModel":
  384. from torch._dynamo.utils import get_metrics_context
  385. from torch._guards import compile_context, CompileContext
  386. results: list[bytes] = pickle.loads(data)
  387. compiled_results = []
  388. for result in results:
  389. with (
  390. compile_context(CompileContext(convert_frame.get_compile_id({}))),
  391. get_metrics_context(),
  392. ):
  393. compiled_results.append(AOTCompiledFunction.deserialize(result))
  394. return cls(model, compiled_results)
  395. def aot_compile_module(
  396. model: torch.nn.Module,
  397. inputs: list[ModelInput],
  398. hooks: Hooks,
  399. backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable],
  400. ) -> AOTCompiledModel:
  401. """
  402. Compiles a single nn.Module with any number of inputs, and returns a compiled forward function.
  403. """
  404. def compile_single_graph(model_input: ModelInput) -> AOTCompiledFunction:
  405. example_inputs = (model_input.args, model_input.kwargs)
  406. orig_forward = model.forward
  407. with ExitStack() as stack:
  408. for ctx in model_input.contexts:
  409. stack.enter_context(ctx)
  410. return aot_compile_fullgraph(
  411. orig_forward,
  412. example_inputs,
  413. hooks=hooks,
  414. backend=backend,
  415. )
  416. # pyrefly: ignore [implicit-any]
  417. compiled_results = []
  418. for model_input in inputs:
  419. log.info("Compiling input %s..", model_input)
  420. compiled_results.append(compile_single_graph(model_input))
  421. assert len(compiled_results) > 0
  422. return AOTCompiledModel(model, compiled_results)