external_utils.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. """
  2. This module contains utility functions that are explicitly allowed to be called during
  3. TorchDynamo compilation. These functions are carefully vetted to ensure they work
  4. correctly within the TorchDynamo tracing and compilation process.
  5. Key functionality groups:
  6. - Compilation State:
  7. Functions for checking compilation state (is_compiling)
  8. - Function Wrapping:
  9. Utilities for wrapping functions (wrap_inline, wrap_numpy) to work with
  10. TorchDynamo compilation
  11. - Autograd Hooks:
  12. Functions and classes for handling autograd hooks and backward passes
  13. (call_hook, FakeBackwardCFunction, etc.)
  14. - Tensor Operations:
  15. Utility functions for tensor operations and transformations
  16. """
  17. import functools
  18. import warnings
  19. from collections.abc import Callable
  20. from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union
  21. from typing_extensions import deprecated, ParamSpec
  22. import torch
  23. import torch.utils._pytree as pytree
  24. try:
  25. import numpy as np
  26. except ModuleNotFoundError:
  27. np = None # type: ignore[assignment]
  28. _P = ParamSpec("_P")
  29. _R = TypeVar("_R")
  30. if TYPE_CHECKING:
  31. # TorchScript does not support `@deprecated`
  32. # This is a workaround to avoid breaking TorchScript
  33. @deprecated(
  34. "`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
  35. category=FutureWarning,
  36. )
  37. def is_compiling() -> bool:
  38. return torch.compiler.is_compiling()
  39. else:
  40. def is_compiling() -> bool:
  41. """
  42. Indicates whether we are tracing/compiling with torch.compile() or torch.export().
  43. """
  44. # NOTE: With `@torch.compile(backend="eager")`, torch._dynamo.is_compiling() will get traced
  45. # and return true. torch.compiler.is_compiling() is skipped and will return false.
  46. return torch.compiler.is_compiling()
  47. def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  48. """
  49. Create an extra frame around fn that is not in skipfiles.
  50. """
  51. @functools.wraps(fn)
  52. def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  53. return fn(*args, **kwargs)
  54. return inner
  55. def call_hook(
  56. hook: Callable[..., Optional[torch.Tensor]], *args: Any, **kwargs: Any
  57. ) -> torch.Tensor:
  58. """
  59. Used by compiled autograd to handle hook returning None.
  60. """
  61. result = hook(*args)
  62. if result is None:
  63. return args[0]
  64. elif kwargs.get("hook_type") == "post_acc_grad_hook":
  65. raise RuntimeError("Tensor post accumulate grad hooks should return None.")
  66. return result
  67. def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
  68. r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
  69. from ``torch.Tensor``s to ``torch.Tensor``s.
  70. """
  71. if not np:
  72. return f
  73. @functools.wraps(f)
  74. def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree:
  75. args, kwargs = pytree.tree_map_only(
  76. torch.Tensor, lambda x: x.numpy(), (args, kwargs)
  77. )
  78. # pyrefly: ignore [invalid-param-spec]
  79. out = f(*args, **kwargs)
  80. # pyrefly: ignore [missing-attribute]
  81. return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
  82. return wrap
  83. class FakeBackwardCFunction:
  84. def __init__(
  85. self,
  86. real: torch.autograd.function.BackwardCFunction,
  87. saved_tensors: list[torch.Tensor],
  88. ) -> None:
  89. self.real = real
  90. self.saved_tensors = saved_tensors
  91. def __getattr__(self, name: str) -> Any:
  92. if name == "saved_variables":
  93. warnings.warn(
  94. "'saved_variables' is deprecated; use 'saved_tensors'",
  95. DeprecationWarning,
  96. )
  97. return self.saved_tensors
  98. return getattr(self.real, name)
  99. def call_backward(
  100. backward_c_function: torch.autograd.function.BackwardCFunction,
  101. saved_tensors: list[torch.Tensor],
  102. *args: Any,
  103. ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
  104. fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
  105. grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
  106. if not isinstance(grads, tuple):
  107. grads = (grads,)
  108. return grads
  109. def normalize_as_list(x: Any) -> list[Any]:
  110. if isinstance(x, tuple):
  111. return list(x)
  112. elif isinstance(x, list):
  113. return x
  114. return [x]
  115. def untyped_storage_size(x: torch.Tensor) -> int:
  116. return x.untyped_storage().size()
  117. class FakeCompiledAutogradEngine:
  118. @staticmethod
  119. def queue_callback(
  120. final_callbacks: list[Callable[[], None]], cb: Callable[[], None]
  121. ) -> None:
  122. final_callbacks.append(cb)
  123. @staticmethod
  124. def exec_final_callbacks(final_callbacks: list[Callable[[], None]]) -> None:
  125. i = 0
  126. while i < len(final_callbacks):
  127. cb = final_callbacks[i]
  128. cb()
  129. i += 1
  130. final_callbacks.clear()
  131. @staticmethod
  132. def _exec_final_callbacks_stub() -> None:
  133. pass
  134. def call_hook_from_backward_state(
  135. *args: Any, bw_state: Any, hook_name: str, **kwargs: Any
  136. ) -> Any:
  137. return getattr(bw_state, hook_name)(*args, **kwargs)
  138. class _ApplyBackwardHook(torch.autograd.Function):
  139. """Custom autograd function that applies a hook during backward.
  140. This is used to implement register_hook on intermediate tensors without
  141. requiring compiled autograd. The hook function is captured in the context
  142. and applied during the backward pass.
  143. """
  144. @staticmethod
  145. # pyre-ignore[14]: Inconsistent override is expected for autograd.Function
  146. def forward(
  147. ctx: Any, tensor: torch.Tensor, hook_fn: Callable[..., Any]
  148. ) -> torch.Tensor: # type: ignore[override]
  149. ctx.hook_fn = hook_fn
  150. return tensor.view_as(tensor)
  151. @staticmethod
  152. def backward(ctx: Any, grad: torch.Tensor) -> tuple[torch.Tensor, None]: # type: ignore[override]
  153. result = ctx.hook_fn(grad)
  154. if result is None:
  155. result = grad
  156. return result, None
  157. def call_module_hooks_from_backward_state(
  158. _: Any, result: Any, *args: Any, bw_state: Any, hooks_name: str, module_name: str
  159. ) -> Any:
  160. module = getattr(bw_state, module_name)
  161. hooks = getattr(bw_state, hooks_name)
  162. for hook in hooks:
  163. new_result = hook(module, result, *args)
  164. if new_result is not None:
  165. result = new_result
  166. return result
  167. # used for torch._dynamo.disable(recursive=False)
  168. def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  169. # wrap function to get the right error message
  170. # this function is in external_utils so that convert_frame doesn't skip it.
  171. @functools.wraps(fn)
  172. def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  173. if torch.compiler.is_exporting():
  174. raise RuntimeError(
  175. "Non-recursive torch.compiler.disable is not supported with torch.export."
  176. )
  177. return fn(*args, **kwargs)
  178. return nonrecursive_disable_wrapper
  179. def wrap_dunder_call_ctx_manager(self: Any, func: Callable[_P, _R]) -> Callable[_P, _R]:
  180. """
  181. Apply self as a ctx manager around a call to func
  182. """
  183. # NOTE: do not functools.wraps(func) because we don't ever want this frame to be skipped!
  184. def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  185. with self:
  186. return func(*args, **kwargs)
  187. return inner
  188. # Use only on ints marked dynamic via torch.empty(0, integer)
  189. # Currently only way to mark ints as dynamic: https://github.com/pytorch/pytorch/issues/129623
  190. def unwrap_maybe_dynamic_int(x: Union[torch.Tensor, int]) -> int:
  191. if isinstance(x, torch.Tensor):
  192. # x.size() is expected to be [0, dynamic_int]
  193. return x.size(1)
  194. return x
  195. def call_accumulate_grad(
  196. variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool
  197. ) -> None:
  198. updated_grad = torch._dynamo.compiled_autograd.ops.AccumulateGrad( # type: ignore[attr-defined]
  199. [grad], variable, variable.grad, has_post_hooks
  200. )
  201. variable.grad = updated_grad[0]
  202. def wrap_inline_with_error_on_graph_break(
  203. fn: Callable[_P, _R], error_on_graph_break: bool
  204. ) -> Callable[_P, _R]:
  205. # NB: need multiple definitions in order to prevent `fullgraph` from
  206. # being a freevar of wrapper
  207. # NOTE: do not functools.wraps(fn) because we don't ever want these wrappers to be skipped!
  208. if error_on_graph_break:
  209. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  210. with torch._dynamo.error_on_graph_break(True):
  211. return fn(*args, **kwargs)
  212. else:
  213. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  214. with torch._dynamo.error_on_graph_break(False):
  215. return fn(*args, **kwargs)
  216. return wrapper
  217. def filter_out_const_values(tup: tuple[Any, ...], masks: list[bool]) -> tuple[Any, ...]:
  218. """
  219. masks is a list of bools, where True means the corresponding element in tup
  220. is a const value. Filter out the const values.
  221. """
  222. out = []
  223. for mask_idx, mask in enumerate(masks):
  224. if not mask:
  225. out.append(tup[mask_idx])
  226. return tuple(out)
  227. def insert_const_values_with_mask(
  228. tup: tuple[Any, ...], masks: list[bool], values: tuple[Any, ...]
  229. ) -> tuple[Any, ...]:
  230. """
  231. masks and values are of same length. For indices where the mask is True, use
  232. the const_values to fill in.
  233. """
  234. out = []
  235. idx = 0
  236. for mask_idx, mask in enumerate(masks):
  237. if mask:
  238. out.append(values[mask_idx])
  239. else:
  240. out.append(tup[idx])
  241. idx += 1
  242. return tuple(out)