python.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import itertools
  2. import unittest.mock
  3. from collections.abc import Callable, Generator, Iterator
  4. from contextlib import contextmanager
  5. from typing import TypeVar, Union
  6. from typing_extensions import ParamSpec
  7. import torch
  8. import torch._C
  9. import torch._ops
  10. import torch.utils._python_dispatch
  11. import torch.utils._pytree as pytree
  12. from torch._C import DispatchKey
  13. __all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
  14. no_python_dispatcher = torch._C._DisablePythonDispatcher
  15. enable_python_dispatcher = torch._C._EnablePythonDispatcher
  16. enable_pre_dispatch = torch._C._EnablePreDispatch
  17. CROSSREF_FUNCTIONALIZE = False
  18. _P = ParamSpec("_P")
  19. _T = TypeVar("_T")
  20. _R = TypeVar("_R")
  21. def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
  22. """
  23. Warning: the set of overloads this will report is very subtle. It is precisely
  24. the set of torch.ops functions that have actually been accessed from Python
  25. (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
  26. from the set of registered operators, which will in general be a larger set,
  27. as this would include all operators which we ran C++ static initializers or
  28. Python operator registration on. This does not eagerly populate the list on
  29. torch.ops.aten; this list is lazy!
  30. In other words, this is good for traversing over everything that has an
  31. OpOverload object allocated in Python. We use it for cache invalidation, but
  32. don't rely on this list being complete.
  33. Note that even if we did report all C++ registered overloads, this isn't guaranteed
  34. to be complete either, as a subsequent lazy load of a library which triggers more
  35. registrations could add more things to the set.
  36. """
  37. for ns in torch.ops:
  38. packets = getattr(torch.ops, ns)
  39. for op_name in packets:
  40. packet = getattr(packets, op_name)
  41. for overload in packet:
  42. yield getattr(packet, overload)
  43. @contextmanager
  44. def suspend_functionalization() -> Generator[None, None, None]:
  45. f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
  46. torch._C.DispatchKey.Functionalize
  47. )
  48. f_rv = torch._C._functionalization_reapply_views_tls()
  49. if f_tls:
  50. torch._disable_functionalization()
  51. try:
  52. yield
  53. finally:
  54. if f_tls:
  55. torch._enable_functionalization(reapply_views=f_rv)
  56. def check_tensor_metadata_matches(
  57. nv: torch.Tensor, rv: torch.Tensor, desc: Callable[[], str]
  58. ) -> None:
  59. if not callable(desc):
  60. raise AssertionError(f"desc must be callable, got {type(desc)}")
  61. if nv.size() != rv.size():
  62. raise AssertionError(f"{desc()}: sizes {nv.size()} != {rv.size()}")
  63. if nv.dtype != rv.dtype:
  64. raise AssertionError(f"{desc()}: dtype {nv.dtype} != {rv.dtype}")
  65. same_strides, idx = torch._prims_common.check_significant_strides(
  66. nv, rv, only_cuda=False
  67. )
  68. if not same_strides:
  69. raise AssertionError(
  70. f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
  71. )
  72. def check_metadata_matches(n: object, r: object, desc: Callable[[], str]) -> None:
  73. if not callable(desc):
  74. raise AssertionError(f"desc must be callable, got {type(desc)}")
  75. n_vals, _n_spec = pytree.tree_flatten(n)
  76. r_vals, _r_spec = pytree.tree_flatten(r)
  77. # TODO: test the specs match; empirically sometimes we have a tuple
  78. # on one side and a list on the other
  79. if len(n_vals) != len(r_vals):
  80. raise AssertionError(f"{len(n_vals)} != {len(r_vals)}")
  81. for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
  82. if not isinstance(rv, torch.Tensor):
  83. continue
  84. check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
  85. class Lit:
  86. def __init__(self, s: str) -> None:
  87. self.s = s
  88. def __repr__(self) -> str:
  89. return self.s
  90. def _fmt(a: object) -> object:
  91. if isinstance(a, torch.Tensor):
  92. return Lit(
  93. f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
  94. )
  95. else:
  96. return a
  97. def make_crossref_functionalize(
  98. op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
  99. ) -> Union[Callable[_P, _T], DispatchKey]:
  100. from torch._subclasses.fake_tensor import FakeTensorMode
  101. # This case is pretty weird, suppress it for now
  102. if op is torch.ops.aten.lift_fresh.default:
  103. return final_key
  104. def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  105. fake_mode = FakeTensorMode()
  106. def fakeify_defun(t: _R) -> _R | torch._subclasses.fake_tensor.FakeTensor:
  107. if isinstance(t, torch.Tensor):
  108. if torch._is_functional_tensor(t):
  109. r = torch._from_functional_tensor(t)
  110. # NB: This assumes that the inner tensor sizes/strides match
  111. # the outer tensor sizes/strides. This doesn't necessarily have to
  112. # be the case, see discussion at
  113. # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
  114. if t.size() != r.size():
  115. raise AssertionError(f"size mismatch: {t.size()} != {r.size()}")
  116. if t.stride() != r.stride():
  117. raise AssertionError(
  118. f"stride mismatch: {t.stride()} != {r.stride()}"
  119. )
  120. else:
  121. r = t
  122. # TODO: suppress guards
  123. return fake_mode.from_tensor(r)
  124. return t
  125. def maybe_detach(t: _R) -> _R | torch.Tensor:
  126. if isinstance(t, torch.Tensor):
  127. return t.detach()
  128. else:
  129. return t
  130. # TODO: This probably does the wrong thing if you're running other
  131. # substantive modes with the normal op outside here
  132. with (
  133. torch.utils._python_dispatch._disable_current_modes(),
  134. suspend_functionalization(),
  135. ):
  136. f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
  137. orig_f_args, orig_f_kwargs = pytree.tree_map(
  138. maybe_detach, (f_args, f_kwargs)
  139. )
  140. with fake_mode:
  141. f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec]
  142. r = op._op_dk(final_key, *args, **kwargs)
  143. def desc() -> str:
  144. fmt_args = ", ".join(
  145. itertools.chain(
  146. (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
  147. (
  148. f"{k}={pytree.tree_map(_fmt, v)}"
  149. for k, v in orig_f_kwargs.items()
  150. ),
  151. )
  152. )
  153. return f"{op}({fmt_args})"
  154. check_metadata_matches(f_r, r, desc)
  155. return r
  156. return handler
  157. # NB: enabling this is slow, don't do it in a hot loop. This is purely
  158. # for debugging purposes.
  159. @contextmanager
  160. def enable_crossref_functionalize() -> Generator[None, None, None]:
  161. for op in all_py_loaded_overloads():
  162. op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
  163. try:
  164. with (
  165. enable_python_dispatcher(),
  166. unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True),
  167. ):
  168. yield
  169. finally:
  170. for op in all_py_loaded_overloads():
  171. op._uncache_dispatch(torch._C.DispatchKey.Functionalize)