fake_impls.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693
  1. from __future__ import annotations
  2. import functools
  3. import itertools
  4. import math
  5. import operator
  6. import sys
  7. from functools import reduce
  8. from typing import Any, cast as typing_cast, TYPE_CHECKING, TypeVar, Union
  9. from typing_extensions import ParamSpec
  10. import torch
  11. import torch._custom_op
  12. import torch._logging
  13. import torch._prims_common as utils
  14. from torch._dispatch.python import no_python_dispatcher
  15. from torch._ops import OpOverload
  16. from torch._prims_common import (
  17. canonicalize_dim,
  18. elementwise_dtypes,
  19. ELEMENTWISE_TYPE_PROMOTION_KIND,
  20. is_boolean_dtype,
  21. is_contiguous,
  22. is_contiguous_for_memory_format_or_false,
  23. is_contiguous_or_false,
  24. is_float_dtype,
  25. is_integer_dtype,
  26. make_contiguous_strides_for,
  27. ShapeType,
  28. )
  29. from torch._subclasses.fake_tensor import (
  30. DataDependentOutputException,
  31. DynamicOutputShapeException,
  32. FakeTensor,
  33. in_kernel_invocation_manager,
  34. run_fallback_kernel,
  35. UnsupportedOperatorException,
  36. )
  37. from torch.fx.operator_schemas import _normalize_function_or_error
  38. from torch.utils._stats import count_label
  39. if TYPE_CHECKING:
  40. from collections.abc import Callable, Sequence
  41. from torch._subclasses.fake_tensor import FakeTensorMode
  42. from torch.types import IntLikeType
  43. FakeTensorLike = Union[FakeTensor, torch.Tensor]
  44. _P = ParamSpec("_P")
  45. _R = TypeVar("_R")
  46. _T = TypeVar("_T")
  47. pytree = torch.utils._pytree
  48. __all__ = [
  49. "op_implementations_checks",
  50. "get_fast_op_impls",
  51. "stride_incorrect_op",
  52. "has_meta",
  53. ]
  54. # pyrefly: ignore [implicit-any]
  55. op_implementations_dict = {}
  56. # pyrefly: ignore [implicit-any]
  57. op_implementations_checks = []
  58. aten = torch._ops.ops.aten
  59. def ordered_set(*items: _T) -> dict[_T, bool]:
  60. return dict.fromkeys(items, True)
  61. # This function indicates if the backend device
  62. # supports non-contiguous tensors
  63. def is_noncontiguous_supported(device: torch.device) -> bool:
  64. return device.type != "hpu"
  65. _like_tensor_constructors = ordered_set(
  66. aten.empty_like.default,
  67. aten.empty_like.out,
  68. aten.full_like.default,
  69. aten.full_like.out,
  70. aten.ones_like.default,
  71. aten.ones_like.out,
  72. aten.rand_like.default,
  73. aten.rand_like.generator,
  74. aten.rand_like.out,
  75. aten.rand_like.generator_out,
  76. aten.randn_like.default,
  77. aten.randn_like.generator,
  78. aten.randn_like.out,
  79. aten.randn_like.generator_out,
  80. aten.randint_like.default,
  81. aten.randint_like.generator,
  82. aten.randint_like.Tensor,
  83. aten.randint_like.Tensor_generator,
  84. aten.randint_like.Tensor_out,
  85. aten.randint_like.Tensor_generator_out,
  86. aten.randint_like.out,
  87. aten.randint_like.generator_out,
  88. aten.randint_like.low_dtype,
  89. aten.randint_like.low_generator_dtype,
  90. aten.randint_like.low_dtype_out,
  91. aten.randint_like.low_generator_dtype_out,
  92. aten.zeros_like.default,
  93. aten.zeros_like.out,
  94. aten.new_empty.default,
  95. aten.new_empty.out,
  96. aten.new_empty_strided.default,
  97. aten.new_empty_strided.out,
  98. aten.new_full.default,
  99. aten.new_full.out,
  100. aten.new_zeros.default,
  101. aten.new_zeros.out,
  102. aten.new_ones.default,
  103. aten.new_ones.out,
  104. )
  105. _device_not_kwarg_ops = ordered_set(
  106. aten._resize_output_.default,
  107. aten._nested_tensor_from_tensor_list.default,
  108. aten._nested_tensor_from_tensor_list.out,
  109. aten.pin_memory.default,
  110. aten.to.device,
  111. aten.to.prim_Device,
  112. aten.is_pinned.default,
  113. aten._pin_memory.default,
  114. aten._pin_memory.out,
  115. aten._resize_output.default,
  116. aten._resize_output.out,
  117. )
  118. # this op is never actually used
  119. _non_kwarg_device_constructors = (aten._list_to_tensor,)
  120. def contains_tensor_types(type_: Any) -> bool:
  121. tensor_type = torch._C.TensorType.get()
  122. return type_.isSubtypeOf(tensor_type) or any(
  123. contains_tensor_types(e) for e in type_.containedTypes()
  124. )
  125. @functools.cache
  126. def _is_tensor_constructor(func: OpOverload) -> bool:
  127. if not isinstance(func, OpOverload):
  128. raise AssertionError(f"func must be an OpOverload, got {type(func)}")
  129. schema = func._schema
  130. if any(contains_tensor_types(arg.type) for arg in schema.arguments):
  131. return False
  132. # TODO: no real reason to restrict multiple outputs
  133. return (
  134. len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
  135. )
  136. def register_op_impl(
  137. run_impl_check: Callable[[OpOverload], bool]
  138. | OpOverload
  139. | list[OpOverload]
  140. | tuple[OpOverload, ...],
  141. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  142. def impl_decorator(op_impl: Callable[_P, _R]) -> Callable[_P, _R]:
  143. if isinstance(run_impl_check, OpOverload):
  144. if run_impl_check in op_implementations_dict:
  145. raise AssertionError(f"duplicate registration: {run_impl_check}")
  146. op_implementations_dict[run_impl_check] = op_impl
  147. elif isinstance(run_impl_check, (list, tuple)):
  148. for op in run_impl_check:
  149. register_op_impl(op)(op_impl)
  150. else:
  151. if not callable(run_impl_check):
  152. raise AssertionError(
  153. f"run_impl_check must be callable, got {type(run_impl_check)}"
  154. )
  155. op_implementations_checks.append((run_impl_check, op_impl))
  156. return op_impl
  157. return impl_decorator
  158. def _is_op_registered_to_fake_rule(op: OpOverload) -> bool:
  159. return op in op_implementations_dict
  160. def _deregister_op_impl(op: OpOverload) -> None:
  161. op_implementations_dict.pop(op, None)
  162. for check, impl in op_implementations_checks:
  163. if check is op:
  164. op_implementations_checks.remove((check, impl))
  165. break
  166. @register_op_impl(op_implementations_dict.__contains__)
  167. def dispatch_to_op_implementations_dict(
  168. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  169. ) -> Any:
  170. return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
  171. @register_op_impl(_is_tensor_constructor)
  172. @register_op_impl([*_like_tensor_constructors])
  173. def constructors(
  174. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  175. ) -> FakeTensor:
  176. if func in _non_kwarg_device_constructors:
  177. raise AssertionError(
  178. f"func must not be in _non_kwarg_device_constructors, got {func}"
  179. )
  180. _, new_kwargs = _normalize_function_or_error(
  181. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  182. )
  183. if "names" in kwargs:
  184. # REASON: "torch.compile doesn't support named tensors"
  185. raise UnsupportedOperatorException(func)
  186. if func in _like_tensor_constructors:
  187. default_device = new_kwargs["input"].device
  188. # TODO: file issue
  189. args = (new_kwargs.pop("input"),)
  190. else:
  191. # cpu is default device if none is specified
  192. default_device = torch.device("cpu")
  193. args = ()
  194. out_device = new_kwargs.pop("device", None)
  195. out_device = out_device if out_device is not None else default_device
  196. new_kwargs["device"] = torch.device("meta")
  197. # _like constructors have fake tensor inputs (maybe this causes the non-like
  198. # to fail? hmmm)
  199. with in_kernel_invocation_manager(fake_mode):
  200. r = func(*args, **new_kwargs)
  201. return FakeTensor(fake_mode, r, out_device)
  202. @register_op_impl(aten.is_pinned.default)
  203. def non_kwarg_is_pinned(
  204. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  205. ) -> bool:
  206. _, new_kwargs = _normalize_function_or_error(
  207. func, args, kwargs, normalize_to_only_use_kwargs=True
  208. )
  209. inp = new_kwargs.pop("input")
  210. # we'll ignore device argument because it is deprecated and not
  211. # actually used by is_pinned.
  212. with in_kernel_invocation_manager(fake_mode):
  213. r = func(inp)
  214. return r
  215. # Legacy profiler ops return Tensors but don't follow tensor constructor patterns
  216. # They take string arguments and should not have device/dtype parameters added
  217. @register_op_impl(torch.ops.profiler._record_function_enter.default)
  218. def _record_function_enter(
  219. fake_mode: FakeTensorMode, func: OpOverload, name: str, args: object | None = None
  220. ) -> FakeTensor:
  221. # Call the real implementation to get a real handle tensor
  222. with in_kernel_invocation_manager(fake_mode):
  223. real_handle = func(name, args)
  224. # Create a meta tensor with the same properties as the real handle
  225. meta_handle = torch.empty_like(real_handle, device="meta")
  226. # Wrap it as a FakeTensor
  227. return FakeTensor(fake_mode, meta_handle, torch.device("cpu"))
  228. @register_op_impl(torch.ops.profiler._record_function_exit.default)
  229. def _record_function_exit(
  230. fake_mode: FakeTensorMode, func: OpOverload, handle: Any
  231. ) -> None:
  232. # Exit doesn't return anything and doesn't need to do anything for fake tensors
  233. # Just return None (the actual return type is void)
  234. pass
  235. @register_op_impl(torch.ops.profiler._record_function_enter_new.default)
  236. def _record_function_enter_new(
  237. fake_mode: FakeTensorMode, func: OpOverload, name: str, args: object | None = None
  238. ) -> Any:
  239. # Call the real implementation - returns a custom class, not a tensor
  240. # Just pass through without wrapping
  241. with in_kernel_invocation_manager(fake_mode):
  242. return func(name, args)
  243. @register_op_impl(aten.to.prim_Device)
  244. @register_op_impl(aten.to.device)
  245. def non_kwarg_to(
  246. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  247. ) -> FakeTensor:
  248. _, new_kwargs = _normalize_function_or_error(
  249. func, args, kwargs, normalize_to_only_use_kwargs=True
  250. )
  251. input_device = new_kwargs["device"]
  252. out_device = input_device if input_device else new_kwargs["input"].device
  253. new_kwargs["device"] = torch.device("meta")
  254. inp = new_kwargs.pop("input")
  255. with in_kernel_invocation_manager(fake_mode):
  256. r = func(inp, **new_kwargs)
  257. # TODO: I think this does the wrong thing if r is inp
  258. return fake_mode.fake_tensor_converter.from_meta_and_device(
  259. fake_mode, r, out_device
  260. )
  261. def stride_incorrect_op(op: OpOverload) -> bool:
  262. return False
  263. # These operators have meta implementations with incorrect strides
  264. @register_op_impl(stride_incorrect_op)
  265. def workaround_stride_incorrect_op(
  266. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  267. ) -> FakeTensor:
  268. # This is a workaround for meta implementations with incorrect strides
  269. def is_symbolic(x: object) -> bool:
  270. if isinstance(x, FakeTensor):
  271. return x._has_symbolic_sizes_strides
  272. if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
  273. return True
  274. return False
  275. # For static shapes, we can fall back to eager for the real strides
  276. if fake_mode.allow_fallback_kernels:
  277. require_dynamic = any(
  278. is_symbolic(x) for x in itertools.chain(args, kwargs.values())
  279. )
  280. if not require_dynamic:
  281. flat_args, args_spec = pytree.tree_flatten((args, kwargs))
  282. return run_fallback_kernel(
  283. fake_mode,
  284. func,
  285. flat_args,
  286. args_spec,
  287. # TODO: refactor to lambda so we don't instantiate extra errors before
  288. # calling
  289. RuntimeError("Cannot run fallback kernel for stride_incorrect_op"),
  290. )
  291. raise UnsupportedOperatorException(func)
  292. # Dont default to default device handling,
  293. # since the device of `the_template` is ignored
  294. @register_op_impl(aten.resize_as_.default)
  295. def resize_as_(
  296. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  297. ) -> FakeTensor:
  298. with in_kernel_invocation_manager(fake_mode):
  299. return func(*args, **kwargs)
  300. @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
  301. def _sparse_coo_tensor_with_dims_and_tensors(
  302. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  303. ) -> FakeTensor:
  304. return constructors(fake_mode, func, *args, **kwargs)
  305. # index.Tensor data-dependent in only some conditions
  306. @register_op_impl(
  307. lambda func: torch.Tag.dynamic_output_shape in func.tags
  308. and func
  309. not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
  310. )
  311. def dyn_shape(
  312. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  313. ) -> None:
  314. raise DynamicOutputShapeException(func)
  315. def _unique(
  316. fake_mode: FakeTensorMode,
  317. func: OpOverload,
  318. arg: FakeTensor,
  319. dim: int | None,
  320. sorted: bool = True,
  321. return_inverse: bool = False,
  322. return_counts: bool = False,
  323. *,
  324. unique_consecutive: bool = False,
  325. ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
  326. if (
  327. fake_mode.shape_env is None
  328. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  329. ):
  330. # Without symints/symfloats, cannot handle this
  331. raise DynamicOutputShapeException(func)
  332. nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo
  333. # Do not use a memo for unique_dim
  334. if dim is not None or nnz is None:
  335. # Avoid importing sympy at a module level
  336. from torch.fx.experimental.symbolic_shapes import (
  337. _constrain_range_for_size,
  338. has_free_symbols,
  339. )
  340. if not has_free_symbols(arg.numel()) and arg.numel() == 0:
  341. # If numel is zero, then the output size must be zero.
  342. # In this case, we must not allocate an unbacked SymInt,
  343. # because if we do, it will immediately get refined to
  344. # zero, but this will be inconsistent with size oblivious
  345. # tests (which will continue to claim that the unbacked
  346. # symint cannot equal zero). We could also unconditionally
  347. # allocate an unbacked SymInt and not refine its range,
  348. # but this seems more precise.
  349. nnz = 0
  350. else:
  351. nnz = fake_mode.shape_env.create_unbacked_symint()
  352. maxval = sys.maxsize - 1
  353. numel = arg.numel() if dim is None else arg.size(dim)
  354. if not has_free_symbols(numel):
  355. maxval = int(numel)
  356. _constrain_range_for_size(nnz, max=maxval)
  357. if dim is None:
  358. if unique_consecutive:
  359. arg.unique_consecutive_memo = nnz
  360. else:
  361. arg.unique_memo = nnz
  362. if dim is None:
  363. # pyrefly: ignore[no-matching-overload]
  364. ret = [arg.new_empty((nnz,))]
  365. else:
  366. # pyrefly: ignore[no-matching-overload]
  367. ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
  368. return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
  369. if return_inverse or return_if_dim_and_cpu:
  370. inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
  371. else:
  372. inverse = arg.new_empty(0)
  373. ret.append(inverse)
  374. if return_counts or return_if_dim_and_cpu:
  375. counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
  376. else:
  377. counts = arg.new_empty(0)
  378. ret.append(counts)
  379. return tuple(ret)
  380. @register_op_impl(aten._unique2.default)
  381. def unique2(
  382. fake_mode: FakeTensorMode,
  383. func: OpOverload,
  384. arg: FakeTensor,
  385. sorted: bool = True,
  386. return_inverse: bool = False,
  387. return_counts: bool = False,
  388. ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
  389. return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
  390. @register_op_impl(aten.select.int)
  391. def meta_select(
  392. fake_mode: FakeTensorMode,
  393. func: OpOverload,
  394. self: FakeTensor,
  395. dim: int,
  396. index: IntLikeType,
  397. ) -> FakeTensor:
  398. from torch.fx.experimental.symbolic_shapes import guard_or_false
  399. if self.is_sparse:
  400. return NotImplemented
  401. ndim = self.dim()
  402. torch._check_index(
  403. ndim != 0,
  404. lambda: "select() cannot be applied to a 0-dim tensor.",
  405. )
  406. dim = dim if dim >= 0 else dim + ndim
  407. size = self.size(dim)
  408. new_size = list(self.size())
  409. new_stride = list(self.stride())
  410. new_storage_offset = None
  411. if guard_or_false(index >= 0):
  412. new_storage_offset = self.storage_offset() + index * new_stride[dim]
  413. elif guard_or_false(index < 0):
  414. new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim]
  415. if new_storage_offset is None:
  416. if fake_mode.shape_env is None or (
  417. not fake_mode.shape_env.allow_scalar_outputs
  418. and not fake_mode.allow_scalar_outputs
  419. ):
  420. raise DataDependentOutputException(func)
  421. # index is data-dependent, we do not know which index we are accessing it could be index or index+size!
  422. # we assign a new data-dependent symbol for the storage offset.
  423. new_storage_offset = fake_mode.shape_env.create_unbacked_symint()
  424. del new_size[dim]
  425. del new_stride[dim]
  426. if new_storage_offset is None:
  427. raise AssertionError("new_storage_offset must not be None")
  428. # pyrefly: ignore[bad-return]
  429. return self.as_strided(new_size, new_stride, new_storage_offset)
  430. @register_op_impl(aten.unique_dim.default)
  431. def unique_dim(
  432. fake_mode: FakeTensorMode,
  433. func: OpOverload,
  434. arg: FakeTensor,
  435. dim: int,
  436. sorted: bool = True,
  437. return_inverse: bool = False,
  438. return_counts: bool = False,
  439. ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
  440. return _unique(
  441. fake_mode,
  442. func,
  443. arg,
  444. # normalize dim to be non-negative
  445. dim if dim >= 0 else dim % max(arg.ndim, 1),
  446. sorted,
  447. return_inverse,
  448. return_counts,
  449. )
  450. @register_op_impl(aten.unique_consecutive.default)
  451. def unique_consecutive(
  452. fake_mode: FakeTensorMode,
  453. func: OpOverload,
  454. arg: FakeTensor,
  455. return_inverse: bool = False,
  456. return_counts: bool = False,
  457. dim: int | None = None,
  458. ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
  459. return _unique(
  460. fake_mode,
  461. func,
  462. arg,
  463. dim,
  464. False,
  465. return_inverse,
  466. return_counts,
  467. unique_consecutive=True,
  468. )
  469. # This function is python match of computeStride_impl in TensorUtils.cpp
  470. def _compute_stride(
  471. old_shape: Sequence[IntLikeType],
  472. old_stride: Sequence[IntLikeType],
  473. new_shape: Sequence[IntLikeType],
  474. size_oblivious: bool = False,
  475. ) -> list[IntLikeType] | None:
  476. from torch.fx.experimental.symbolic_shapes import (
  477. guard_or_false,
  478. guard_or_true,
  479. sym_eq,
  480. )
  481. def maybe_guard_or_false(x: Any) -> Any:
  482. if size_oblivious:
  483. return guard_or_false(x)
  484. return x
  485. def maybe_guard_or_true(x: Any) -> Any:
  486. if size_oblivious:
  487. return guard_or_true(x)
  488. return x
  489. if len(old_shape) == 0:
  490. return [1] * len(new_shape)
  491. numel = reduce(operator.mul, old_shape, 1)
  492. zero_numel = maybe_guard_or_false(numel == 0)
  493. if zero_numel and maybe_guard_or_false(sym_eq(old_shape, new_shape)):
  494. return list(old_stride)
  495. new_stride: list[IntLikeType] = [0] * len(new_shape)
  496. if zero_numel:
  497. for view_d in range(len(new_shape) - 1, -1, -1):
  498. if view_d == len(new_shape) - 1:
  499. new_stride[view_d] = 1
  500. else:
  501. new_stride[view_d] = (
  502. max(new_shape[view_d + 1], 1) * new_stride[view_d + 1]
  503. )
  504. return new_stride
  505. view_d = len(new_shape) - 1
  506. # Annotate type here to support type checking
  507. chunk_base_stride: IntLikeType = old_stride[-1]
  508. tensor_numel: IntLikeType = 1
  509. view_numel: IntLikeType = 1
  510. for tensor_d in range(len(old_shape) - 1, -1, -1):
  511. tensor_numel *= old_shape[tensor_d]
  512. if tensor_d == 0 or (
  513. maybe_guard_or_true(old_shape[tensor_d - 1] != 1)
  514. and maybe_guard_or_true(
  515. old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride
  516. )
  517. ):
  518. while view_d >= 0 and (
  519. maybe_guard_or_true(view_numel < tensor_numel)
  520. or maybe_guard_or_false(new_shape[view_d] == 1)
  521. ):
  522. new_stride[view_d] = view_numel * chunk_base_stride
  523. view_numel *= new_shape[view_d]
  524. view_d -= 1
  525. if maybe_guard_or_true(view_numel != tensor_numel):
  526. return None
  527. if tensor_d > 0:
  528. chunk_base_stride = old_stride[tensor_d - 1]
  529. tensor_numel = 1
  530. view_numel = 1
  531. if view_d != -1:
  532. return None
  533. return new_stride
  534. def _view_has_unbacked_input(
  535. a: torch.Tensor, shape: ShapeType | tuple[ShapeType]
  536. ) -> bool:
  537. from torch.fx.experimental.symbolic_shapes import has_hint
  538. shape = utils.extract_shape_from_varargs(shape, validate=False)
  539. return (
  540. any(not has_hint(s) for s in a.size())
  541. or any(not has_hint(s) for s in a.stride())
  542. or any(not has_hint(s) for s in shape)
  543. )
  544. def _view_unbacked_meta(
  545. a: torch.Tensor,
  546. shape: ShapeType | tuple[ShapeType],
  547. size_oblivious_enabled: bool = True,
  548. ) -> torch.Tensor:
  549. from torch._prims import view_of
  550. from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq
  551. # Creates a valid shape
  552. shape = utils.extract_shape_from_varargs(shape, validate=False)
  553. # Reshape may be given a shape with a -1 length
  554. # This indicates that the dimension's length should be inferred
  555. shape = utils.infer_size(shape, a.numel())
  556. # Special-cases reshaping zero dim tensors
  557. if a.ndim == 0:
  558. _a = a
  559. for length in shape:
  560. torch._check(length == 1)
  561. _a = torch._refs.unsqueeze(_a, -1)
  562. if _a is a:
  563. return view_of(a)
  564. else:
  565. return _a # type: ignore[return-value]
  566. # Special-cases reshaping to zero dim tensors
  567. if len(shape) == 0:
  568. _a = a
  569. for length in a.shape:
  570. torch._check(length == 1)
  571. _a = torch._refs.squeeze(_a, -1)
  572. if _a is a:
  573. return view_of(a)
  574. else:
  575. return _a # type: ignore[return-value]
  576. shape_numel = reduce(operator.mul, shape, 1)
  577. torch._check(
  578. a.numel() == shape_numel,
  579. lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
  580. )
  581. if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)):
  582. return view_of(a)
  583. if is_contiguous_or_false(a) if size_oblivious_enabled else is_contiguous(a):
  584. strides = make_contiguous_strides_for(shape)
  585. return a.as_strided(shape, strides) # type: ignore[return-value]
  586. new_strides = _compute_stride(
  587. a.size(), a.stride(), shape, size_oblivious=size_oblivious_enabled
  588. )
  589. if new_strides is not None:
  590. return a.as_strided(shape, new_strides) # type: ignore[return-value]
  591. # If we fail to do size oblivious view, and backed_size_oblivious was on,
  592. # then we redo everything by looking at hints and guarding instead of failing.
  593. # Also if the expression has unbacked symbols, then we run again with size_oblivious_enabled=False
  594. # to throw a data dependent error.
  595. if size_oblivious_enabled and (
  596. torch.fx.experimental._config.backed_size_oblivious
  597. or _view_has_unbacked_input(a, shape)
  598. ):
  599. return _view_unbacked_meta(a, shape, size_oblivious_enabled=False)
  600. msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
  601. raise ValueError(msg)
  602. @register_op_impl(aten._reshape_copy.default)
  603. def _reshape_copy(
  604. fake_mode: FakeTensorMode, func: OpOverload, a: FakeTensor, *shape: Any
  605. ) -> FakeTensor | Exception:
  606. if a.is_sparse or a.is_mkldnn:
  607. return NotImplemented
  608. # pyrefly: ignore[bad-argument-count]
  609. shape = utils.infer_size(*shape, a.numel())
  610. if is_contiguous_or_false(a):
  611. view = _view_meta(fake_mode, func, a, *shape)
  612. return typing_cast(
  613. FakeTensor, view.clone(memory_format=torch.contiguous_format)
  614. )
  615. else:
  616. return _view_meta(
  617. fake_mode,
  618. func,
  619. typing_cast(FakeTensor, a.clone(memory_format=torch.contiguous_format)),
  620. *shape,
  621. )
  622. @register_op_impl(aten.view.default)
  623. @register_op_impl(aten._unsafe_view.default)
  624. def _view_meta(
  625. fake_mode: FakeTensorMode,
  626. func: OpOverload,
  627. a: FakeTensor,
  628. *shape: Any,
  629. ) -> FakeTensor:
  630. if torch.fx.experimental._config.backed_size_oblivious or _view_has_unbacked_input(
  631. a, shape
  632. ):
  633. return typing_cast(FakeTensor, _view_unbacked_meta(a, shape))
  634. else:
  635. return typing_cast(
  636. FakeTensor, torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
  637. )
  638. @register_op_impl(aten.view_copy.default)
  639. def _view_meta_copy(
  640. fake_mode: FakeTensorMode,
  641. func: OpOverload,
  642. a: FakeTensor,
  643. *shape: IntLikeType,
  644. out: FakeTensor | None = None,
  645. ) -> FakeTensor:
  646. result = _view_meta(fake_mode, func, a, *shape)
  647. if out is not None:
  648. return result
  649. return pytree.tree_map(
  650. lambda x: x.clone(memory_format=torch.contiguous_format),
  651. result,
  652. )
  653. @register_op_impl(aten.repeat_interleave.Tensor)
  654. def repeat_interleave_tensor(
  655. fake_mode: FakeTensorMode,
  656. func: OpOverload,
  657. repeats: FakeTensor,
  658. output_size: IntLikeType | None = None,
  659. ) -> FakeTensor:
  660. if output_size is None:
  661. if (
  662. fake_mode.shape_env is None
  663. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  664. ):
  665. raise DynamicOutputShapeException(func)
  666. output_size = fake_mode.shape_env.create_unbacked_symint()
  667. # Avoid importing sympy at a module level
  668. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  669. _constrain_range_for_size(output_size)
  670. # TODO: consider a memo
  671. return repeats.new_empty(output_size) # type: ignore[return-value]
  672. @register_op_impl(torch.ops.aten.item.default)
  673. @register_op_impl(torch.ops.aten._local_scalar_dense.default)
  674. def local_scalar_dense(
  675. fake_mode: FakeTensorMode, func: OpOverload, arg: FakeTensor
  676. ) -> int | float | bool | torch.SymInt | torch.SymFloat | torch.SymBool:
  677. if (r := arg.item_memo) is not None:
  678. return r
  679. if fake_mode.shape_env is None or (
  680. not fake_mode.shape_env.allow_scalar_outputs
  681. and not fake_mode.allow_scalar_outputs
  682. ):
  683. # Without symints/symfloats, cannot handle this
  684. raise DataDependentOutputException(func)
  685. if is_float_dtype(arg.dtype):
  686. r = fake_mode.shape_env.create_unbacked_symfloat()
  687. elif is_integer_dtype(arg.dtype):
  688. r = fake_mode.shape_env.create_unbacked_symint()
  689. elif is_boolean_dtype(arg.dtype):
  690. r = fake_mode.shape_env.create_unbacked_symbool()
  691. else:
  692. raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
  693. arg.item_memo = r
  694. return r
  695. @register_op_impl(torch.ops.aten.nonzero_numpy.default)
  696. def nonzero_numpy(
  697. fake_mode: FakeTensorMode, func: OpOverload, arg: FakeTensor
  698. ) -> list[FakeTensor]:
  699. return torch.ops.aten.nonzero.default(arg).unbind(1)
  700. @register_op_impl(torch.ops.aten.nonzero.default)
  701. def nonzero(fake_mode: FakeTensorMode, func: OpOverload, arg: FakeTensor) -> FakeTensor:
  702. if (
  703. fake_mode.shape_env is None
  704. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  705. ):
  706. # Without symints/symfloats, cannot handle this
  707. raise DynamicOutputShapeException(func)
  708. if (nnz := arg.nonzero_memo) is None:
  709. # Avoid importing sympy at a module level
  710. from torch.fx.experimental.symbolic_shapes import (
  711. _constrain_range_for_size,
  712. has_free_symbols,
  713. )
  714. from torch.utils._sympy.numbers import IntInfinity
  715. from torch.utils._sympy.value_ranges import bound_sympy
  716. if not has_free_symbols(arg.numel()) and arg.numel() == 0:
  717. # If numel is zero, then the output size must be zero.
  718. # In this case, we must not allocate an unbacked SymInt,
  719. # because if we do, it will immediately get refined to
  720. # zero, but this will be inconsistent with size oblivious
  721. # tests (which will continue to claim that the unbacked
  722. # symint cannot equal zero). We could also unconditionally
  723. # allocate an unbacked SymInt and not refine its range,
  724. # but this seems more precise.
  725. nnz = 0
  726. else:
  727. nnz = fake_mode.shape_env.create_unbacked_symint()
  728. maxval = sys.maxsize - 1
  729. if not has_free_symbols(arg.numel()):
  730. maxval = int(arg.numel())
  731. else:
  732. prod_node = math.prod(arg.shape).node # type: ignore[union-attr]
  733. prod_range = bound_sympy(
  734. prod_node.expr, prod_node.shape_env.var_to_range
  735. )
  736. if isinstance(prod_range.upper, IntInfinity):
  737. maxval = sys.maxsize - 1
  738. else:
  739. maxval = prod_range.upper
  740. _constrain_range_for_size(nnz, max=maxval)
  741. arg.nonzero_memo = nnz
  742. return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64) # type: ignore[return]
  743. @register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)
  744. def _padded_dense_to_jagged_forward(
  745. fake_mode: FakeTensorMode,
  746. func: OpOverload,
  747. padded: FakeTensor,
  748. offsets: list[FakeTensor],
  749. total_L: IntLikeType | None = None,
  750. ) -> FakeTensor:
  751. # only one jagged dim is supported for now
  752. if len(offsets) != 1:
  753. raise AssertionError(
  754. f"Only one jagged dim is supported, got {len(offsets)} offsets"
  755. )
  756. if not total_L:
  757. if (
  758. fake_mode.shape_env is None
  759. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  760. ):
  761. # Without symints/symfloats, cannot handle this
  762. raise DynamicOutputShapeException(func)
  763. total_L = fake_mode.shape_env.create_unbacked_symint()
  764. maxval = sys.maxsize - 1
  765. # Avoid importing sympy at a module level
  766. from torch.fx.experimental.symbolic_shapes import (
  767. _constrain_range_for_size,
  768. has_free_symbols,
  769. )
  770. if not has_free_symbols(padded.numel()):
  771. maxval = int(padded.numel())
  772. _constrain_range_for_size(total_L, min=0, max=maxval)
  773. output_shape = (total_L, *padded.shape[2:])
  774. return padded.new_empty(output_shape) # type: ignore[return]
  775. def _compute_slice_index(size: IntLikeType, index: IntLikeType) -> IntLikeType | None:
  776. from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and
  777. if guard_or_false(sym_and(index >= 0, index <= size)):
  778. return index
  779. elif guard_or_false(sym_and(index < 0, index >= -size)):
  780. return index + size
  781. elif guard_or_false(index < -size):
  782. return 0
  783. elif guard_or_false(index > size):
  784. return size
  785. return None
  786. @register_op_impl(torch.ops.aten.slice.Tensor)
  787. def slice_forward(
  788. fake_mode: FakeTensorMode,
  789. func: OpOverload,
  790. self: FakeTensor,
  791. dim: int = 0,
  792. start: int | None = None,
  793. end: int | None = None,
  794. step: int = 1,
  795. ) -> FakeTensor:
  796. from torch.fx.experimental.symbolic_shapes import (
  797. guard_or_false,
  798. statically_known_true,
  799. )
  800. shape_env = fake_mode.shape_env
  801. ndim = self.dim()
  802. if ndim == 0:
  803. raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
  804. dim = canonicalize_dim(self.dim(), dim)
  805. sizes = list(self.size())
  806. strides = list(self.stride())
  807. if step <= 0:
  808. raise RuntimeError("slice step must be positive")
  809. # start, end
  810. start_index = 0 if start is None else _compute_slice_index(sizes[dim], start)
  811. end_index = (
  812. sizes[dim]
  813. if statically_known_true(end == sys.maxsize) or end is None
  814. else _compute_slice_index(sizes[dim], end)
  815. )
  816. # size
  817. new_size: IntLikeType | None = None
  818. if start_index is not None and end_index is not None:
  819. if guard_or_false(end_index >= start_index):
  820. new_size = (end_index - start_index + step - 1) // step
  821. elif guard_or_false(start_index >= end_index):
  822. new_size = 0
  823. # create unbacked if case unknown
  824. if new_size is None:
  825. if shape_env is None:
  826. raise AssertionError("Must have shape_env to create symint")
  827. new_size = shape_env.create_unbacked_symint()
  828. torch._check(new_size >= 0)
  829. torch._check(new_size <= sizes[dim])
  830. # stride
  831. new_stride = strides[dim] * step
  832. # storage offset
  833. if start_index is not None:
  834. storage_offset = self.storage_offset() + start_index * strides[dim]
  835. else:
  836. if shape_env is None:
  837. raise AssertionError("Must have shape_env to create symint")
  838. storage_offset = shape_env.create_unbacked_symint()
  839. torch._check(storage_offset >= 0)
  840. sizes[dim] = new_size # type: ignore[unsupported-operation]
  841. strides[dim] = new_stride
  842. if self.is_quantized:
  843. raise NotImplementedError(
  844. "Slice decomposition for quantized tensors aren't implemented"
  845. )
  846. else:
  847. return self.as_strided(sizes, strides, storage_offset) # type: ignore[return-value]
  848. @register_op_impl(torch.ops.aten.masked_select.default)
  849. def masked_select(
  850. fake_mode: FakeTensorMode, func: OpOverload, self: FakeTensor, mask: FakeTensor
  851. ) -> FakeTensor:
  852. if (
  853. fake_mode.shape_env is None
  854. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  855. ):
  856. # Without symints/symfloats, cannot handle this
  857. raise DynamicOutputShapeException(func)
  858. nnz = fake_mode.shape_env.create_unbacked_symint()
  859. # see nonzero for commentary
  860. maxval = sys.maxsize - 1
  861. # Avoid importing sympy at a module level
  862. from torch.fx.experimental.symbolic_shapes import (
  863. _constrain_range_for_size,
  864. has_free_symbols,
  865. )
  866. from torch.utils._sympy.numbers import IntInfinity
  867. from torch.utils._sympy.value_ranges import bound_sympy
  868. # If num elements is expressed symbolically, calculate
  869. # the concrete value based on upper bounds. Otherwise,
  870. # we can set max val directly.
  871. if not has_free_symbols(self.numel()):
  872. num_elements = int(self.numel())
  873. else:
  874. prod_node = math.prod(self.shape).node # type: ignore[union-attr]
  875. prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range)
  876. if isinstance(prod_range.upper, IntInfinity):
  877. num_elements = sys.maxsize - 1
  878. else:
  879. num_elements = prod_range.upper
  880. if num_elements > 2:
  881. maxval = num_elements
  882. _constrain_range_for_size(nnz, max=maxval)
  883. return self.new_empty((nnz,)) # type: ignore[return]
  884. @register_op_impl(torch.ops.aten._assert_tensor_metadata.default)
  885. def assert_tensor_metadata(
  886. fake_mode: FakeTensorMode,
  887. func: OpOverload,
  888. t: FakeTensor,
  889. sizes: torch.Size | None = None,
  890. strides: tuple[int, ...] | None = None,
  891. dtype: torch.dtype | None = None,
  892. *,
  893. device: torch.device | None = None,
  894. layout: torch.layout | None = None,
  895. ) -> None:
  896. if sizes is not None:
  897. if t.size() != sizes:
  898. raise AssertionError(
  899. f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}"
  900. )
  901. if strides is not None:
  902. if t.stride() != strides:
  903. raise AssertionError(
  904. f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}"
  905. )
  906. if dtype is not None:
  907. if t.dtype != dtype:
  908. raise AssertionError(
  909. f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}"
  910. )
  911. if layout is not None:
  912. if t.layout != layout:
  913. raise AssertionError(
  914. f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout}"
  915. )
  916. if device is not None:
  917. if t.device != device:
  918. raise AssertionError(
  919. f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
  920. )
  921. # NB: this must be ordered after local_scalar_dense
  922. @register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
  923. def data_dep(
  924. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  925. ) -> None:
  926. raise DataDependentOutputException(func)
  927. # Bool Indices get Expanded as Masks
  928. # See: IndexingUtils.h:expandTensors
  929. def check_no_bool_index_tensors(
  930. func: OpOverload, self: FakeTensor, indices: list[FakeTensor | None]
  931. ) -> None:
  932. for index in indices:
  933. if index is not None and index.dtype in (torch.bool, torch.uint8):
  934. raise DynamicOutputShapeException(func)
  935. def run_and_return_new_tensor_of_input_device(
  936. fake_mode: FakeTensorMode,
  937. func: OpOverload,
  938. args: tuple[Any, ...],
  939. kwargs: dict[str, Any],
  940. ) -> FakeTensor:
  941. # TODO: ref
  942. _, new_kwargs = _normalize_function_or_error(
  943. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  944. )
  945. out_device = new_kwargs["input"].device
  946. with in_kernel_invocation_manager(fake_mode):
  947. out = func(*args, **kwargs)
  948. if not is_noncontiguous_supported(out_device):
  949. out = out.new_empty(out.shape)
  950. if out is new_kwargs["input"]:
  951. return out # copy_
  952. return FakeTensor(fake_mode, out, out_device)
  953. _is_builtin_namespaces = ordered_set("aten", "prims", "prim")
  954. def is_builtin(op: OpOverload) -> bool:
  955. return op.namespace in _is_builtin_namespaces
  956. def has_meta(func: OpOverload) -> bool:
  957. return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
  958. # These are for the `torch._foreach_...` ops like `torch._foreach_add`.
  959. @register_op_impl(
  960. lambda func: is_builtin(func)
  961. and func.name().startswith("aten::_foreach_")
  962. and has_meta(func)
  963. )
  964. def foreach_run_and_map_input_device(
  965. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  966. ) -> list[FakeTensor] | None:
  967. tensor_lists = [
  968. arg
  969. for arg in itertools.chain(args, kwargs.values())
  970. if isinstance(arg, (list, tuple))
  971. and len(arg)
  972. and isinstance(arg[0], torch.Tensor)
  973. ]
  974. try:
  975. with in_kernel_invocation_manager(fake_mode):
  976. out_meta = func(*args, **kwargs)
  977. except NotImplementedError:
  978. return NotImplemented
  979. if not out_meta:
  980. return out_meta
  981. if not tensor_lists:
  982. raise AssertionError("tensor_lists must not be empty")
  983. out_fake = []
  984. for i, meta_t in enumerate(out_meta):
  985. device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
  986. out_fake.append(
  987. fake_mode.fake_tensor_converter.from_meta_and_device(
  988. fake_mode, meta_t, device
  989. )
  990. )
  991. return out_fake
  992. # Dont default to default device handling,
  993. # Since op can take in non-zero sized cpu
  994. # index tensors with cuda self
  995. @register_op_impl(aten.index.Tensor)
  996. def index_tensor(
  997. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  998. ) -> FakeTensor:
  999. from torch._meta_registrations import meta_index_Tensor
  1000. _, new_kwargs = _normalize_function_or_error(
  1001. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1002. )
  1003. out_device = new_kwargs["input"].device
  1004. # ensure nonzero call goes to fake tensor
  1005. with fake_mode:
  1006. out = meta_index_Tensor(*args, **kwargs)
  1007. return out.to(out_device)
  1008. # Can take mixed meta/non-meta arguments; the meta registration
  1009. # will roughly do the right thing even when given real devices
  1010. @register_op_impl(aten._embedding_bag.default)
  1011. def embedding_bag(
  1012. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  1013. ) -> tuple[FakeTensor, FakeTensor, FakeTensor, FakeTensor]:
  1014. from torch._meta_registrations import meta_embedding_bag
  1015. with fake_mode:
  1016. return meta_embedding_bag(*args, **kwargs)
  1017. # takes in multiple-devices, dont default to default device handling
  1018. @register_op_impl(aten._unsafe_index_put.default)
  1019. @register_op_impl(aten.copy.default)
  1020. @register_op_impl(aten.copy_.default)
  1021. @register_op_impl(aten.slice_scatter.default)
  1022. @register_op_impl(aten.diagonal_scatter.default)
  1023. def multi_device_op_default(
  1024. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  1025. ) -> FakeTensor:
  1026. return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  1027. # same with multi_device_op_default, but return the input
  1028. @register_op_impl(aten.copy.out)
  1029. @register_op_impl(aten.slice_scatter.out)
  1030. def multi_device_op_out(
  1031. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  1032. ) -> FakeTensor:
  1033. with in_kernel_invocation_manager(fake_mode):
  1034. func(*args, **kwargs)
  1035. _, new_kwargs = _normalize_function_or_error(
  1036. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1037. )
  1038. return new_kwargs["input"]
  1039. @register_op_impl(aten.index_put.default)
  1040. @register_op_impl(aten.index_put_.default)
  1041. def index_put_impl(
  1042. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  1043. ) -> FakeTensor:
  1044. _, new_kwargs = _normalize_function_or_error(
  1045. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1046. )
  1047. values = new_kwargs["values"]
  1048. self_device = new_kwargs["input"].fake_device
  1049. torch._check(
  1050. self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
  1051. lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
  1052. )
  1053. out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
  1054. if func is aten.index_put_.default:
  1055. return new_kwargs["input"]
  1056. else:
  1057. return out
  1058. @register_op_impl(aten._nested_tensor_from_tensor_list.default)
  1059. @register_op_impl(aten._nested_tensor_from_tensor_list.out)
  1060. @register_op_impl(aten._nested_view_from_buffer.default)
  1061. @register_op_impl(aten._nested_view_from_buffer_copy.default)
  1062. def nested_tensors_unsupported(
  1063. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  1064. ) -> None:
  1065. raise UnsupportedOperatorException(func)
  1066. @register_op_impl(
  1067. [
  1068. x
  1069. for x in _device_not_kwarg_ops
  1070. if x
  1071. not in (
  1072. # these are already registered elsewhere
  1073. aten.is_pinned.default,
  1074. aten.to.device,
  1075. aten.to.prim_Device,
  1076. aten._nested_tensor_from_tensor_list.default,
  1077. aten._nested_tensor_from_tensor_list.out,
  1078. )
  1079. ]
  1080. )
  1081. def nyi(fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any) -> None:
  1082. if func in _device_not_kwarg_ops:
  1083. raise AssertionError(f"NYI: {func}")
  1084. @register_op_impl([aten.convolution.default, aten.convolution_backward.default])
  1085. def conv(
  1086. fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
  1087. ) -> FakeTensor | tuple[FakeTensor | None, FakeTensor | None, FakeTensor | None]:
  1088. _, new_kwargs = _normalize_function_or_error(
  1089. func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
  1090. )
  1091. device = new_kwargs["input"].fake_device
  1092. # need to re-enable mode so the tensors report fake device
  1093. with fake_mode:
  1094. # if the input is unsqueezed is done in Convolution.cpp we get segfault
  1095. k = new_kwargs["weight"].ndim
  1096. batch = new_kwargs["input"].shape[0]
  1097. # Avoid importing sympy at a module level
  1098. from torch.fx.experimental.symbolic_shapes import has_hint
  1099. if not has_hint(batch):
  1100. # TODO: We can make this a little more faithful with best effort
  1101. # channels last detection (but only if it's statically obvious!)
  1102. mem_fmt = None
  1103. else:
  1104. if func is aten.convolution.default:
  1105. conv_backend = torch._C._select_conv_backend(**new_kwargs)
  1106. else:
  1107. conv_backend = torch._C._select_conv_backend(
  1108. new_kwargs["input"],
  1109. new_kwargs["weight"],
  1110. bias=None,
  1111. stride=new_kwargs["stride"],
  1112. padding=new_kwargs["padding"],
  1113. dilation=new_kwargs["dilation"],
  1114. transposed=new_kwargs["transposed"],
  1115. output_padding=new_kwargs["output_padding"],
  1116. groups=new_kwargs["groups"],
  1117. bias_sizes=new_kwargs["bias_sizes"],
  1118. )
  1119. # Expand 1d -> 2d.
  1120. # Note: Avoid expanding before calling _select_conv_backend,
  1121. # as the function handles 2D expansion internally.
  1122. if (
  1123. k == 3
  1124. and not new_kwargs["input"].is_mkldnn
  1125. and not new_kwargs["input"].is_xpu
  1126. ):
  1127. # Note: Using input.to(memory_format=contiguous) does not work.
  1128. new_kwargs["input"] = new_kwargs["input"].contiguous().unsqueeze(2)
  1129. new_kwargs["weight"] = new_kwargs["weight"].unsqueeze(2)
  1130. if len(new_kwargs["stride"]) == 1:
  1131. new_kwargs["stride"].insert(0, 1)
  1132. new_kwargs["padding"].insert(0, 0)
  1133. new_kwargs["dilation"].insert(0, 1)
  1134. new_kwargs["output_padding"].insert(0, 0)
  1135. mem_fmt = torch._C._conv_determine_backend_memory_format(
  1136. new_kwargs["input"], new_kwargs["weight"], conv_backend
  1137. )
  1138. # revert 2d -> 1d
  1139. if (
  1140. k == 3
  1141. and not new_kwargs["input"].is_mkldnn
  1142. and not new_kwargs["input"].is_xpu
  1143. ):
  1144. new_kwargs["input"] = new_kwargs["input"].squeeze(2)
  1145. new_kwargs["weight"] = new_kwargs["weight"].squeeze(2)
  1146. if len(new_kwargs["stride"]) == 2:
  1147. new_kwargs["stride"].pop(0)
  1148. new_kwargs["padding"].pop(0)
  1149. new_kwargs["dilation"].pop(0)
  1150. new_kwargs["output_padding"].pop(0)
  1151. def convert(
  1152. t: torch.Tensor | None, mem_fmt: torch.memory_format | None
  1153. ) -> FakeTensor | None:
  1154. if t is None:
  1155. return t
  1156. if mem_fmt is not None:
  1157. # channels last only support 4d, try to expand dim then convert it back later.
  1158. if t.dim() == 3 and mem_fmt == torch.channels_last:
  1159. t = t.unsqueeze(2).to(memory_format=mem_fmt).squeeze(2)
  1160. else:
  1161. t = t.to(memory_format=mem_fmt)
  1162. return FakeTensor(fake_mode, t, device)
  1163. with in_kernel_invocation_manager(fake_mode):
  1164. out = func(**new_kwargs)
  1165. if func is aten.convolution.default:
  1166. return convert(out, mem_fmt) # type: ignore[return]
  1167. else:
  1168. return (
  1169. convert(out[0], mem_fmt),
  1170. convert(out[1], mem_fmt),
  1171. convert(out[2], None),
  1172. )
  1173. @register_op_impl(torch.ops.aten.bincount.default)
  1174. def bincount(
  1175. fake_mode: FakeTensorMode,
  1176. func: OpOverload,
  1177. inputs: FakeTensor,
  1178. weights: FakeTensor | None = None,
  1179. minlength: IntLikeType = 0,
  1180. ) -> FakeTensor:
  1181. if (
  1182. fake_mode.shape_env is None
  1183. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  1184. ):
  1185. # Without symints/symfloats, cannot handle this
  1186. raise DynamicOutputShapeException(func)
  1187. new_size = fake_mode.shape_env.create_unbacked_symint()
  1188. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  1189. _constrain_range_for_size(new_size)
  1190. torch._check(new_size >= minlength)
  1191. return inputs.new_empty(new_size) # type: ignore[return]
  1192. @register_op_impl(torch.ops.aten._pack_padded_sequence.default)
  1193. def _pack_padded_sequence(
  1194. fake_mode: FakeTensorMode,
  1195. func: OpOverload,
  1196. inputs: FakeTensor,
  1197. lengths: FakeTensor,
  1198. batch_first: bool,
  1199. ) -> tuple[FakeTensor, FakeTensor]:
  1200. if (
  1201. fake_mode.shape_env is None
  1202. or not fake_mode.shape_env.allow_dynamic_output_shape_ops
  1203. ):
  1204. # Without symints/symfloats, cannot handle this
  1205. raise DynamicOutputShapeException(func)
  1206. new_batch_size = fake_mode.shape_env.create_unbacked_symint()
  1207. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  1208. _constrain_range_for_size(new_batch_size)
  1209. if not batch_first:
  1210. # Inputs should have shape (batch_size, seq_len, *)
  1211. inputs = inputs.transpose(0, 1) # type: ignore[assignment]
  1212. res_size = inputs.shape[1:]
  1213. packed_data = inputs.new_empty(res_size)
  1214. batch_size = inputs.new_empty((new_batch_size,))
  1215. return (packed_data, batch_size) # type: ignore[return]
  1216. # pyrefly: ignore [implicit-any]
  1217. FAST_OP_IMPLEMENTATIONS = {}
  1218. # Unlike register_op_impl, these don't do the slow iteration for
  1219. # run_impl_check, and these run BEFORE decompositions
  1220. def register_fast_op_impl(
  1221. func: OpOverload,
  1222. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  1223. def impl_decorator(op_impl: Callable[_P, _R]) -> Callable[_P, _R]:
  1224. FAST_OP_IMPLEMENTATIONS[func] = op_impl
  1225. return op_impl
  1226. return impl_decorator
  1227. # infer_size_impl in ExpandUtils
  1228. def infer_size(
  1229. a: Sequence[IntLikeType], b: Sequence[IntLikeType]
  1230. ) -> tuple[IntLikeType, ...]:
  1231. from torch.fx.experimental.symbolic_shapes import guard_or_false
  1232. dimsA = len(a)
  1233. dimsB = len(b)
  1234. ndim = max(dimsA, dimsB)
  1235. expandedSizes: list[IntLikeType] = [0] * ndim
  1236. for i in range(ndim - 1, -1, -1):
  1237. offset = ndim - 1 - i
  1238. dimA = dimsA - 1 - offset
  1239. dimB = dimsB - 1 - offset
  1240. sizeA = a[dimA] if dimA >= 0 else 1
  1241. sizeB = b[dimB] if dimB >= 0 else 1
  1242. # NB: It is very important to test for broadcasting, before testing
  1243. # sizeA == sizeB. This is because the broadcasting tests are likely
  1244. # to be statically known (in particular, if sizeA/sizeB is unbacked
  1245. # but size-like, we will unsoundly assume they never equal 1), but
  1246. # the sizeA == sizeB test may not be statically known. However, once
  1247. # we have established that no broadcasting is happening, the
  1248. # sizeA == sizeB is now expect_true and we can defer it as a runtime
  1249. # assert (this works because Python will return the terminal
  1250. # expression of an or statement as-is, without bool()'ing it; if this
  1251. # were not the case, we'd need to write this using torch.sym_or() or
  1252. # something like that).
  1253. torch._check(
  1254. guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
  1255. lambda: f"The size of tensor a ({sizeA}) "
  1256. f"must match the size of tensor b ({sizeB}) "
  1257. f"at non-singleton dimension {i})",
  1258. )
  1259. expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
  1260. return tuple(expandedSizes)
  1261. def make_fast_binary_impl(
  1262. slow_ref: Callable[..., Any],
  1263. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1264. ) -> Callable[..., FakeTensor]:
  1265. def fast_binary_impl(mode: FakeTensorMode, *args: Any, **kwargs: Any) -> FakeTensor:
  1266. def slow(msg: str) -> FakeTensor:
  1267. count_label(f"slow {msg}")
  1268. with mode:
  1269. return slow_ref(*args, **kwargs)
  1270. count_label("attempt fast")
  1271. # Fast path (based off of TensorIterator fast path).
  1272. # Unfortunately, there is no way to easily deduplicate
  1273. # this with either the TensorIterator C++ implementation
  1274. # (which we don't want to SymIntify, and also the algorithm
  1275. # here is slightly different from TensorIterator to allow
  1276. # for broadcasting), nor the PrimTorch implementation
  1277. # (which does not actually implement a fast path.)
  1278. operands = args
  1279. # compute_shape
  1280. final_shape: ShapeType | None = None
  1281. for op in operands:
  1282. shape: ShapeType = op.shape if isinstance(op, torch.Tensor) else ()
  1283. if final_shape is None:
  1284. final_shape = shape
  1285. # TODO: Minor optimization: track if the shapes
  1286. # were equal so you can skip the equality check
  1287. # below if unnecessary
  1288. # pyrefly: ignore[bad-assignment]
  1289. final_shape = infer_size(final_shape, shape)
  1290. if final_shape is None:
  1291. raise AssertionError("final_shape must not be None")
  1292. from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq
  1293. # Do some extra safety checks to see if the output
  1294. # stride is obvious
  1295. for op in operands:
  1296. if (
  1297. isinstance(op, torch.Tensor)
  1298. and len(op.shape) == len(final_shape)
  1299. # take the slow path if result is not determined.
  1300. and guard_or_false(sym_eq(op.shape, final_shape)) # type: ignore[arg-type]
  1301. ):
  1302. break
  1303. else:
  1304. # if we never break in the for loop above we take the slow path.
  1305. return slow("both tensors nontrivially broadcast")
  1306. # compute_types
  1307. cpu = torch.device("cpu")
  1308. common_device: torch.device = cpu
  1309. common_dtype: torch.dtype | None = None
  1310. has_different_input_dtypes = False
  1311. for op in operands:
  1312. if not isinstance(op, torch.Tensor):
  1313. # Use elementwise_dtypes for the tricky case
  1314. has_different_input_dtypes = True
  1315. continue
  1316. if common_device == cpu and op.device.type != "cpu":
  1317. common_device = op.device
  1318. if common_dtype is None:
  1319. if type_promotion_kind != ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
  1320. has_different_input_dtypes = True
  1321. else:
  1322. common_dtype = op.dtype
  1323. elif common_dtype != op.dtype:
  1324. has_different_input_dtypes = True
  1325. if has_different_input_dtypes:
  1326. # compute promotion
  1327. # TODO: we don't need the compute type
  1328. _, common_dtype = elementwise_dtypes(
  1329. *operands, type_promotion_kind=type_promotion_kind
  1330. )
  1331. # check all tensors on same device
  1332. # cpu scalars are assumed allow
  1333. current_cpu_scalars_on_non_cpu = 0
  1334. max_cpu_scalars_on_non_cpu = 1 # hard coded atm
  1335. for op in operands:
  1336. if not isinstance(op, torch.Tensor):
  1337. continue
  1338. if common_device != cpu and op.dim() == 0 and op.device == cpu:
  1339. if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
  1340. return slow("error")
  1341. current_cpu_scalars_on_non_cpu += 1
  1342. elif op.device != common_device:
  1343. return slow("error")
  1344. # compute_fast_setup_type
  1345. definitely_contiguous = True
  1346. definitely_channels_last = True
  1347. # TODO: is_non-overlapping_and_dense not bound from Python
  1348. # no inplace, no out, everything defined
  1349. if is_noncontiguous_supported(common_device):
  1350. for op in operands:
  1351. if not isinstance(op, torch.Tensor):
  1352. continue
  1353. definitely_contiguous = (
  1354. definitely_contiguous
  1355. and is_contiguous_for_memory_format_or_false(
  1356. op, memory_format=torch.contiguous_format
  1357. )
  1358. )
  1359. definitely_channels_last = (
  1360. definitely_channels_last
  1361. and is_contiguous_for_memory_format_or_false(
  1362. op, memory_format=torch.channels_last
  1363. )
  1364. )
  1365. if definitely_contiguous:
  1366. # do contiguous
  1367. count_label("fast is_contiguous")
  1368. return FakeTensor(
  1369. mode,
  1370. torch.empty(
  1371. final_shape,
  1372. dtype=common_dtype,
  1373. device="meta",
  1374. memory_format=torch.contiguous_format,
  1375. ),
  1376. device=common_device,
  1377. )
  1378. if definitely_channels_last:
  1379. count_label("fast channels_last")
  1380. # do channels last
  1381. return FakeTensor(
  1382. mode,
  1383. torch.empty(
  1384. final_shape,
  1385. dtype=common_dtype,
  1386. device="meta",
  1387. memory_format=torch.channels_last,
  1388. ),
  1389. device=common_device,
  1390. )
  1391. return slow("no contiguity match")
  1392. return fast_binary_impl
  1393. # disable the python dispatcher to avoid decomposing detach() further
  1394. # (proxy_mode should still decompose detach() though)
  1395. def fast_detach(
  1396. fake_mode: FakeTensorMode, x: FakeTensor, include_real: bool = False
  1397. ) -> FakeTensor:
  1398. with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
  1399. out = torch.ops.aten.detach.default(x)
  1400. if include_real:
  1401. return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
  1402. return FakeTensor(fake_mode, out, x.device)
  1403. @functools.cache
  1404. def get_fast_op_impls() -> dict[OpOverload, Callable[..., Any]]:
  1405. import torch._refs
  1406. register_fast_op_impl(torch.ops.aten.add.Tensor)(
  1407. make_fast_binary_impl(torch._refs.add)
  1408. )
  1409. register_fast_op_impl(torch.ops.aten.sub.Tensor)(
  1410. make_fast_binary_impl(torch._refs.sub)
  1411. )
  1412. register_fast_op_impl(torch.ops.aten.mul.Tensor)(
  1413. make_fast_binary_impl(torch._refs.mul)
  1414. ) # type: ignore[has-type]
  1415. register_fast_op_impl(torch.ops.aten.div.Tensor)(
  1416. make_fast_binary_impl(
  1417. torch._refs.div,
  1418. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1419. )
  1420. )
  1421. register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
  1422. return FAST_OP_IMPLEMENTATIONS