_python_dispatch.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import contextlib
  4. import functools
  5. import warnings
  6. from collections import deque
  7. from dataclasses import dataclass
  8. from typing import cast, overload, Protocol, TYPE_CHECKING
  9. from typing_extensions import TypeIs
  10. import torch
  11. import torchgen
  12. import torchgen.model
  13. from torch._C import (
  14. _get_dispatch_stack_at,
  15. _len_torch_dispatch_stack,
  16. _pop_torch_dispatch_stack,
  17. _push_on_torch_dispatch_stack,
  18. DispatchKey,
  19. )
  20. from torch._C._dynamo.guards import set_is_in_mode_without_ignore_compile_internals
  21. if TYPE_CHECKING:
  22. from collections.abc import Sequence
  23. # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
  24. # - We need a better user-facing api for _DisableTorchDispatch that
  25. # is able to selectively disable __torch_dispatch__ of a particular class.
  26. # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
  27. # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
  28. _is_in_torch_dispatch_mode = False
  29. _is_in_non_infra_torch_dispatch_mode = False
  30. # If inside any mode that has ignore_compile_internals() = False
  31. _is_in_any_mode_without_ignore_compile_internals = False
  32. def is_in_torch_dispatch_mode(include_infra_modes: bool = True) -> bool:
  33. return (
  34. _is_in_torch_dispatch_mode
  35. if include_infra_modes
  36. else _is_in_non_infra_torch_dispatch_mode
  37. )
  38. def is_in_any_mode_without_ignore_compile_internals() -> bool:
  39. return _is_in_any_mode_without_ignore_compile_internals
  40. def any_torch_dispatch_mode_on_stack() -> bool:
  41. stack_len = torch._C._len_torch_dispatch_stack()
  42. for idx in range(stack_len):
  43. mode = _get_dispatch_stack_at(idx)
  44. # Apply filters first
  45. if mode.is_infra_mode():
  46. continue
  47. if mode.ignore_compile_internals():
  48. continue
  49. return True
  50. return False
  51. class TorchDispatchMode:
  52. """
  53. A ``TorchDispatchMode`` allows you to override the meaning of all
  54. ``__torch_dispatch__`` overridable functions within a dynamic scope,
  55. without having to actually create a tensor subclass or manually
  56. monkey-patch functions in the PyTorch API. Some common situations
  57. where you should use a mode:
  58. * You want to override the meaning of factory functions, or other
  59. functions that do not otherwise take a tensor as an argument
  60. (these cannot be overridden with tensor subclasses).
  61. * You want to override the behavior of all functions without needing
  62. to wrap your inputs in tensor subclasses; e.g., if you are just
  63. interested in logging intermediate computations.
  64. * You want to control the order of execution of various tensor
  65. subclasses explicitly, rather than implicitly via the return of
  66. ``NotImplemented``.
  67. Independent subclasses of :class:`TorchDispatchMode` are compositional:
  68. modes can be pushed onto a stack using ``with MyMode():``.
  69. When you call functions in the PyTorch API inside your
  70. ``__torch_dispatch__`` implementation, by default, they will forward on to
  71. the next mode on the mode stack. If you want recursively call back into
  72. your current ``__torch_dispatch__`` implementation, either explicitly
  73. invoke ``self.__torch_dispatch__(...)``, or use the context manager
  74. ``self`` to make PyTorch
  75. API self-referential (beware of infinite loops, in this case!)
  76. """
  77. # - When False, custom torch dispatch mode will error out explicitly when a hop
  78. # is called under the mode.
  79. # - When True, custom torch dispatch mode's __torch_dispatch__ will be triggered.
  80. # Mode authors can implement how the mode interacts with higher order operators.
  81. supports_higher_order_operators = False
  82. def __init_subclass__(cls, **kwargs):
  83. super().__init_subclass__(**kwargs)
  84. if cls._should_skip_dynamo():
  85. if "__torch_dispatch__" in cls.__dict__:
  86. raw = cls.__dict__["__torch_dispatch__"]
  87. if not isinstance(raw, classmethod):
  88. cls.__torch_dispatch__ = torch._disable_dynamo(raw, recursive=True)
  89. def __init__(self, _dispatch_key=None):
  90. if _dispatch_key is not None:
  91. if not isinstance(_dispatch_key, torch._C.DispatchKey):
  92. raise AssertionError("_dispatch_key must be a torch._C.DispatchKey")
  93. self.__dict__["_dispatch_key"] = _dispatch_key
  94. self.old_dispatch_mode_flags: deque[bool] = deque()
  95. self.old_non_infra_dispatch_mode_flags: deque[bool] = deque()
  96. self.old_without_ignore_compile_internals_dispatch_mode_flags: deque[bool] = (
  97. deque()
  98. )
  99. def _lazy_init_old_dispatch_mode_flags(self):
  100. if not hasattr(self, "old_dispatch_mode_flags"):
  101. self.old_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef]
  102. if not hasattr(self, "old_non_infra_dispatch_mode_flags"):
  103. self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef]
  104. if not hasattr(
  105. self, "old_without_ignore_compile_internals_dispatch_mode_flags"
  106. ):
  107. self.old_without_ignore_compile_internals_dispatch_mode_flags: deque[ # type: ignore[no-redef]
  108. bool
  109. ] = deque()
  110. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  111. raise NotImplementedError
  112. def __enter__(self):
  113. global _is_in_torch_dispatch_mode
  114. global _is_in_non_infra_torch_dispatch_mode
  115. global _is_in_any_mode_without_ignore_compile_internals
  116. # Previously, there wasn't any state in this class' constructor
  117. # super calls were added to existing modes, but for any new modes
  118. # this will replicate the previous behavior of not strictly needing
  119. # to call super().__init__()
  120. self._lazy_init_old_dispatch_mode_flags()
  121. self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode)
  122. _is_in_torch_dispatch_mode = True
  123. self.old_non_infra_dispatch_mode_flags.append(
  124. _is_in_non_infra_torch_dispatch_mode
  125. )
  126. _is_in_non_infra_torch_dispatch_mode = (
  127. _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
  128. )
  129. self.old_without_ignore_compile_internals_dispatch_mode_flags.append(
  130. _is_in_any_mode_without_ignore_compile_internals
  131. )
  132. _is_in_any_mode_without_ignore_compile_internals = (
  133. _is_in_any_mode_without_ignore_compile_internals
  134. or not self.ignore_compile_internals()
  135. )
  136. set_is_in_mode_without_ignore_compile_internals(
  137. _is_in_any_mode_without_ignore_compile_internals
  138. )
  139. _push_mode(self)
  140. return self
  141. def __exit__(self, exc_type, exc_val, exc_tb):
  142. mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None)
  143. if mb_dk_or_mode_key is None:
  144. # Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch)
  145. # We should probably revisit this.
  146. mb_dk_or_mode_key = self.__dict__.get("_mode_key", None)
  147. global _is_in_torch_dispatch_mode
  148. _is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop()
  149. global _is_in_non_infra_torch_dispatch_mode
  150. _is_in_non_infra_torch_dispatch_mode = (
  151. self.old_non_infra_dispatch_mode_flags.pop()
  152. )
  153. global _is_in_any_mode_without_ignore_compile_internals
  154. _is_in_any_mode_without_ignore_compile_internals = (
  155. self.old_without_ignore_compile_internals_dispatch_mode_flags.pop()
  156. )
  157. set_is_in_mode_without_ignore_compile_internals(
  158. _is_in_any_mode_without_ignore_compile_internals
  159. )
  160. _pop_mode(mb_dk_or_mode_key)
  161. @classmethod
  162. def push(cls, *args, **kwargs):
  163. warnings.warn(
  164. "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`",
  165. stacklevel=2,
  166. )
  167. instance = cls(*args, **kwargs)
  168. return instance
  169. @classmethod
  170. def is_infra_mode(cls) -> bool:
  171. return False
  172. @classmethod
  173. def _should_skip_dynamo(cls) -> bool:
  174. """Skip Dynamo when the flag is set to True
  175. This is temporary measure to rollout a feature
  176. that skips PT2 compilation inside __torch_dispatch__
  177. frames.
  178. If this flag is off, we would expect following:
  179. class YoloMode(TorchDispatchMode):
  180. @classmethod
  181. def _should_skip_dynamo(cls):
  182. return False
  183. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  184. return torch.ops.aten.mul.Tensor(args[0], args[1])
  185. x = torch.ones(5)
  186. with YoloMode():
  187. out = torch.compile(torch.add, backend=backend, fullgraph=True)(x, x)
  188. # instead of recursively disabling, we are compiling into __torch_dispatch__
  189. assert len(backend.graphs) == 1
  190. """
  191. return True
  192. @classmethod
  193. def ignore_compile_internals(cls) -> bool:
  194. """Ignore operators that are compiled via torch.compile.
  195. If ``True``, then this TorchDispatchMode ignores operators that
  196. are optimized by :func:`torch.compile`. Mechanically, this involves
  197. turning off the TorchDispatchMode throughout the whole compilation process,
  198. and turning it back on for the runtime of the compiled artifact(s).
  199. For example,
  200. @torch.compile
  201. def f(x):
  202. return x.sin().cos()
  203. with LoggingMode():
  204. f(x)
  205. The above example will not log anything if
  206. ``LoggingMode.ignore_compile_internals()`` is True.
  207. torch.compile will fuse sin() and cos() into a single operation
  208. and this TorchDispatchMode will not be passed sin and cos.
  209. If ``False`` (default), :func:`torch.compile` will respect
  210. the eager semantics of passing this TorchDispatchMode all
  211. operators that would have run during eager execution.
  212. The way this will usually happen is that :func:`torch.compile`
  213. will just fallback to eager-mode PyTorch.
  214. """
  215. if cls.is_infra_mode():
  216. return True
  217. return False
  218. def _get_current_dispatch_mode() -> TorchDispatchMode | None:
  219. """
  220. Return the top user mode on the stack (the next one that would be
  221. executed) if there are any.
  222. """
  223. stack_len = _len_torch_dispatch_stack()
  224. if stack_len > 0:
  225. return _get_dispatch_stack_at(stack_len - 1)
  226. return None
  227. def _detect_infra_mode(key):
  228. if key not in (
  229. torch._C._TorchDispatchModeKey.FUNCTIONAL,
  230. torch._C._TorchDispatchModeKey.PROXY,
  231. ):
  232. raise AssertionError(
  233. f"key must be either FUNCTIONAL ({torch._C._TorchDispatchModeKey.FUNCTIONAL}) \
  234. or PROXY ({torch._C._TorchDispatchModeKey.PROXY}) _TorchDispatchModeKey, \
  235. got {key}"
  236. )
  237. from torch._ops import _get_dispatch_mode_pre_dispatch
  238. pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
  239. post_dispatch_mode = torch._C._get_dispatch_mode(key)
  240. if pre_dispatch_mode is not None and post_dispatch_mode is not None:
  241. raise AssertionError(
  242. "At most one of pre_dispatch_mode and post_dispatch_mode may be active"
  243. )
  244. if pre_dispatch_mode is None:
  245. return post_dispatch_mode
  246. return pre_dispatch_mode
  247. def _unset_infra_mode(key):
  248. from torch._ops import _get_dispatch_mode_pre_dispatch, unset_mode_pre_dispatch
  249. pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
  250. post_dispatch_mode = torch._C._get_dispatch_mode(key)
  251. if pre_dispatch_mode and post_dispatch_mode:
  252. raise AssertionError(
  253. "Can't have active infra mode on both pre and post dispatch mode stack"
  254. )
  255. if pre_dispatch_mode:
  256. mode = unset_mode_pre_dispatch(key)
  257. return mode
  258. if post_dispatch_mode:
  259. return torch._C._unset_dispatch_mode(key)
  260. def _disable_infra_mode(key):
  261. if key not in (
  262. torch._C._TorchDispatchModeKey.FUNCTIONAL,
  263. torch._C._TorchDispatchModeKey.PROXY,
  264. ):
  265. raise AssertionError(
  266. "key must be either FUNCTIONAL or PROXY _TorchDispatchModeKey"
  267. )
  268. mode_unset = _unset_infra_mode(key)
  269. try:
  270. yield mode_unset
  271. finally:
  272. if mode_unset is not None:
  273. _push_mode(mode_unset)
  274. def _get_current_dispatch_mode_stack() -> list[TorchDispatchMode]:
  275. """
  276. Returns the current stack of dispatch modes, with the most recent
  277. (i.e., the one that will be processed first) at the end of the
  278. list (standard stack convention).
  279. """
  280. stack_len = _len_torch_dispatch_stack()
  281. return [_get_dispatch_stack_at(i) for i in range(stack_len)]
  282. def _push_mode(mode: TorchDispatchMode) -> None:
  283. k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
  284. if k is not None and k != torch._C.DispatchKey.PreDispatch:
  285. raise AssertionError(
  286. "mode._dispatch_key must be None or DispatchKey.PreDispatch"
  287. )
  288. if k is None:
  289. _push_on_torch_dispatch_stack(mode)
  290. return
  291. from torch._ops import _set_mode_pre_dispatch, get_cached_ops
  292. # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
  293. # Clear the cache of every op that has been used so far, for this particular key.
  294. ks = torch._C._functionality_to_backend_keys(k)
  295. for op in get_cached_ops():
  296. for key in ks:
  297. op._uncache_dispatch(key)
  298. _set_mode_pre_dispatch(mode)
  299. def _pop_mode(k: DispatchKey | torch._C._TorchDispatchModeKey | None = None):
  300. if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined]
  301. from torch._ops import _pop_mode_from_pre_dispatch
  302. return _pop_mode_from_pre_dispatch()
  303. if k is None or isinstance(k, torch._C._TorchDispatchModeKey):
  304. return _pop_torch_dispatch_stack(k)
  305. @contextlib.contextmanager
  306. def _pop_mode_temporarily(k: DispatchKey | None = None):
  307. old = _pop_mode(k)
  308. try:
  309. yield old
  310. finally:
  311. _push_mode(old)
  312. @contextlib.contextmanager
  313. def _disable_current_modes():
  314. from torch._ops import (
  315. _len_torch_dispatch_stack_pre_dispatch,
  316. _pop_mode_from_pre_dispatch,
  317. )
  318. from torch._subclasses.functional_tensor import FunctionalTensorMode
  319. from torch._subclasses.schema_check_mode import SchemaCheckMode
  320. from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
  321. mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch()
  322. old_pre_dispatch_modes = [
  323. _pop_mode_from_pre_dispatch() for _ in range(mode_len_pre_dispatch)
  324. ]
  325. has_proxy_mode_in_pre_dispatch = False
  326. has_functional_mode_in_pre_dispatch = False
  327. has_schema_check_mode_in_pre_dispatch = False
  328. for i in old_pre_dispatch_modes:
  329. if isinstance(i, ProxyTorchDispatchMode):
  330. has_proxy_mode_in_pre_dispatch = True
  331. if isinstance(i, FunctionalTensorMode):
  332. has_functional_mode_in_pre_dispatch = True
  333. if isinstance(i, SchemaCheckMode):
  334. has_schema_check_mode_in_pre_dispatch = True
  335. mode_len = _len_torch_dispatch_stack()
  336. old_modes = [_pop_mode() for _ in range(mode_len)]
  337. for old in old_modes:
  338. if (
  339. isinstance(old, FunctionalTensorMode)
  340. and has_functional_mode_in_pre_dispatch
  341. ):
  342. raise AssertionError(
  343. "Can't have FunctionalMode available both in PreDispatch and Python Key"
  344. )
  345. if isinstance(old, ProxyTorchDispatchMode) and has_proxy_mode_in_pre_dispatch:
  346. raise AssertionError(
  347. "Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key"
  348. )
  349. if isinstance(old, SchemaCheckMode) and has_schema_check_mode_in_pre_dispatch:
  350. raise AssertionError(
  351. "Can't have SchemaCheckMode available both in PreDispatch and Python Key"
  352. )
  353. # Manually disable proxy and fake modes, if any are active
  354. try:
  355. yield old_pre_dispatch_modes + old_modes
  356. finally:
  357. for mode in reversed(old_modes):
  358. _push_mode(mode)
  359. for mode in reversed(old_pre_dispatch_modes):
  360. _push_mode(mode)
  361. class BaseTorchDispatchMode(TorchDispatchMode):
  362. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  363. if kwargs is None:
  364. kwargs = {}
  365. return func(*args, **kwargs)
  366. # Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
  367. class TensorWithFlatten(Protocol):
  368. def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ...
  369. @staticmethod
  370. def __tensor_unflatten__(
  371. inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int
  372. ) -> torch.Tensor: ...
  373. # It would be really nice to be able to say that the return of
  374. # is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
  375. # TensorWithFlatten] - but that doesn't exist.
  376. shape: torch._C.Size
  377. @overload
  378. def stride(self, dim: None = None) -> tuple[int, ...]: ...
  379. @overload
  380. def stride(self, dim: int) -> int: ...
  381. @overload
  382. def size(self, dim: None = None) -> tuple[int, ...]: ...
  383. @overload
  384. def size(self, dim: int) -> int: ...
  385. def storage_offset(self) -> int: ...
  386. def dim(self) -> int: ...
  387. @overload
  388. def to(
  389. self,
  390. dtype: torch.types._dtype,
  391. non_blocking: bool = False,
  392. copy: bool = False,
  393. *,
  394. memory_format: torch.memory_format | None = None,
  395. ) -> torch.Tensor: ...
  396. @overload
  397. def to(
  398. self,
  399. device: torch._prims_common.DeviceLikeType | None = None,
  400. dtype: torch.types._dtype | None = None,
  401. non_blocking: bool = False,
  402. copy: bool = False,
  403. *,
  404. memory_format: torch.memory_format | None = None,
  405. ) -> torch.Tensor: ...
  406. @overload
  407. def to(
  408. self,
  409. other: torch.Tensor,
  410. non_blocking: bool = False,
  411. copy: bool = False,
  412. *,
  413. memory_format: torch.memory_format | None = None,
  414. ) -> torch.Tensor: ...
  415. def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
  416. """
  417. Returns whether or not a tensor subclass that implements __torch_dispatch__
  418. is 'traceable' with torch.compile.
  419. In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
  420. It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
  421. It is also expected to obey some restrictions around traceability and aliasing:
  422. * The subclass's __torch_dispatch__() implementation should desugar into pytorch
  423. dispatcher operations that can be traced into a graph.
  424. * The subclass should use return_and_correct_aliasing(). This is needed today to make
  425. sure that torch.compile does the right thing in a few cases around input mutation
  426. and output aliasing.
  427. Expected magic method signatures:
  428. attrs, ctx = t.__tensor_flatten__()
  429. attrs: list of attribute name strings for inner tensors
  430. ctx: dict containing any other subclass-specific metadata needed for unflattening
  431. t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
  432. inner_tensors: dict mapping attribute name -> tensor for each inner tensor
  433. ctx: dict with subclass metadata in the form that __tensor_flatten__() produces
  434. outer_size: expected (possibly symbolic) size that the returned subclass
  435. instance should have. Note that this arg is useful for certain subclasses
  436. that require the shape info to be constructed. In most cases, this arg can be
  437. safely ignored.
  438. outer_stride: expected (possibly symbolic) stride that the returned subclass
  439. instance should have. Note that this arg is useful for certain subclasses
  440. that require the stride info to be constructed. In most cases, this arg can be
  441. safely ignored.
  442. """
  443. is_subclass = isinstance(t, torch.Tensor) and type(t) is not torch.Tensor
  444. return (
  445. is_subclass
  446. and hasattr(t, "__tensor_flatten__")
  447. and hasattr(t, "__tensor_unflatten__")
  448. )
  449. def is_traceable_wrapper_subclass_type(t: type) -> TypeIs[type[TensorWithFlatten]]:
  450. """Same as above, but takes a type argument instead of an instance."""
  451. return (
  452. issubclass(t, torch.Tensor)
  453. and t is not torch.Tensor
  454. and hasattr(t, "__tensor_flatten__")
  455. and hasattr(t, "__tensor_unflatten__")
  456. )
  457. def transform_subclass(t, callback, outer_size=None, outer_stride=None):
  458. """
  459. Given a traceable, wrapper tensor subclass ``t`` that implements
  460. ``__torch_dispatch__`` and holds some inner tensors,
  461. and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``,
  462. `transform_subclass` will construct a fresh instance of the wrapper tensor subclass.
  463. It will do so by grabbing each inner tensor attribute from the wrapper,
  464. passing them into ``callback`` to get a transformed tensor,
  465. and putting each transformed tensor into the fresh tensor subclass instance.
  466. Note: this function will not handle ensuring that the fresh subclass
  467. gets the same (autograd, and aliasing) metadata as the original tensor.
  468. This is generally handled in other subsystems like AOTAutograd.
  469. """
  470. outer_size = outer_size if outer_size is not None else t.size()
  471. outer_stride = outer_stride if outer_stride is not None else t.stride()
  472. attrs, ctx = t.__tensor_flatten__()
  473. transformed_tensors_dict = {}
  474. for attr in attrs:
  475. transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
  476. sub = type(t).__tensor_unflatten__(
  477. transformed_tensors_dict, ctx, outer_size, outer_stride
  478. )
  479. # NB: Purposefully guard here to simplify the inner / outer symbols.
  480. # Using sym_eq() for symbolic comparison can result in an expression that's too
  481. # difficult to guard on, so we use == here.
  482. if sub.shape != outer_size:
  483. raise AssertionError(
  484. f"Expected return value from {type(t)}__tensor_unflatten__() to have "
  485. f"shape equal to {outer_size}, but got: {sub.shape}"
  486. )
  487. if sub.stride() != outer_stride:
  488. raise AssertionError(
  489. f"Expected return value from {type(t)}__tensor_unflatten__() to have "
  490. f"stride equal to {outer_stride}, but got: {sub.stride()}"
  491. )
  492. return sub
  493. def _correct_storage_aliasing(func, schema_info, args, outs) -> None:
  494. """
  495. Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema),
  496. and the inputs/outputs to the OpOverload,
  497. this function checks to see if func is a view operator
  498. (by checking if any of the outputs in the op's schema
  499. are immutable aliases of inputs).
  500. If so, this function manually aliases the storage of the output tensor
  501. with its corresponding input tensor alias.
  502. It does this by unsafely overwriting the storage field of the output tensor
  503. to be the same storage as the input.
  504. """
  505. if not isinstance(func, torch._ops.OpOverload):
  506. raise AssertionError(f"func must be an OpOverload, got {type(args)}")
  507. if not isinstance(args, tuple):
  508. raise AssertionError(f"args must be a tuple, got {type(args)}")
  509. if not isinstance(outs, (list, tuple)):
  510. raise AssertionError(f"outs must be a list or tuple, got {type(args)}")
  511. def alias_non_inplace_storage(arg, ret) -> None:
  512. # This is hopefully a reasonable assert:
  513. # subclasses that rely on this API for output aliasing
  514. # should always return wrapper tensor subclasses for us to manually alias.
  515. # in theory if a subclass that needs this API wants to sometimes return
  516. # plain tensors, we could remove the assert and just not perform the aliasing,
  517. # but it seems safer to learn more about this case first.
  518. #
  519. # Performance note: This is all just to assert that the argument and result
  520. # types match, checking that is cheaper than is_traceable_wrapper_subclass_type,
  521. # and multiple returns are relatively unlikely, so just check up front!
  522. arg_type = type(arg)
  523. ret_type = type(ret)
  524. if arg_type is not ret_type and (
  525. is_traceable_wrapper_subclass_type(arg_type)
  526. or is_traceable_wrapper_subclass_type(ret_type)
  527. ):
  528. ret_list = ret if isinstance(ret, list) else [ret]
  529. for r in ret_list:
  530. if type(arg) is not type(r):
  531. raise AssertionError(
  532. f"Called {str(func)} with input of type {type(arg)}\n"
  533. f"and output of type {type(ret)}. But expected types to match."
  534. )
  535. # Need to call a non-dispatcher helper, because we explicitly do **not**
  536. # want our subclass to intercept the set_() call.
  537. # instead, our subclass should directly have its storage swapped out.
  538. # we **explicitly** don't want to reset the sizes on ret, if the storage implies a size change.
  539. # Why?
  540. # The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
  541. # We just want to "fix up" the storage aliasing, without modifying or output's metadata.
  542. # Example: out = inp.expand(inp.shape[0], inp.shape[0])
  543. # This requires swapping the storage of out to be the same as inp,
  544. # but we do *not* want it to change the sizes/strides that were compute for out.
  545. if isinstance(ret, list):
  546. for r in ret:
  547. torch._functionalize_unsafe_set(r, arg)
  548. else:
  549. if not isinstance(ret, torch.Tensor):
  550. raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
  551. torch._functionalize_unsafe_set(ret, arg)
  552. for arg_idx, return_idx in schema_info.read_only_alias_match_indexes:
  553. alias_non_inplace_storage(args[arg_idx], outs[return_idx])
  554. def _get_write_alias(x) -> str | None:
  555. alias_set = x.alias_set
  556. if not alias_set or not x.is_write:
  557. return None
  558. # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
  559. if len(alias_set) != 1:
  560. raise AssertionError("Expected alias_set to contain exactly one element")
  561. # timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
  562. # set of size 1 on Python 3.13.
  563. return next(iter(alias_set))
  564. # This abstracts over the fact that in return_and_correct_aliasing,
  565. # we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
  566. # and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
  567. @dataclass
  568. class AliasInfo:
  569. alias_set: set[str]
  570. is_write: bool
  571. name: str | None
  572. @dataclass
  573. class SchemaInfo:
  574. args: list[AliasInfo]
  575. outs: list[AliasInfo]
  576. is_inplace_view_op: bool
  577. # [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce
  578. # all-Nones result to empty list instead, and we don't support
  579. # some-but-not-all-Nones.
  580. outs_write_aliases: list[str] | None
  581. # List of (arg_idx, return_idx) where args[arg_idx].alias_set &
  582. # outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write.
  583. read_only_alias_match_indexes: list[tuple[int, int]]
  584. # Given an OpOverload, returns schema information on it.
  585. # This is cached for efficiency, since it can involve running torchgen
  586. @functools.cache
  587. def get_alias_info(func) -> SchemaInfo:
  588. # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
  589. # properly for some ops that output tensorlists)
  590. if func.namespace == "aten":
  591. torchgen_schema_str = str(func._schema)
  592. if not torchgen_schema_str.startswith("aten::"):
  593. raise AssertionError(
  594. "Expected torchgen schema string to start with 'aten::'"
  595. )
  596. # remove the aten:: namespace, which is added by the torchscript parser,
  597. # and torchgen doesn't know how to handle
  598. torchgen_schema_str = torchgen_schema_str[6:]
  599. import re
  600. # the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1],
  601. # which torchgen chokes on.
  602. torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
  603. torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
  604. # for aten::rot90 / aten:fft_*
  605. torchgen_schema_str = re.sub(
  606. r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str
  607. )
  608. torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
  609. arg_schemas = [
  610. AliasInfo(
  611. alias_set=(
  612. set() if a.annotation is None else set(a.annotation.alias_set)
  613. ),
  614. is_write=a.annotation is not None and a.annotation.is_write,
  615. name=a.name,
  616. )
  617. for a in torchgen_schema.arguments.flat_all
  618. ]
  619. out_schemas = [
  620. AliasInfo(
  621. alias_set=(
  622. set() if a.annotation is None else set(a.annotation.alias_set)
  623. ),
  624. is_write=a.annotation is not None and a.annotation.is_write,
  625. name=a.name,
  626. )
  627. for a in torchgen_schema.returns
  628. ]
  629. else:
  630. # For non-aten ops, torchgen is untested so we rely on torchscript schema parsing
  631. arg_schemas = [
  632. AliasInfo(
  633. alias_set=(
  634. set() if a.alias_info is None else set(a.alias_info.before_set)
  635. ),
  636. is_write=a.alias_info is not None and a.alias_info.is_write,
  637. name=a.name,
  638. )
  639. for a in func._schema.arguments
  640. ]
  641. out_schemas = [
  642. AliasInfo(
  643. alias_set=(
  644. set() if a.alias_info is None else set(a.alias_info.before_set)
  645. ),
  646. is_write=a.alias_info is not None and a.alias_info.is_write,
  647. name=a.name,
  648. )
  649. for a in func._schema.returns
  650. ]
  651. read_only_alias_match_indexes = []
  652. for arg_idx, schema_arg in enumerate(arg_schemas):
  653. for return_idx, schema_out in enumerate(out_schemas):
  654. is_read_only_alias_match = (
  655. schema_arg.alias_set & schema_out.alias_set
  656. ) and not schema_arg.is_write
  657. if is_read_only_alias_match:
  658. read_only_alias_match_indexes.append((arg_idx, return_idx))
  659. outs_write_aliases_list: list[str | None] = [
  660. _get_write_alias(r) for r in out_schemas
  661. ]
  662. non_nones = sum(x is not None for x in outs_write_aliases_list)
  663. if non_nones == 0:
  664. outs_write_aliases: list[str] | None = None
  665. elif non_nones != len(outs_write_aliases_list):
  666. # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
  667. raise RuntimeError("Unsupported schema: " + str(func._schema))
  668. else:
  669. outs_write_aliases = cast(list[str], outs_write_aliases_list)
  670. schema_info = SchemaInfo(
  671. args=arg_schemas,
  672. outs=out_schemas,
  673. # This check is surprisingly expensive because pybind11 enum_s are
  674. # inefficient. Just cache it.
  675. is_inplace_view_op=torch.Tag.inplace_view in func.tags,
  676. outs_write_aliases=outs_write_aliases,
  677. read_only_alias_match_indexes=read_only_alias_match_indexes,
  678. )
  679. return schema_info
  680. def autograd_would_have_decomposed(
  681. func: torch._ops.OpOverload, flat_args: Sequence[torch.Tensor | object]
  682. ) -> bool:
  683. """
  684. Suppose that an operator has CompositeImplicitAutograd decomp registered.
  685. Would autograd have used this decomposition? It will only use it if there
  686. isn't an explicit backend registration for the device as well. This function
  687. will tell if this would have occurred.
  688. Why do we need to apply these decompositions later? When inference mode is
  689. on, the autograd key is bypassed entirely, so a lower level mode cannot rely
  690. on the decomposition have been applied. It's easy to accidentally never apply
  691. the decomposition, resulting in an operator showing up in a graph that
  692. is unexpected.
  693. Why do we need to AVOID applying the decomposition when autograd wouldn't
  694. have decomposed? If autograd doesn't decompose, this means in eager mode
  695. we would have run the fused kernel. It must be possible to trace this
  696. fused kernel directly into the graph for fidelity with eager (NB: a user
  697. has the option of then further decomposing at proxy tensor mode via
  698. decomposition table, but we must preserve it to proxy mode to have the
  699. choice.)
  700. Why does functionalization need to also perform the test here? This is
  701. because some CompositeImplicitAutograd decompositions are not functional.
  702. If we are eventually going to decompose, we need to do this while we can
  703. still turn functionalization back on, so those decompositions get functionalized.
  704. So an early decomposition in functionalization may still be necessary. Note that
  705. if proxy tensor decomposition process could turn functionalization back on, this
  706. wouldn't be necessary, and maybe that is a useful thing to do anyway because
  707. the decomposition table is user specified and a user could violate the functional
  708. decomp requirement with a bad decomp. If this happened, then you could always
  709. pass through functionalization.
  710. """
  711. has_backend_registration = False
  712. for a in flat_args:
  713. if isinstance(a, torch.Tensor):
  714. backend_key = torch._C._parse_dispatch_key(
  715. torch._C._dispatch_key_for_device(a.device.type)
  716. )
  717. if backend_key is None:
  718. raise AssertionError(
  719. f"failed to parse dispatch key for device {a.device.type}"
  720. )
  721. # TODO: use func.has_kernel_for_dispatch_key(backend_key)
  722. # but this one checks py_impl and CompositeImplicitAutograd
  723. # incorrectly shows up as has backend reg here
  724. has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
  725. func.name(), backend_key
  726. )
  727. # in theory we should take all backend keys and take the highest priority one
  728. # to properly mimic the dispatcher,
  729. # this just grabs the first tensor and takes its device key
  730. break
  731. return not has_backend_registration
  732. def return_and_correct_aliasing(func, args, kwargs, out):
  733. """
  734. This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
  735. that would like to work with torch.compile. It ensures that the subclass
  736. properly implements the aliasing behavior of every op,
  737. which is needed for correctness in AOTAutograd.
  738. This function will handle:
  739. * When we see a view op, we will alias the storages of any
  740. input and output tensor subclasses
  741. * When we see an inplace or out= op, we will directly
  742. return the corresponding input tensor, instead of returning
  743. a (potentially) fresh output tensor.
  744. """
  745. # Caching here because torchgen parsing is definitely not fast, and this function is called
  746. # once for every op in the graph during functionalization.
  747. schema_info = get_alias_info(func)
  748. def get_arg_from_alias(output_alias, schema_info, args, kwargs):
  749. new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
  750. func, args=args, kwargs=kwargs
  751. )
  752. arg_indices = [
  753. i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set
  754. ]
  755. # For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments.
  756. if len(arg_indices) != 1:
  757. raise AssertionError(
  758. "Expected exactly one argument index for the given output alias"
  759. )
  760. idx = arg_indices[0]
  761. arg_info = schema_info.args[idx]
  762. if arg_info.name is not None and arg_info.name in new_kwargs:
  763. return new_kwargs[arg_info.name]
  764. return new_args[idx]
  765. # Fix up the storages of any outs so that they point to the same storage as the input,
  766. # if func is a view op.
  767. _correct_storage_aliasing(
  768. func, schema_info, args, (out,) if not isinstance(out, tuple) else out
  769. )
  770. # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
  771. # metadata is set correctly.
  772. if schema_info.is_inplace_view_op:
  773. # no_dispatch() to make sure that we secretly change the metadata on the wrapper,
  774. # but don't end up dispatching the op anywhere else.
  775. mutated_args = [
  776. x
  777. for i, x in enumerate(args)
  778. if _get_write_alias(schema_info.args[i]) is not None
  779. ]
  780. # Assumption: we have a very small number of inplace_view ops that follow a strict schema:
  781. # there is only a single argument that gets its metadata mutated.
  782. if len(mutated_args) != 1:
  783. raise AssertionError(
  784. "expected exactly one mutated arg for inplace_view ops"
  785. )
  786. # This check exists because we generally *do* want to update the metadata of any wrapper subclasses,
  787. # but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor.
  788. # so we don't actually need to update the metadata (and attempting to do so causes errors)
  789. from torch._subclasses.functional_tensor import FunctionalTensor
  790. if not isinstance(mutated_args[0], FunctionalTensor):
  791. with torch.utils._mode_utils.no_dispatch():
  792. # See Note: [Fake Tensor Dispatch Keys]
  793. # we're borrowing the way it modifies dispatch key TLS.
  794. meta_in_tls = torch._C._meta_in_tls_dispatch_include()
  795. torch._C._set_meta_in_tls_dispatch_include(True)
  796. try:
  797. func(*args, **kwargs)
  798. finally:
  799. torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
  800. # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
  801. schema_info_outs_write_aliases = schema_info.outs_write_aliases
  802. # simple case: none of our outputs have mutable aliases, so we can return the output as-is
  803. if schema_info_outs_write_aliases is None:
  804. return out
  805. if len(schema_info_outs_write_aliases) == 1:
  806. return get_arg_from_alias(
  807. schema_info_outs_write_aliases[0], schema_info, args, kwargs
  808. )
  809. # In the multi-return case, all aten ops return a tuple / list, so cast accordingly.
  810. outs_to_return = type(out)(
  811. [
  812. (get_arg_from_alias(write_alias, schema_info, args, kwargs))
  813. for write_alias in schema_info_outs_write_aliases
  814. ]
  815. )
  816. return outs_to_return