__init__.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740
  1. # mypy: allow-untyped-defs
  2. import io
  3. from collections.abc import Callable
  4. from dataclasses import dataclass
  5. from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union
  6. from typing_extensions import ParamSpec
  7. import torch
  8. from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions
  9. from . import config
  10. if TYPE_CHECKING:
  11. from ._cache import CacheInfo
  12. __all__ = [
  13. "compile",
  14. "config",
  15. "assume_constant_result",
  16. "reset",
  17. "allow_in_graph",
  18. "substitute_in_graph",
  19. "list_backends",
  20. "disable",
  21. "set_stance",
  22. "set_enable_guard_collectives",
  23. "cudagraph_mark_step_begin",
  24. "load_compiled_function",
  25. "wrap_numpy",
  26. "is_compiling",
  27. "is_dynamo_compiling",
  28. "is_exporting",
  29. "save_cache_artifacts",
  30. "load_cache_artifacts",
  31. "keep_portable_guards_unsafe",
  32. "skip_guard_on_inbuilt_nn_modules_unsafe",
  33. "skip_guard_on_all_nn_modules_unsafe",
  34. "keep_tensor_guards_unsafe",
  35. "skip_guard_on_globals_unsafe",
  36. "skip_all_guards_unsafe",
  37. "nested_compile_region",
  38. ]
  39. _P = ParamSpec("_P")
  40. _R = TypeVar("_R")
  41. FuncType = Callable[..., Any]
  42. F = TypeVar("F", bound=FuncType)
  43. def compile(*args, **kwargs):
  44. """
  45. See :func:`torch.compile` for details on the arguments for this function.
  46. """
  47. # pyrefly: ignore [not-iterable]
  48. return torch.compile(*args, **kwargs)
  49. def reset() -> None:
  50. """
  51. This function clears all compilation caches and restores the system to its initial state.
  52. It is recommended to call this function, especially after using operations like `torch.compile(...)`
  53. to ensure a clean state before another unrelated compilation
  54. """
  55. import torch._dynamo
  56. torch._dynamo.reset()
  57. def allow_in_graph(fn):
  58. """
  59. Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
  60. and instead directly write it to the graph when encountered.
  61. If you are using :func:`torch.compile` (with backend="inductor" (the default)), or
  62. :func:`torch.export.export`, and trying to black-box a Python function throughout
  63. all tracing, do not use this API.
  64. Instead, please create a custom operator (see `PyTorch Custom Operators Landing Page
  65. <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html>`_)
  66. .. warning::
  67. If you're a typical torch.compile user (e.g. you're applying torch.compile to
  68. a model to make it run faster), you probably don't want to use this function.
  69. :func:`allow_in_graph` is a footgun because it skips the compiler frontend
  70. (Dynamo) that is responsible for doing safety checks (graph breaks, handling
  71. closures, etc). Incorrect usage will lead to difficult-to-debug silent
  72. incorrectness issues.
  73. Given a Python function with no allow_in_graph decorator, regular execution
  74. of torch.compile traces through the function. :func:`allow_in_graph` changes
  75. it so that the frontend does not trace inside the function, but the compiler
  76. backend still traces through it. Compare this to custom operators, which
  77. treats a function as a black box throughout the torch.compile stack. The following
  78. table compares these mechanisms.
  79. +------------------------+-----------------------+--------------------------------+
  80. | Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) |
  81. +========================+=======================+================================+
  82. | no decorator | trace inside | trace inside |
  83. +------------------------+-----------------------+--------------------------------+
  84. | allow_in_graph | opaque callable | trace inside |
  85. +------------------------+-----------------------+--------------------------------+
  86. | custom op | opaque callable | opaque callable |
  87. +------------------------+-----------------------+--------------------------------+
  88. One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler
  89. frontend: if you know the function works w.r.t. to the downstream components of the
  90. compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from
  91. symbolically introspecting the function properly (or if your code is in C/C++ and
  92. therefore cannot be introspected with Dynamo), then one can decorate said function
  93. with :func:`allow_in_graph` to bypass Dynamo.
  94. We require that ``fn`` adhere to the following restrictions. Failure to adhere
  95. results in undefined behavior:
  96. - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include:
  97. Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?]
  98. Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device
  99. - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet)
  100. - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn``
  101. (as opposed to being captured variables).
  102. Args:
  103. fn: A callable representing the function to be included in the graph.
  104. If ``fn`` is a list or tuple of callables it recursively applies
  105. :func:`allow_in_graph()` to each function and returns a new list or
  106. tuple containing the modified functions.
  107. Example::
  108. torch.compiler.allow_in_graph(my_custom_function)
  109. @torch.compile(...)
  110. def fn(x):
  111. x = torch.add(x, 1)
  112. x = my_custom_function(x)
  113. x = torch.add(x, 1)
  114. return x
  115. fn(...)
  116. Will capture a single graph containing ``my_custom_function()``.
  117. """
  118. import torch._dynamo
  119. return torch._dynamo.allow_in_graph(fn)
  120. def substitute_in_graph(
  121. original_fn: Callable[_P, _R],
  122. *,
  123. can_constant_fold_through: bool = False,
  124. skip_signature_check: bool = False,
  125. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  126. """
  127. Register a polyfill handler for a function, usually a C function from the C extension, to be
  128. used in place of the original function when inlining the original function in the graph.
  129. .. note::
  130. The polyfill handler is only used when inlining the original function. It is not used when
  131. the original function is called directly. In the eager mode, the decorated function calls
  132. the performant C function rather than the polyfill handler.
  133. The polyfill handler is a function that will be called in place of the original function when
  134. inlining the original function. The polyfill handler should have the same signature and the same
  135. behavior as the original function.
  136. Args:
  137. original_fn (callable): The original function, usually a C function, to register a polyfill
  138. handler for.
  139. can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
  140. folded through. That is, if the polyfill handler is a pure function and its arguments
  141. are constant, the result of the polyfill handler can be constant folded during the
  142. compilation. Defaults to ``False``.
  143. skip_signature_check (bool, optional): Whether to skip the signature check between the
  144. original function and the polyfill handler. Defaults to ``False``.
  145. Returns:
  146. A decorator that registers the polyfill handler for the original function.
  147. Example::
  148. >>> import operator
  149. >>> operator.indexOf([1, 2, 3, 4, 5], 3)
  150. 2
  151. >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
  152. ... # xdoctest: +SKIP("Long tracebacks")
  153. Traceback (most recent call last):
  154. ...
  155. torch._dynamo.exc.Unsupported: ...
  156. >>> @torch.compiler.substitute_in_graph(operator.indexOf)
  157. ... def indexOf(a, b, /):
  158. ... for i, item in enumerate(a):
  159. ... if item is b or item == b:
  160. ... return i
  161. ... raise ValueError("sequence.index(x): x not in sequence")
  162. >>>
  163. >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
  164. 2
  165. """
  166. import torch._dynamo
  167. return torch._dynamo.substitute_in_graph(
  168. original_fn,
  169. can_constant_fold_through=can_constant_fold_through,
  170. skip_signature_check=skip_signature_check,
  171. )
  172. def list_backends(exclude_tags=("debug", "experimental")) -> list[str]:
  173. """
  174. Return valid strings that can be passed to `torch.compile(..., backend="name")`.
  175. Args:
  176. exclude_tags(optional): A tuple of strings representing tags to exclude.
  177. """
  178. import torch._dynamo
  179. return torch._dynamo.list_backends(exclude_tags)
  180. def assume_constant_result(fn):
  181. """
  182. This function is used to mark a function `fn` as having a constant result.
  183. This allows the compiler to optimize away your function.
  184. Returns The same function `fn`
  185. Args:
  186. fn: The function to be marked as having a constant result.
  187. .. warning::
  188. `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile`
  189. will not attempt to validate whether the constant assumption is true or not
  190. """
  191. import torch._dynamo
  192. return torch._dynamo.assume_constant_result(fn)
  193. def disable(fn=None, recursive=True, *, reason=None):
  194. """
  195. This function provides a decorator to disable compilation on a function.
  196. It also provides the option of recursively disabling called functions.
  197. Args:
  198. fn (optional): The function to disable
  199. recursive (optional): A boolean value indicating whether the disabling should be recursive.
  200. reason (optional): A string value indicating the reason for disabling the function.
  201. """
  202. import torch._dynamo
  203. return torch._dynamo.disable(fn, recursive, reason=reason)
  204. def set_stance(
  205. stance: str = "default",
  206. *,
  207. skip_guard_eval_unsafe: bool = False,
  208. force_backend: Union[str, Callable[..., Any], None] = None,
  209. ):
  210. """
  211. Set the current stance of the compiler.
  212. Can be used as a function, context manager, or decorator.
  213. Do not use this function inside a `torch.compile` region - an error will be raised otherwise.
  214. .. code-block:: python
  215. @torch.compile
  216. def foo(x): ...
  217. @torch.compiler.set_stance("force_eager")
  218. def bar():
  219. # will not be compiled
  220. foo(...)
  221. bar()
  222. with torch.compiler.set_stance("force_eager"):
  223. # will also not be compiled
  224. foo(...)
  225. torch.compiler.set_stance("force_eager")
  226. # will also not be compiled
  227. foo(...)
  228. torch.compiler.set_stance("default")
  229. # will be compiled
  230. foo(...)
  231. Args:
  232. stance: The stance to set the compiler to. Valid values are:
  233. - "default": The default stance, used for normal compilation.
  234. - "force_eager": Ignore all `torch.compile` directives.
  235. - "eager_on_recompile": Run code eagerly when a recompile is necessary.
  236. If there is cached compiled code valid for the input, it will still be used.
  237. - "fail_on_recompile": Raise an error when recompiling a function.
  238. - "eager_then_compile": Run the first invocation in eager mode, then compile on
  239. subsequent calls. This is beneficial for dynamic shapes as it allows inferring
  240. dynamism from the first two invocations instead of wasting a static compile on
  241. the first invocation.
  242. - "aot_eager_then_compile": Run the first invocation with AOT eager to get memory
  243. benefits from activation checkpointing, then compile on subsequent calls. Like
  244. eager_then_compile, this improves handling of dynamic shapes by avoiding an
  245. initial static compile.
  246. skip_guard_eval_unsafe: A flag to run only differentiating guards.
  247. CAUTION - This flag is unsafe and should only be used if your setup
  248. meets the following conditions.
  249. torch.compile uses a guard system to support recompilations and
  250. choose which compiled artifact to run at runtime. These guards,
  251. though efficient, add some overhead, which may impact performance in
  252. scenarios where you need to optimize for minimal guard processing
  253. time. This API enables you to disable guard evaluation, assuming
  254. that you have warmed up the compiled model with a sufficient variety
  255. of inputs. This assumption means that, after the warmup phase, no
  256. further recompilations will be necessary. If this assumption fails,
  257. there is a risk of silently producing incorrect results (hence the
  258. term "unsafe" in the API name).
  259. force_backend: If `stance` is "default", this argument can be used to force `torch.compile`
  260. to use a specific backend. Otherwise, an error is raised.
  261. """
  262. import torch._dynamo
  263. return torch._dynamo.set_stance(
  264. stance,
  265. skip_guard_eval_unsafe=skip_guard_eval_unsafe,
  266. force_backend=force_backend,
  267. )
  268. # forbid in graph
  269. set_stance._dynamo_forbidden = True # type: ignore[attr-defined]
  270. def set_enable_guard_collectives(enabled: bool):
  271. """
  272. Enables use of collectives *during* guard evaluation to synchronize behavior
  273. across ranks. This is expensive: we have to issue a collective every time
  274. we enter a compiled code region, even if no rank actually would need to
  275. compile. This can help prevent NCCL hangs by ensuring that we never have a
  276. situation where one rank starts recompiling while other ranks don't compile;
  277. it is especially useful in conjunction with enable_compiler_collectives
  278. where such a situation would immediately cause a hang (as it is necessary
  279. for all ranks to compile at the same time to run compiler collectives). Like
  280. compiler collectives, you can only run this on SPMD programs; you will hang
  281. otherwise. Note that a guard collective is only issued if there is any
  282. compiled code to guard on; if this the first time we encounter a frame or
  283. the frame is skipped, we don't issue collectives.
  284. Returns the previous setting of enabled.
  285. """
  286. from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401
  287. from torch._dynamo.eval_frame import guard_collectives_hook
  288. if enabled:
  289. return set_guard_complete_hook(guard_collectives_hook) is not None # type: ignore[arg-type]
  290. else:
  291. return set_guard_complete_hook(None) is not None
  292. set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined]
  293. def cudagraph_mark_step_begin():
  294. """
  295. Indicates that a new iteration of inference or training is about to begin.
  296. CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of
  297. torch.compile, so long as there is not a pending backward that has not been called.
  298. If that heuristic is wrong, such as in the following example, manually mark it with this api.
  299. .. code-block:: python
  300. @torch.compile(mode="reduce-overhead")
  301. def rand_foo():
  302. return torch.rand([4], device="cuda")
  303. for _ in range(5):
  304. torch.compiler.cudagraph_mark_step_begin()
  305. rand_foo() + rand_foo()
  306. For more details, see `torch.compiler_cudagraph_trees <https://docs.pytorch.org/docs/main/user_guide/torch_compiler/torch.compiler_cudagraph_trees.html>`__ # noqa: B950
  307. """
  308. from torch._inductor import cudagraph_trees
  309. cudagraph_trees.mark_step_begin()
  310. def wrap_numpy(fn):
  311. r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
  312. from ``torch.Tensor``s to ``torch.Tensor``s.
  313. It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to
  314. compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code
  315. on CUDA or compute its gradients.
  316. .. note::
  317. This decorator does not work without :func:`torch.compile`.
  318. Example::
  319. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  320. >>> # Compile a NumPy function as a Tensor -> Tensor function
  321. >>> @torch.compile(fullgraph=True)
  322. >>> @torch.compiler.wrap_numpy
  323. >>> def fn(a: np.ndarray):
  324. >>> return np.sum(a * a)
  325. >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients
  326. >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True)
  327. >>> out = fn(x)
  328. >>> out.backward()
  329. >>> print(x.grad)
  330. tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0')
  331. """
  332. from torch._dynamo.external_utils import wrap_numpy as wrap
  333. return wrap(fn)
  334. _is_compiling_flag: bool = False
  335. _is_exporting_flag: bool = False
  336. def is_compiling() -> bool:
  337. """
  338. Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
  339. Note that there are 2 other related flags that should deprecated eventually:
  340. * torch._dynamo.external_utils.is_compiling()
  341. * torch._utils.is_compiling()
  342. Example::
  343. >>> def forward(self, x):
  344. >>> if not torch.compiler.is_compiling():
  345. >>> pass # ...logic that is not needed in a compiled/traced graph...
  346. >>>
  347. >>> # ...rest of the function...
  348. """
  349. if torch.jit.is_scripting():
  350. return False
  351. else:
  352. return _is_compiling_flag
  353. def is_dynamo_compiling() -> bool:
  354. """
  355. Indicates whether a graph is traced via TorchDynamo.
  356. It's stricter than is_compiling() flag, as it would only be set to True when
  357. TorchDynamo is used.
  358. Example::
  359. >>> def forward(self, x):
  360. >>> if not torch.compiler.is_dynamo_compiling():
  361. >>> pass # ...logic that is not needed in a TorchDynamo-traced graph...
  362. >>>
  363. >>> # ...rest of the function...
  364. """
  365. return False
  366. def is_exporting() -> bool:
  367. """
  368. Indicated whether we're under exporting.
  369. It's stricter than is_compiling() flag, as it would only be set to True when
  370. torch.export is used.
  371. Example::
  372. >>> def forward(self, x):
  373. >>> if not torch.compiler.is_exporting():
  374. >>> pass # ...logic that is not needed in export...
  375. >>>
  376. >>> # ...rest of the function...
  377. """
  378. return _is_exporting_flag
  379. def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]:
  380. """
  381. Serializes all the cache artifacts that were created during the compilation
  382. Example:
  383. - Execute torch.compile
  384. - Call torch.compiler.save_cache_artifacts()
  385. """
  386. from ._cache import CacheArtifactManager
  387. if torch._dynamo.config.caching_precompile:
  388. from torch._dynamo.precompile_context import PrecompileContext
  389. PrecompileContext.save_to_dynamo_cache()
  390. return CacheArtifactManager.serialize()
  391. def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]:
  392. """
  393. Hot loads cache artifacts that were previously serialized via
  394. save_cache_artifacts
  395. Example:
  396. # From a previous invocation
  397. artifacts = torch.compiler.save_cache_artifacts()
  398. torch.compiler.load_cache_artifacts(artifacts[0])
  399. """
  400. from ._cache import CacheArtifactManager, CacheInfo
  401. artifacts = CacheArtifactManager.deserialize(serialized_artifacts)
  402. if artifacts is not None:
  403. return CacheArtifactManager.populate_caches(artifacts)
  404. return None
  405. def keep_portable_guards_unsafe(guard_entries):
  406. """
  407. A common function to only keep guards that can be used in both Python and non-Python environments.
  408. This includes:
  409. - Tensor metadata and dynamic shape information.
  410. - Global contexts state (e.g. autocast, no_grad, etc.)
  411. This is unsafe to use by default.
  412. To use this API, use guard_filter_fn argument while calling torch.compile
  413. >> opt_mod = torch.compile(
  414. >> mod,
  415. >> options={"guard_filter_fn": torch.compiler.keep_global_context_and_tensor_guards_unsafe},
  416. >> )
  417. """
  418. return [
  419. (
  420. g.guard_type in ("GLOBAL_STATE", "SHAPE_ENV")
  421. or (g.guard_type == "TENSOR_MATCH" and not g.is_global)
  422. )
  423. for g in guard_entries
  424. ]
  425. def skip_guard_on_inbuilt_nn_modules_unsafe(guard_entries):
  426. """
  427. A common function to skip guards on the inbuilt nn modules like
  428. torch.nn.Linear. This is unsafe to use by default. But for majority of
  429. torch.compile users, the model code does not modify the inbuilt nn module
  430. attributes. They can benefit from reduction in guard latency overhead using
  431. this API.
  432. To use this API, use guard_filter_fn argument while calling torch.compile
  433. >> opt_mod = torch.compile(
  434. >> mod,
  435. >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},
  436. >> )
  437. """
  438. return [
  439. not entry.orig_guard.source.is_unspecialized_builtin_nn_module()
  440. for entry in guard_entries
  441. ]
  442. def skip_guard_on_all_nn_modules_unsafe(guard_entries):
  443. """
  444. A common function to skip guards on all nn modules, both user defined as
  445. well inbuilt nn modules (like torch.nn.Linear). This is unsafe to use by
  446. default. But for majority of torch.compile users, the model code does not
  447. modify the nn module attributes. They can benefit from reduction in guard
  448. latency overhead using this API.
  449. To use this API, use guard_filter_fn argument while calling torch.compile
  450. >> opt_mod = torch.compile(
  451. >> mod,
  452. >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe},
  453. >> )
  454. """
  455. return [
  456. not entry.orig_guard.source.is_unspecialized_nn_module()
  457. for entry in guard_entries
  458. ]
  459. def keep_tensor_guards_unsafe(guard_entries, keep_parameters=False):
  460. """
  461. A common function to keep tensor guards on all tensors. This is unsafe to
  462. use by default. But if you don't expect any changes in the model code, you
  463. can just keep the tensor guards.
  464. >> opt_mod = torch.compile(
  465. >> mod,
  466. >> options={"guard_filter_fn": torch.compiler.keep_tensor_guards},
  467. >> )
  468. """
  469. keep_flags = []
  470. for entry in guard_entries:
  471. if entry.guard_type == "TENSOR_MATCH":
  472. if not isinstance(entry.value, torch.nn.Parameter):
  473. keep_flags.append(True)
  474. elif keep_parameters:
  475. keep_flags.append(True)
  476. else:
  477. keep_flags.append(False)
  478. else:
  479. keep_flags.append(False)
  480. return keep_flags
  481. def skip_guard_on_globals_unsafe(guard_entries):
  482. """
  483. A common function to skip guards on all globals. This is unsafe to use by
  484. default. But if you don't expect any changes in the globals, you can just
  485. keep the tensor guards.
  486. >> opt_mod = torch.compile(
  487. >> mod,
  488. >> options={"guard_filter_fn": torch.compiler.skip_guard_on_globals},
  489. >> )
  490. """
  491. return [not entry.is_global for entry in guard_entries]
  492. def skip_all_guards_unsafe(guard_entries):
  493. """
  494. A function for skipping all guards on a compiled function.
  495. WARNING: This function will drop all the safety guarantees from Dynamo
  496. compiled function. Use this with caution.
  497. To use this API, use guard_filter_fn argument while calling torch.compile
  498. >> opt_mod = torch.compile(
  499. >> mod,
  500. >> options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe},
  501. >> )
  502. """
  503. return [False for entry in guard_entries]
  504. def nested_compile_region(
  505. fn=None, options: Optional[NestedCompileRegionOptions] = None
  506. ):
  507. """
  508. Tells **``torch.compile``** that the marked set of operations forms a nested
  509. compile region (which is often repeated in the full model) whose code can be
  510. compiled once and safely reused. ``nested_compile_region`` can also be used
  511. as a decorator.
  512. During **``torch.compile``** tracing, the compiler applies *hierarchical
  513. compilation* with ``nested_compile_region``: it emits optimized code for the
  514. marked region the first time it is encountered and re-emits (or "stamps
  515. out") the previously compiled code on every subsequent invocation. This can
  516. substantially reduce overall compile time for deeply-stacked,
  517. structurally-identical components such as the transformer layers of a
  518. large-language-model (LLM).
  519. Outside a ``torch.compile`` context—i.e., in standard eager execution—the
  520. call is a no-op, so existing workflows remain unaffected.
  521. Note that ``nested_compile_region`` **does not** promise that a region will
  522. be compiled exactly once. If the compiler detects that new input conditions
  523. (shape, dtype, device, stride, globals etc.) make the cached version invalid
  524. to reuse, it will transparently re-compile the region. Using it is
  525. therefore *safe*: correctness is always preserved, and you pay the extra
  526. compilation cost only when required.
  527. Args:
  528. fn: The function to wrap
  529. options: Optional backend to use for compiling the subgraph.
  530. Warning: this is an experimental feature under development and
  531. not ready for use yet.
  532. """
  533. if options is not None:
  534. from torch._dynamo import config as dynamo_config
  535. if not dynamo_config.enable_invoke_subgraph_regional_compile:
  536. raise RuntimeError(
  537. "nested_compile_region config is an experimental feature for testing only."
  538. )
  539. from torch._higher_order_ops.invoke_subgraph import (
  540. mark_compile_region as _mark_compile_region,
  541. )
  542. return _mark_compile_region(fn, options=options)
  543. def load_compiled_function(
  544. file: io.IOBase,
  545. *,
  546. f_globals: dict[str, object] | None = None,
  547. external_data: dict[str, Any] | None = None,
  548. ) -> Callable[..., Any]:
  549. """
  550. Load an aot-compiled function from a file.
  551. .. warning::
  552. This API is currently experimental and subject to change.
  553. Args:
  554. file: A file-like object containing the serialized compiled function.
  555. f_globals: Optional global scope enclosing the compiled function.
  556. external_data: Optional data to be loaded into the runtime environment
  557. of the compiled function. This should contains the same
  558. data as AOTCompileResult.external_data returned from save_compiled_function() call.
  559. Returns:
  560. A torch-compiled function with compilation preloaded from disk.
  561. """
  562. from torch._dynamo.aot_compile import AOTCompiledFunction
  563. data = file.read()
  564. return AOTCompiledFunction.deserialize(data, f_globals, external_data)