fake_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from __future__ import annotations
  2. import functools
  3. import warnings
  4. from typing import Any, TYPE_CHECKING
  5. import torch
  6. import torch.utils._pytree as pytree
  7. from torch._subclasses.fake_tensor import (
  8. FakeTensor,
  9. FakeTensorMode,
  10. MetadataMismatchError,
  11. tree_flatten_only,
  12. UnsupportedFakeTensorException,
  13. )
  14. from torch.utils._python_dispatch import TorchDispatchMode
  15. if TYPE_CHECKING:
  16. from collections.abc import Callable, Mapping, Sequence
  17. from torch._ops import OpOverload
  18. from torch.utils._pytree import PyTree
  19. aten = torch._ops.ops.aten
  20. def outputs_alias_inputs(outputs: PyTree, inputs: PyTree) -> bool:
  21. input_storages = {
  22. inp._typed_storage()._cdata
  23. for inp in tree_flatten_only(torch.Tensor, inputs)
  24. if torch._C._has_storage(inp)
  25. }
  26. return any(
  27. torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
  28. for out in tree_flatten_only(torch.Tensor, outputs)
  29. )
  30. def outputs_are_inputs(outputs: PyTree, inputs: PyTree) -> bool:
  31. input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
  32. return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
  33. def output_alias_each_other(outputs: PyTree) -> bool:
  34. storages = set()
  35. for out in tree_flatten_only(torch.Tensor, outputs):
  36. if not torch._C._has_storage(out):
  37. continue
  38. stor = out._typed_storage()._cdata
  39. if stor in storages:
  40. return True
  41. storages.add(stor)
  42. return False
  43. def _check_alias_info(
  44. context: str,
  45. real_out: PyTree,
  46. real_in: PyTree,
  47. fake_out: PyTree,
  48. fake_in: PyTree,
  49. ) -> None:
  50. r_aliasing = outputs_alias_inputs(real_out, real_in)
  51. f_aliasing = outputs_alias_inputs(fake_out, fake_in)
  52. if r_aliasing != f_aliasing:
  53. raise MetadataMismatchError(
  54. f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
  55. )
  56. r_identity_eq = outputs_are_inputs(real_out, real_in)
  57. f_identity_eq = outputs_are_inputs(fake_out, fake_in)
  58. if r_identity_eq != f_identity_eq:
  59. raise MetadataMismatchError(
  60. f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
  61. )
  62. r_output_alias_each_other = output_alias_each_other(real_out)
  63. f_output_alias_each_other = output_alias_each_other(fake_out)
  64. if r_output_alias_each_other != f_output_alias_each_other:
  65. raise MetadataMismatchError(
  66. f"{context} mismatch in outputs_alias_each_other check "
  67. f"{f_output_alias_each_other} != {r_output_alias_each_other}"
  68. )
  69. def is_sdpa_error(func: OpOverload, idx: int, e: Exception) -> bool:
  70. if (
  71. (
  72. func is aten._scaled_dot_product_flash_attention.default
  73. or func is aten._flash_attention_forward.default
  74. )
  75. and idx in (6, 7)
  76. and "Devices" in repr(e)
  77. ):
  78. return True
  79. if (
  80. (
  81. func is aten._scaled_dot_product_efficient_attention.default
  82. or func is aten._efficient_attention_forward.default
  83. )
  84. and idx in (2, 3)
  85. and "Devices" in repr(e)
  86. ):
  87. return True
  88. if (
  89. func is aten._scaled_dot_product_cudnn_attention.default
  90. and idx in (6, 7)
  91. and "Devices" in repr(e)
  92. ):
  93. return True
  94. return False
  95. def try_convert_fake_to_real(
  96. ten_list: list[FakeTensor | Any],
  97. ) -> list[FakeTensor | torch.Tensor | Any]:
  98. """
  99. Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up
  100. the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will
  101. remain in the list.
  102. Note: this is not currently optimized (makes copies of the meta converter internal dictionaries)
  103. """
  104. fake_tensor = next(
  105. (item for item in ten_list if isinstance(item, FakeTensor)), None
  106. )
  107. if fake_tensor is None:
  108. return ten_list
  109. fake_mode = fake_tensor.fake_mode
  110. meta_converter = fake_mode.fake_tensor_converter.meta_converter
  111. desc = meta_converter.describer
  112. storage_to_key = {v: k for k, v in meta_converter.storage_memo.items()}
  113. key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()}
  114. out = []
  115. for t in ten_list:
  116. if not isinstance(t, FakeTensor) or t.layout != torch.strided:
  117. out.append(t)
  118. continue
  119. key = storage_to_key.get(t.untyped_storage())
  120. real_storage = None if key is None else key_to_real_storage.get(key)
  121. if real_storage is None:
  122. out.append(t)
  123. continue
  124. unhinted = False
  125. def map_symint(s: torch.SymInt | int) -> int:
  126. nonlocal unhinted
  127. if not isinstance(s, torch.SymInt):
  128. return s
  129. unhinted = unhinted if not unhinted else s.node.has_hint()
  130. return s.node.hint
  131. stor_offset = map_symint(t.storage_offset())
  132. size = [map_symint(s) for s in t.shape]
  133. stride = [map_symint(s) for s in t.stride()]
  134. if unhinted:
  135. out.append(t)
  136. continue
  137. new_tensor = torch.empty(
  138. [],
  139. dtype=t.dtype,
  140. device=t.device,
  141. )
  142. new_tensor.set_(
  143. real_storage,
  144. storage_offset=stor_offset,
  145. size=size,
  146. stride=stride,
  147. )
  148. out.append(new_tensor.clone())
  149. return out
  150. def _check_fake_real_tensors(
  151. real_out: torch.Tensor,
  152. fake_out: FakeTensor,
  153. context: str = "",
  154. sizes: bool = True,
  155. strides: bool = False,
  156. storage_offset: bool = True,
  157. requires_grad: bool = True,
  158. ) -> None:
  159. if requires_grad:
  160. if real_out.requires_grad != fake_out.requires_grad:
  161. raise MetadataMismatchError(
  162. f"{context} mismatched requires_grad-ness of outputs. "
  163. f"This usually means that you have added autograd support "
  164. f"for your operator at a dispatch key other than Autograd, "
  165. f"which will lead to problems"
  166. )
  167. if torch._C._has_storage(real_out):
  168. r_offset = real_out.storage_offset()
  169. f_offset = fake_out.storage_offset()
  170. if r_offset != f_offset:
  171. raise MetadataMismatchError(f"{context} mismatched storage offset")
  172. torch._prims.utils.compare_tensor_meta(
  173. real_out,
  174. fake_out,
  175. check_sizes=sizes,
  176. check_strides=strides,
  177. allow_rhs_unbacked=True,
  178. )
  179. class CrossRefFakeMode(TorchDispatchMode):
  180. def __init__(
  181. self,
  182. ignore_op_fn: Callable[[OpOverload], bool] | None = None,
  183. *,
  184. check_strides: bool = True,
  185. check_aliasing: bool = True,
  186. only_check_ops_with_meta: bool = True,
  187. ) -> None:
  188. super().__init__()
  189. self.ignore_op_fn = (
  190. ignore_op_fn if ignore_op_fn is not None else lambda fn: False
  191. )
  192. self.check_strides = check_strides
  193. self.check_aliasing = check_aliasing
  194. self.only_check_ops_with_meta = only_check_ops_with_meta
  195. def __torch_dispatch__(
  196. self,
  197. func: OpOverload,
  198. types: Sequence[type],
  199. args: Sequence[object] = (),
  200. kwargs: Mapping[str, object] | None = None,
  201. ) -> object:
  202. kwargs = kwargs or {}
  203. fake_r = None
  204. fake_args: Sequence[object] = ()
  205. fake_kwargs: Mapping[str, object] = {}
  206. # empty_like excluded for now due to sparse complex
  207. # aten._to_dense.default this one is getting called with csc
  208. if (
  209. func
  210. not in (
  211. aten.lift_fresh.default,
  212. aten.lift_fresh_copy.default,
  213. aten.set_.source_Storage_storage_offset,
  214. )
  215. and not self.ignore_op_fn(func)
  216. and (
  217. not self.only_check_ops_with_meta
  218. or torch._subclasses.fake_impls.has_meta(func)
  219. )
  220. and torch.Tag.dynamic_output_shape not in func.tags
  221. and torch.Tag.inplace_view not in func.tags
  222. and torch.Tag.data_dependent_output not in func.tags
  223. ):
  224. # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
  225. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  226. try:
  227. # TODO: enable_python_dispatcher() here
  228. with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
  229. fake_args, fake_kwargs = pytree.tree_map_only(
  230. torch.Tensor,
  231. functools.partial(fake_mode.from_tensor, static_shapes=True),
  232. (args, kwargs),
  233. )
  234. with warnings.catch_warnings():
  235. fake_r = func(*fake_args, **fake_kwargs)
  236. except UnsupportedFakeTensorException:
  237. pass
  238. context = (
  239. f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
  240. f"found"
  241. )
  242. r = func(*args, **kwargs)
  243. if fake_r is not None:
  244. r_flat = pytree.tree_leaves(r)
  245. f_flat = pytree.tree_leaves(fake_r)
  246. if len(f_flat) != len(r_flat):
  247. raise AssertionError(
  248. f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
  249. )
  250. if self.check_aliasing:
  251. _check_alias_info(
  252. context, r, (args, kwargs), fake_r, (fake_args, fake_kwargs)
  253. )
  254. for idx, (r_out, f_out) in enumerate(
  255. zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
  256. ):
  257. r_is_ten = isinstance(r_out, torch.Tensor)
  258. if r_is_ten != isinstance(f_out, torch.Tensor):
  259. raise AssertionError(
  260. f"{context} mismatched number of tensor outputs"
  261. )
  262. if r_is_ten:
  263. try:
  264. _check_fake_real_tensors(
  265. r_out,
  266. f_out,
  267. sizes=True,
  268. strides=self.check_strides,
  269. storage_offset=True,
  270. requires_grad=True,
  271. )
  272. except Exception as e:
  273. if is_sdpa_error(func, idx, e):
  274. continue
  275. error_message = (
  276. f"{context} mismatched tensor metadata: {e}"
  277. if len(r_flat) == 1
  278. else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
  279. )
  280. raise MetadataMismatchError(error_message) from e
  281. return r