| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- from __future__ import annotations
- import functools
- import warnings
- from typing import Any, TYPE_CHECKING
- import torch
- import torch.utils._pytree as pytree
- from torch._subclasses.fake_tensor import (
- FakeTensor,
- FakeTensorMode,
- MetadataMismatchError,
- tree_flatten_only,
- UnsupportedFakeTensorException,
- )
- from torch.utils._python_dispatch import TorchDispatchMode
- if TYPE_CHECKING:
- from collections.abc import Callable, Mapping, Sequence
- from torch._ops import OpOverload
- from torch.utils._pytree import PyTree
- aten = torch._ops.ops.aten
- def outputs_alias_inputs(outputs: PyTree, inputs: PyTree) -> bool:
- input_storages = {
- inp._typed_storage()._cdata
- for inp in tree_flatten_only(torch.Tensor, inputs)
- if torch._C._has_storage(inp)
- }
- return any(
- torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
- for out in tree_flatten_only(torch.Tensor, outputs)
- )
- def outputs_are_inputs(outputs: PyTree, inputs: PyTree) -> bool:
- input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
- return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
- def output_alias_each_other(outputs: PyTree) -> bool:
- storages = set()
- for out in tree_flatten_only(torch.Tensor, outputs):
- if not torch._C._has_storage(out):
- continue
- stor = out._typed_storage()._cdata
- if stor in storages:
- return True
- storages.add(stor)
- return False
- def _check_alias_info(
- context: str,
- real_out: PyTree,
- real_in: PyTree,
- fake_out: PyTree,
- fake_in: PyTree,
- ) -> None:
- r_aliasing = outputs_alias_inputs(real_out, real_in)
- f_aliasing = outputs_alias_inputs(fake_out, fake_in)
- if r_aliasing != f_aliasing:
- raise MetadataMismatchError(
- f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
- )
- r_identity_eq = outputs_are_inputs(real_out, real_in)
- f_identity_eq = outputs_are_inputs(fake_out, fake_in)
- if r_identity_eq != f_identity_eq:
- raise MetadataMismatchError(
- f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
- )
- r_output_alias_each_other = output_alias_each_other(real_out)
- f_output_alias_each_other = output_alias_each_other(fake_out)
- if r_output_alias_each_other != f_output_alias_each_other:
- raise MetadataMismatchError(
- f"{context} mismatch in outputs_alias_each_other check "
- f"{f_output_alias_each_other} != {r_output_alias_each_other}"
- )
- def is_sdpa_error(func: OpOverload, idx: int, e: Exception) -> bool:
- if (
- (
- func is aten._scaled_dot_product_flash_attention.default
- or func is aten._flash_attention_forward.default
- )
- and idx in (6, 7)
- and "Devices" in repr(e)
- ):
- return True
- if (
- (
- func is aten._scaled_dot_product_efficient_attention.default
- or func is aten._efficient_attention_forward.default
- )
- and idx in (2, 3)
- and "Devices" in repr(e)
- ):
- return True
- if (
- func is aten._scaled_dot_product_cudnn_attention.default
- and idx in (6, 7)
- and "Devices" in repr(e)
- ):
- return True
- return False
- def try_convert_fake_to_real(
- ten_list: list[FakeTensor | Any],
- ) -> list[FakeTensor | torch.Tensor | Any]:
- """
- Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
- the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will
- remain in the list.
- Note: this is not currently optimized (makes copies of the meta converter internal dictionaries)
- """
- fake_tensor = next(
- (item for item in ten_list if isinstance(item, FakeTensor)), None
- )
- if fake_tensor is None:
- return ten_list
- fake_mode = fake_tensor.fake_mode
- meta_converter = fake_mode.fake_tensor_converter.meta_converter
- desc = meta_converter.describer
- storage_to_key = {v: k for k, v in meta_converter.storage_memo.items()}
- key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()}
- out = []
- for t in ten_list:
- if not isinstance(t, FakeTensor) or t.layout != torch.strided:
- out.append(t)
- continue
- key = storage_to_key.get(t.untyped_storage())
- real_storage = None if key is None else key_to_real_storage.get(key)
- if real_storage is None:
- out.append(t)
- continue
- unhinted = False
- def map_symint(s: torch.SymInt | int) -> int:
- nonlocal unhinted
- if not isinstance(s, torch.SymInt):
- return s
- unhinted = unhinted if not unhinted else s.node.has_hint()
- return s.node.hint
- stor_offset = map_symint(t.storage_offset())
- size = [map_symint(s) for s in t.shape]
- stride = [map_symint(s) for s in t.stride()]
- if unhinted:
- out.append(t)
- continue
- new_tensor = torch.empty(
- [],
- dtype=t.dtype,
- device=t.device,
- )
- new_tensor.set_(
- real_storage,
- storage_offset=stor_offset,
- size=size,
- stride=stride,
- )
- out.append(new_tensor.clone())
- return out
- def _check_fake_real_tensors(
- real_out: torch.Tensor,
- fake_out: FakeTensor,
- context: str = "",
- sizes: bool = True,
- strides: bool = False,
- storage_offset: bool = True,
- requires_grad: bool = True,
- ) -> None:
- if requires_grad:
- if real_out.requires_grad != fake_out.requires_grad:
- raise MetadataMismatchError(
- f"{context} mismatched requires_grad-ness of outputs. "
- f"This usually means that you have added autograd support "
- f"for your operator at a dispatch key other than Autograd, "
- f"which will lead to problems"
- )
- if torch._C._has_storage(real_out):
- r_offset = real_out.storage_offset()
- f_offset = fake_out.storage_offset()
- if r_offset != f_offset:
- raise MetadataMismatchError(f"{context} mismatched storage offset")
- torch._prims.utils.compare_tensor_meta(
- real_out,
- fake_out,
- check_sizes=sizes,
- check_strides=strides,
- allow_rhs_unbacked=True,
- )
- class CrossRefFakeMode(TorchDispatchMode):
- def __init__(
- self,
- ignore_op_fn: Callable[[OpOverload], bool] | None = None,
- *,
- check_strides: bool = True,
- check_aliasing: bool = True,
- only_check_ops_with_meta: bool = True,
- ) -> None:
- super().__init__()
- self.ignore_op_fn = (
- ignore_op_fn if ignore_op_fn is not None else lambda fn: False
- )
- self.check_strides = check_strides
- self.check_aliasing = check_aliasing
- self.only_check_ops_with_meta = only_check_ops_with_meta
- def __torch_dispatch__(
- self,
- func: OpOverload,
- types: Sequence[type],
- args: Sequence[object] = (),
- kwargs: Mapping[str, object] | None = None,
- ) -> object:
- kwargs = kwargs or {}
- fake_r = None
- fake_args: Sequence[object] = ()
- fake_kwargs: Mapping[str, object] = {}
- # empty_like excluded for now due to sparse complex
- # aten._to_dense.default this one is getting called with csc
- if (
- func
- not in (
- aten.lift_fresh.default,
- aten.lift_fresh_copy.default,
- aten.set_.source_Storage_storage_offset,
- )
- and not self.ignore_op_fn(func)
- and (
- not self.only_check_ops_with_meta
- or torch._subclasses.fake_impls.has_meta(func)
- )
- and torch.Tag.dynamic_output_shape not in func.tags
- and torch.Tag.inplace_view not in func.tags
- and torch.Tag.data_dependent_output not in func.tags
- ):
- # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
- try:
- # TODO: enable_python_dispatcher() here
- with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
- fake_args, fake_kwargs = pytree.tree_map_only(
- torch.Tensor,
- functools.partial(fake_mode.from_tensor, static_shapes=True),
- (args, kwargs),
- )
- with warnings.catch_warnings():
- fake_r = func(*fake_args, **fake_kwargs)
- except UnsupportedFakeTensorException:
- pass
- context = (
- f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
- f"found"
- )
- r = func(*args, **kwargs)
- if fake_r is not None:
- r_flat = pytree.tree_leaves(r)
- f_flat = pytree.tree_leaves(fake_r)
- if len(f_flat) != len(r_flat):
- raise AssertionError(
- f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
- )
- if self.check_aliasing:
- _check_alias_info(
- context, r, (args, kwargs), fake_r, (fake_args, fake_kwargs)
- )
- for idx, (r_out, f_out) in enumerate(
- zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
- ):
- r_is_ten = isinstance(r_out, torch.Tensor)
- if r_is_ten != isinstance(f_out, torch.Tensor):
- raise AssertionError(
- f"{context} mismatched number of tensor outputs"
- )
- if r_is_ten:
- try:
- _check_fake_real_tensors(
- r_out,
- f_out,
- sizes=True,
- strides=self.check_strides,
- storage_offset=True,
- requires_grad=True,
- )
- except Exception as e:
- if is_sdpa_error(func, idx, e):
- continue
- error_message = (
- f"{context} mismatched tensor metadata: {e}"
- if len(r_flat) == 1
- else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
- )
- raise MetadataMismatchError(error_message) from e
- return r
|