| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- import itertools
- import unittest.mock
- from collections.abc import Callable, Generator, Iterator
- from contextlib import contextmanager
- from typing import TypeVar, Union
- from typing_extensions import ParamSpec
- import torch
- import torch._C
- import torch._ops
- import torch.utils._python_dispatch
- import torch.utils._pytree as pytree
- from torch._C import DispatchKey
- __all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
- no_python_dispatcher = torch._C._DisablePythonDispatcher
- enable_python_dispatcher = torch._C._EnablePythonDispatcher
- enable_pre_dispatch = torch._C._EnablePreDispatch
- CROSSREF_FUNCTIONALIZE = False
- _P = ParamSpec("_P")
- _T = TypeVar("_T")
- _R = TypeVar("_R")
- def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
- """
- Warning: the set of overloads this will report is very subtle. It is precisely
- the set of torch.ops functions that have actually been accessed from Python
- (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
- from the set of registered operators, which will in general be a larger set,
- as this would include all operators which we ran C++ static initializers or
- Python operator registration on. This does not eagerly populate the list on
- torch.ops.aten; this list is lazy!
- In other words, this is good for traversing over everything that has an
- OpOverload object allocated in Python. We use it for cache invalidation, but
- don't rely on this list being complete.
- Note that even if we did report all C++ registered overloads, this isn't guaranteed
- to be complete either, as a subsequent lazy load of a library which triggers more
- registrations could add more things to the set.
- """
- for ns in torch.ops:
- packets = getattr(torch.ops, ns)
- for op_name in packets:
- packet = getattr(packets, op_name)
- for overload in packet:
- yield getattr(packet, overload)
- @contextmanager
- def suspend_functionalization() -> Generator[None, None, None]:
- f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
- torch._C.DispatchKey.Functionalize
- )
- f_rv = torch._C._functionalization_reapply_views_tls()
- if f_tls:
- torch._disable_functionalization()
- try:
- yield
- finally:
- if f_tls:
- torch._enable_functionalization(reapply_views=f_rv)
- def check_tensor_metadata_matches(
- nv: torch.Tensor, rv: torch.Tensor, desc: Callable[[], str]
- ) -> None:
- if not callable(desc):
- raise AssertionError(f"desc must be callable, got {type(desc)}")
- if nv.size() != rv.size():
- raise AssertionError(f"{desc()}: sizes {nv.size()} != {rv.size()}")
- if nv.dtype != rv.dtype:
- raise AssertionError(f"{desc()}: dtype {nv.dtype} != {rv.dtype}")
- same_strides, idx = torch._prims_common.check_significant_strides(
- nv, rv, only_cuda=False
- )
- if not same_strides:
- raise AssertionError(
- f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
- )
- def check_metadata_matches(n: object, r: object, desc: Callable[[], str]) -> None:
- if not callable(desc):
- raise AssertionError(f"desc must be callable, got {type(desc)}")
- n_vals, _n_spec = pytree.tree_flatten(n)
- r_vals, _r_spec = pytree.tree_flatten(r)
- # TODO: test the specs match; empirically sometimes we have a tuple
- # on one side and a list on the other
- if len(n_vals) != len(r_vals):
- raise AssertionError(f"{len(n_vals)} != {len(r_vals)}")
- for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
- if not isinstance(rv, torch.Tensor):
- continue
- check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
- class Lit:
- def __init__(self, s: str) -> None:
- self.s = s
- def __repr__(self) -> str:
- return self.s
- def _fmt(a: object) -> object:
- if isinstance(a, torch.Tensor):
- return Lit(
- f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
- )
- else:
- return a
- def make_crossref_functionalize(
- op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey
- ) -> Union[Callable[_P, _T], DispatchKey]:
- from torch._subclasses.fake_tensor import FakeTensorMode
- # This case is pretty weird, suppress it for now
- if op is torch.ops.aten.lift_fresh.default:
- return final_key
- def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T:
- fake_mode = FakeTensorMode()
- def fakeify_defun(t: _R) -> _R | torch._subclasses.fake_tensor.FakeTensor:
- if isinstance(t, torch.Tensor):
- if torch._is_functional_tensor(t):
- r = torch._from_functional_tensor(t)
- # NB: This assumes that the inner tensor sizes/strides match
- # the outer tensor sizes/strides. This doesn't necessarily have to
- # be the case, see discussion at
- # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
- if t.size() != r.size():
- raise AssertionError(f"size mismatch: {t.size()} != {r.size()}")
- if t.stride() != r.stride():
- raise AssertionError(
- f"stride mismatch: {t.stride()} != {r.stride()}"
- )
- else:
- r = t
- # TODO: suppress guards
- return fake_mode.from_tensor(r)
- return t
- def maybe_detach(t: _R) -> _R | torch.Tensor:
- if isinstance(t, torch.Tensor):
- return t.detach()
- else:
- return t
- # TODO: This probably does the wrong thing if you're running other
- # substantive modes with the normal op outside here
- with (
- torch.utils._python_dispatch._disable_current_modes(),
- suspend_functionalization(),
- ):
- f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
- orig_f_args, orig_f_kwargs = pytree.tree_map(
- maybe_detach, (f_args, f_kwargs)
- )
- with fake_mode:
- f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec]
- r = op._op_dk(final_key, *args, **kwargs)
- def desc() -> str:
- fmt_args = ", ".join(
- itertools.chain(
- (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
- (
- f"{k}={pytree.tree_map(_fmt, v)}"
- for k, v in orig_f_kwargs.items()
- ),
- )
- )
- return f"{op}({fmt_args})"
- check_metadata_matches(f_r, r, desc)
- return r
- return handler
- # NB: enabling this is slow, don't do it in a hot loop. This is purely
- # for debugging purposes.
- @contextmanager
- def enable_crossref_functionalize() -> Generator[None, None, None]:
- for op in all_py_loaded_overloads():
- op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
- try:
- with (
- enable_python_dispatcher(),
- unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True),
- ):
- yield
- finally:
- for op in all_py_loaded_overloads():
- op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|