| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693 |
- from __future__ import annotations
- import functools
- import itertools
- import math
- import operator
- import sys
- from functools import reduce
- from typing import Any, cast as typing_cast, TYPE_CHECKING, TypeVar, Union
- from typing_extensions import ParamSpec
- import torch
- import torch._custom_op
- import torch._logging
- import torch._prims_common as utils
- from torch._dispatch.python import no_python_dispatcher
- from torch._ops import OpOverload
- from torch._prims_common import (
- canonicalize_dim,
- elementwise_dtypes,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- is_boolean_dtype,
- is_contiguous,
- is_contiguous_for_memory_format_or_false,
- is_contiguous_or_false,
- is_float_dtype,
- is_integer_dtype,
- make_contiguous_strides_for,
- ShapeType,
- )
- from torch._subclasses.fake_tensor import (
- DataDependentOutputException,
- DynamicOutputShapeException,
- FakeTensor,
- in_kernel_invocation_manager,
- run_fallback_kernel,
- UnsupportedOperatorException,
- )
- from torch.fx.operator_schemas import _normalize_function_or_error
- from torch.utils._stats import count_label
- if TYPE_CHECKING:
- from collections.abc import Callable, Sequence
- from torch._subclasses.fake_tensor import FakeTensorMode
- from torch.types import IntLikeType
- FakeTensorLike = Union[FakeTensor, torch.Tensor]
- _P = ParamSpec("_P")
- _R = TypeVar("_R")
- _T = TypeVar("_T")
- pytree = torch.utils._pytree
- __all__ = [
- "op_implementations_checks",
- "get_fast_op_impls",
- "stride_incorrect_op",
- "has_meta",
- ]
- # pyrefly: ignore [implicit-any]
- op_implementations_dict = {}
- # pyrefly: ignore [implicit-any]
- op_implementations_checks = []
- aten = torch._ops.ops.aten
- def ordered_set(*items: _T) -> dict[_T, bool]:
- return dict.fromkeys(items, True)
- # This function indicates if the backend device
- # supports non-contiguous tensors
- def is_noncontiguous_supported(device: torch.device) -> bool:
- return device.type != "hpu"
- _like_tensor_constructors = ordered_set(
- aten.empty_like.default,
- aten.empty_like.out,
- aten.full_like.default,
- aten.full_like.out,
- aten.ones_like.default,
- aten.ones_like.out,
- aten.rand_like.default,
- aten.rand_like.generator,
- aten.rand_like.out,
- aten.rand_like.generator_out,
- aten.randn_like.default,
- aten.randn_like.generator,
- aten.randn_like.out,
- aten.randn_like.generator_out,
- aten.randint_like.default,
- aten.randint_like.generator,
- aten.randint_like.Tensor,
- aten.randint_like.Tensor_generator,
- aten.randint_like.Tensor_out,
- aten.randint_like.Tensor_generator_out,
- aten.randint_like.out,
- aten.randint_like.generator_out,
- aten.randint_like.low_dtype,
- aten.randint_like.low_generator_dtype,
- aten.randint_like.low_dtype_out,
- aten.randint_like.low_generator_dtype_out,
- aten.zeros_like.default,
- aten.zeros_like.out,
- aten.new_empty.default,
- aten.new_empty.out,
- aten.new_empty_strided.default,
- aten.new_empty_strided.out,
- aten.new_full.default,
- aten.new_full.out,
- aten.new_zeros.default,
- aten.new_zeros.out,
- aten.new_ones.default,
- aten.new_ones.out,
- )
- _device_not_kwarg_ops = ordered_set(
- aten._resize_output_.default,
- aten._nested_tensor_from_tensor_list.default,
- aten._nested_tensor_from_tensor_list.out,
- aten.pin_memory.default,
- aten.to.device,
- aten.to.prim_Device,
- aten.is_pinned.default,
- aten._pin_memory.default,
- aten._pin_memory.out,
- aten._resize_output.default,
- aten._resize_output.out,
- )
- # this op is never actually used
- _non_kwarg_device_constructors = (aten._list_to_tensor,)
- def contains_tensor_types(type_: Any) -> bool:
- tensor_type = torch._C.TensorType.get()
- return type_.isSubtypeOf(tensor_type) or any(
- contains_tensor_types(e) for e in type_.containedTypes()
- )
- @functools.cache
- def _is_tensor_constructor(func: OpOverload) -> bool:
- if not isinstance(func, OpOverload):
- raise AssertionError(f"func must be an OpOverload, got {type(func)}")
- schema = func._schema
- if any(contains_tensor_types(arg.type) for arg in schema.arguments):
- return False
- # TODO: no real reason to restrict multiple outputs
- return (
- len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
- )
- def register_op_impl(
- run_impl_check: Callable[[OpOverload], bool]
- | OpOverload
- | list[OpOverload]
- | tuple[OpOverload, ...],
- ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
- def impl_decorator(op_impl: Callable[_P, _R]) -> Callable[_P, _R]:
- if isinstance(run_impl_check, OpOverload):
- if run_impl_check in op_implementations_dict:
- raise AssertionError(f"duplicate registration: {run_impl_check}")
- op_implementations_dict[run_impl_check] = op_impl
- elif isinstance(run_impl_check, (list, tuple)):
- for op in run_impl_check:
- register_op_impl(op)(op_impl)
- else:
- if not callable(run_impl_check):
- raise AssertionError(
- f"run_impl_check must be callable, got {type(run_impl_check)}"
- )
- op_implementations_checks.append((run_impl_check, op_impl))
- return op_impl
- return impl_decorator
- def _is_op_registered_to_fake_rule(op: OpOverload) -> bool:
- return op in op_implementations_dict
- def _deregister_op_impl(op: OpOverload) -> None:
- op_implementations_dict.pop(op, None)
- for check, impl in op_implementations_checks:
- if check is op:
- op_implementations_checks.remove((check, impl))
- break
- @register_op_impl(op_implementations_dict.__contains__)
- def dispatch_to_op_implementations_dict(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> Any:
- return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
- @register_op_impl(_is_tensor_constructor)
- @register_op_impl([*_like_tensor_constructors])
- def constructors(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- if func in _non_kwarg_device_constructors:
- raise AssertionError(
- f"func must not be in _non_kwarg_device_constructors, got {func}"
- )
- _, new_kwargs = _normalize_function_or_error(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- if "names" in kwargs:
- # REASON: "torch.compile doesn't support named tensors"
- raise UnsupportedOperatorException(func)
- if func in _like_tensor_constructors:
- default_device = new_kwargs["input"].device
- # TODO: file issue
- args = (new_kwargs.pop("input"),)
- else:
- # cpu is default device if none is specified
- default_device = torch.device("cpu")
- args = ()
- out_device = new_kwargs.pop("device", None)
- out_device = out_device if out_device is not None else default_device
- new_kwargs["device"] = torch.device("meta")
- # _like constructors have fake tensor inputs (maybe this causes the non-like
- # to fail? hmmm)
- with in_kernel_invocation_manager(fake_mode):
- r = func(*args, **new_kwargs)
- return FakeTensor(fake_mode, r, out_device)
- @register_op_impl(aten.is_pinned.default)
- def non_kwarg_is_pinned(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> bool:
- _, new_kwargs = _normalize_function_or_error(
- func, args, kwargs, normalize_to_only_use_kwargs=True
- )
- inp = new_kwargs.pop("input")
- # we'll ignore device argument because it is deprecated and not
- # actually used by is_pinned.
- with in_kernel_invocation_manager(fake_mode):
- r = func(inp)
- return r
- # Legacy profiler ops return Tensors but don't follow tensor constructor patterns
- # They take string arguments and should not have device/dtype parameters added
- @register_op_impl(torch.ops.profiler._record_function_enter.default)
- def _record_function_enter(
- fake_mode: FakeTensorMode, func: OpOverload, name: str, args: object | None = None
- ) -> FakeTensor:
- # Call the real implementation to get a real handle tensor
- with in_kernel_invocation_manager(fake_mode):
- real_handle = func(name, args)
- # Create a meta tensor with the same properties as the real handle
- meta_handle = torch.empty_like(real_handle, device="meta")
- # Wrap it as a FakeTensor
- return FakeTensor(fake_mode, meta_handle, torch.device("cpu"))
- @register_op_impl(torch.ops.profiler._record_function_exit.default)
- def _record_function_exit(
- fake_mode: FakeTensorMode, func: OpOverload, handle: Any
- ) -> None:
- # Exit doesn't return anything and doesn't need to do anything for fake tensors
- # Just return None (the actual return type is void)
- pass
- @register_op_impl(torch.ops.profiler._record_function_enter_new.default)
- def _record_function_enter_new(
- fake_mode: FakeTensorMode, func: OpOverload, name: str, args: object | None = None
- ) -> Any:
- # Call the real implementation - returns a custom class, not a tensor
- # Just pass through without wrapping
- with in_kernel_invocation_manager(fake_mode):
- return func(name, args)
- @register_op_impl(aten.to.prim_Device)
- @register_op_impl(aten.to.device)
- def non_kwarg_to(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- _, new_kwargs = _normalize_function_or_error(
- func, args, kwargs, normalize_to_only_use_kwargs=True
- )
- input_device = new_kwargs["device"]
- out_device = input_device if input_device else new_kwargs["input"].device
- new_kwargs["device"] = torch.device("meta")
- inp = new_kwargs.pop("input")
- with in_kernel_invocation_manager(fake_mode):
- r = func(inp, **new_kwargs)
- # TODO: I think this does the wrong thing if r is inp
- return fake_mode.fake_tensor_converter.from_meta_and_device(
- fake_mode, r, out_device
- )
- def stride_incorrect_op(op: OpOverload) -> bool:
- return False
- # These operators have meta implementations with incorrect strides
- @register_op_impl(stride_incorrect_op)
- def workaround_stride_incorrect_op(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- # This is a workaround for meta implementations with incorrect strides
- def is_symbolic(x: object) -> bool:
- if isinstance(x, FakeTensor):
- return x._has_symbolic_sizes_strides
- if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
- return True
- return False
- # For static shapes, we can fall back to eager for the real strides
- if fake_mode.allow_fallback_kernels:
- require_dynamic = any(
- is_symbolic(x) for x in itertools.chain(args, kwargs.values())
- )
- if not require_dynamic:
- flat_args, args_spec = pytree.tree_flatten((args, kwargs))
- return run_fallback_kernel(
- fake_mode,
- func,
- flat_args,
- args_spec,
- # TODO: refactor to lambda so we don't instantiate extra errors before
- # calling
- RuntimeError("Cannot run fallback kernel for stride_incorrect_op"),
- )
- raise UnsupportedOperatorException(func)
- # Dont default to default device handling,
- # since the device of `the_template` is ignored
- @register_op_impl(aten.resize_as_.default)
- def resize_as_(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- with in_kernel_invocation_manager(fake_mode):
- return func(*args, **kwargs)
- @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
- def _sparse_coo_tensor_with_dims_and_tensors(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- return constructors(fake_mode, func, *args, **kwargs)
- # index.Tensor data-dependent in only some conditions
- @register_op_impl(
- lambda func: torch.Tag.dynamic_output_shape in func.tags
- and func
- not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
- )
- def dyn_shape(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> None:
- raise DynamicOutputShapeException(func)
- def _unique(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- arg: FakeTensor,
- dim: int | None,
- sorted: bool = True,
- return_inverse: bool = False,
- return_counts: bool = False,
- *,
- unique_consecutive: bool = False,
- ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo
- # Do not use a memo for unique_dim
- if dim is not None or nnz is None:
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- has_free_symbols,
- )
- if not has_free_symbols(arg.numel()) and arg.numel() == 0:
- # If numel is zero, then the output size must be zero.
- # In this case, we must not allocate an unbacked SymInt,
- # because if we do, it will immediately get refined to
- # zero, but this will be inconsistent with size oblivious
- # tests (which will continue to claim that the unbacked
- # symint cannot equal zero). We could also unconditionally
- # allocate an unbacked SymInt and not refine its range,
- # but this seems more precise.
- nnz = 0
- else:
- nnz = fake_mode.shape_env.create_unbacked_symint()
- maxval = sys.maxsize - 1
- numel = arg.numel() if dim is None else arg.size(dim)
- if not has_free_symbols(numel):
- maxval = int(numel)
- _constrain_range_for_size(nnz, max=maxval)
- if dim is None:
- if unique_consecutive:
- arg.unique_consecutive_memo = nnz
- else:
- arg.unique_memo = nnz
- if dim is None:
- # pyrefly: ignore[no-matching-overload]
- ret = [arg.new_empty((nnz,))]
- else:
- # pyrefly: ignore[no-matching-overload]
- ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
- return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
- if return_inverse or return_if_dim_and_cpu:
- inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
- else:
- inverse = arg.new_empty(0)
- ret.append(inverse)
- if return_counts or return_if_dim_and_cpu:
- counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
- else:
- counts = arg.new_empty(0)
- ret.append(counts)
- return tuple(ret)
- @register_op_impl(aten._unique2.default)
- def unique2(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- arg: FakeTensor,
- sorted: bool = True,
- return_inverse: bool = False,
- return_counts: bool = False,
- ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
- return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
- @register_op_impl(aten.select.int)
- def meta_select(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- self: FakeTensor,
- dim: int,
- index: IntLikeType,
- ) -> FakeTensor:
- from torch.fx.experimental.symbolic_shapes import guard_or_false
- if self.is_sparse:
- return NotImplemented
- ndim = self.dim()
- torch._check_index(
- ndim != 0,
- lambda: "select() cannot be applied to a 0-dim tensor.",
- )
- dim = dim if dim >= 0 else dim + ndim
- size = self.size(dim)
- new_size = list(self.size())
- new_stride = list(self.stride())
- new_storage_offset = None
- if guard_or_false(index >= 0):
- new_storage_offset = self.storage_offset() + index * new_stride[dim]
- elif guard_or_false(index < 0):
- new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim]
- if new_storage_offset is None:
- if fake_mode.shape_env is None or (
- not fake_mode.shape_env.allow_scalar_outputs
- and not fake_mode.allow_scalar_outputs
- ):
- raise DataDependentOutputException(func)
- # index is data-dependent, we do not know which index we are accessing it could be index or index+size!
- # we assign a new data-dependent symbol for the storage offset.
- new_storage_offset = fake_mode.shape_env.create_unbacked_symint()
- del new_size[dim]
- del new_stride[dim]
- if new_storage_offset is None:
- raise AssertionError("new_storage_offset must not be None")
- # pyrefly: ignore[bad-return]
- return self.as_strided(new_size, new_stride, new_storage_offset)
- @register_op_impl(aten.unique_dim.default)
- def unique_dim(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- arg: FakeTensor,
- dim: int,
- sorted: bool = True,
- return_inverse: bool = False,
- return_counts: bool = False,
- ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
- return _unique(
- fake_mode,
- func,
- arg,
- # normalize dim to be non-negative
- dim if dim >= 0 else dim % max(arg.ndim, 1),
- sorted,
- return_inverse,
- return_counts,
- )
- @register_op_impl(aten.unique_consecutive.default)
- def unique_consecutive(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- arg: FakeTensor,
- return_inverse: bool = False,
- return_counts: bool = False,
- dim: int | None = None,
- ) -> tuple[FakeTensor, FakeTensor, FakeTensor]:
- return _unique(
- fake_mode,
- func,
- arg,
- dim,
- False,
- return_inverse,
- return_counts,
- unique_consecutive=True,
- )
- # This function is python match of computeStride_impl in TensorUtils.cpp
- def _compute_stride(
- old_shape: Sequence[IntLikeType],
- old_stride: Sequence[IntLikeType],
- new_shape: Sequence[IntLikeType],
- size_oblivious: bool = False,
- ) -> list[IntLikeType] | None:
- from torch.fx.experimental.symbolic_shapes import (
- guard_or_false,
- guard_or_true,
- sym_eq,
- )
- def maybe_guard_or_false(x: Any) -> Any:
- if size_oblivious:
- return guard_or_false(x)
- return x
- def maybe_guard_or_true(x: Any) -> Any:
- if size_oblivious:
- return guard_or_true(x)
- return x
- if len(old_shape) == 0:
- return [1] * len(new_shape)
- numel = reduce(operator.mul, old_shape, 1)
- zero_numel = maybe_guard_or_false(numel == 0)
- if zero_numel and maybe_guard_or_false(sym_eq(old_shape, new_shape)):
- return list(old_stride)
- new_stride: list[IntLikeType] = [0] * len(new_shape)
- if zero_numel:
- for view_d in range(len(new_shape) - 1, -1, -1):
- if view_d == len(new_shape) - 1:
- new_stride[view_d] = 1
- else:
- new_stride[view_d] = (
- max(new_shape[view_d + 1], 1) * new_stride[view_d + 1]
- )
- return new_stride
- view_d = len(new_shape) - 1
- # Annotate type here to support type checking
- chunk_base_stride: IntLikeType = old_stride[-1]
- tensor_numel: IntLikeType = 1
- view_numel: IntLikeType = 1
- for tensor_d in range(len(old_shape) - 1, -1, -1):
- tensor_numel *= old_shape[tensor_d]
- if tensor_d == 0 or (
- maybe_guard_or_true(old_shape[tensor_d - 1] != 1)
- and maybe_guard_or_true(
- old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride
- )
- ):
- while view_d >= 0 and (
- maybe_guard_or_true(view_numel < tensor_numel)
- or maybe_guard_or_false(new_shape[view_d] == 1)
- ):
- new_stride[view_d] = view_numel * chunk_base_stride
- view_numel *= new_shape[view_d]
- view_d -= 1
- if maybe_guard_or_true(view_numel != tensor_numel):
- return None
- if tensor_d > 0:
- chunk_base_stride = old_stride[tensor_d - 1]
- tensor_numel = 1
- view_numel = 1
- if view_d != -1:
- return None
- return new_stride
- def _view_has_unbacked_input(
- a: torch.Tensor, shape: ShapeType | tuple[ShapeType]
- ) -> bool:
- from torch.fx.experimental.symbolic_shapes import has_hint
- shape = utils.extract_shape_from_varargs(shape, validate=False)
- return (
- any(not has_hint(s) for s in a.size())
- or any(not has_hint(s) for s in a.stride())
- or any(not has_hint(s) for s in shape)
- )
- def _view_unbacked_meta(
- a: torch.Tensor,
- shape: ShapeType | tuple[ShapeType],
- size_oblivious_enabled: bool = True,
- ) -> torch.Tensor:
- from torch._prims import view_of
- from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq
- # Creates a valid shape
- shape = utils.extract_shape_from_varargs(shape, validate=False)
- # Reshape may be given a shape with a -1 length
- # This indicates that the dimension's length should be inferred
- shape = utils.infer_size(shape, a.numel())
- # Special-cases reshaping zero dim tensors
- if a.ndim == 0:
- _a = a
- for length in shape:
- torch._check(length == 1)
- _a = torch._refs.unsqueeze(_a, -1)
- if _a is a:
- return view_of(a)
- else:
- return _a # type: ignore[return-value]
- # Special-cases reshaping to zero dim tensors
- if len(shape) == 0:
- _a = a
- for length in a.shape:
- torch._check(length == 1)
- _a = torch._refs.squeeze(_a, -1)
- if _a is a:
- return view_of(a)
- else:
- return _a # type: ignore[return-value]
- shape_numel = reduce(operator.mul, shape, 1)
- torch._check(
- a.numel() == shape_numel,
- lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
- )
- if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)):
- return view_of(a)
- if is_contiguous_or_false(a) if size_oblivious_enabled else is_contiguous(a):
- strides = make_contiguous_strides_for(shape)
- return a.as_strided(shape, strides) # type: ignore[return-value]
- new_strides = _compute_stride(
- a.size(), a.stride(), shape, size_oblivious=size_oblivious_enabled
- )
- if new_strides is not None:
- return a.as_strided(shape, new_strides) # type: ignore[return-value]
- # If we fail to do size oblivious view, and backed_size_oblivious was on,
- # then we redo everything by looking at hints and guarding instead of failing.
- # Also if the expression has unbacked symbols, then we run again with size_oblivious_enabled=False
- # to throw a data dependent error.
- if size_oblivious_enabled and (
- torch.fx.experimental._config.backed_size_oblivious
- or _view_has_unbacked_input(a, shape)
- ):
- return _view_unbacked_meta(a, shape, size_oblivious_enabled=False)
- msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
- raise ValueError(msg)
- @register_op_impl(aten._reshape_copy.default)
- def _reshape_copy(
- fake_mode: FakeTensorMode, func: OpOverload, a: FakeTensor, *shape: Any
- ) -> FakeTensor | Exception:
- if a.is_sparse or a.is_mkldnn:
- return NotImplemented
- # pyrefly: ignore[bad-argument-count]
- shape = utils.infer_size(*shape, a.numel())
- if is_contiguous_or_false(a):
- view = _view_meta(fake_mode, func, a, *shape)
- return typing_cast(
- FakeTensor, view.clone(memory_format=torch.contiguous_format)
- )
- else:
- return _view_meta(
- fake_mode,
- func,
- typing_cast(FakeTensor, a.clone(memory_format=torch.contiguous_format)),
- *shape,
- )
- @register_op_impl(aten.view.default)
- @register_op_impl(aten._unsafe_view.default)
- def _view_meta(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- a: FakeTensor,
- *shape: Any,
- ) -> FakeTensor:
- if torch.fx.experimental._config.backed_size_oblivious or _view_has_unbacked_input(
- a, shape
- ):
- return typing_cast(FakeTensor, _view_unbacked_meta(a, shape))
- else:
- return typing_cast(
- FakeTensor, torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
- )
- @register_op_impl(aten.view_copy.default)
- def _view_meta_copy(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- a: FakeTensor,
- *shape: IntLikeType,
- out: FakeTensor | None = None,
- ) -> FakeTensor:
- result = _view_meta(fake_mode, func, a, *shape)
- if out is not None:
- return result
- return pytree.tree_map(
- lambda x: x.clone(memory_format=torch.contiguous_format),
- result,
- )
- @register_op_impl(aten.repeat_interleave.Tensor)
- def repeat_interleave_tensor(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- repeats: FakeTensor,
- output_size: IntLikeType | None = None,
- ) -> FakeTensor:
- if output_size is None:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- raise DynamicOutputShapeException(func)
- output_size = fake_mode.shape_env.create_unbacked_symint()
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
- _constrain_range_for_size(output_size)
- # TODO: consider a memo
- return repeats.new_empty(output_size) # type: ignore[return-value]
- @register_op_impl(torch.ops.aten.item.default)
- @register_op_impl(torch.ops.aten._local_scalar_dense.default)
- def local_scalar_dense(
- fake_mode: FakeTensorMode, func: OpOverload, arg: FakeTensor
- ) -> int | float | bool | torch.SymInt | torch.SymFloat | torch.SymBool:
- if (r := arg.item_memo) is not None:
- return r
- if fake_mode.shape_env is None or (
- not fake_mode.shape_env.allow_scalar_outputs
- and not fake_mode.allow_scalar_outputs
- ):
- # Without symints/symfloats, cannot handle this
- raise DataDependentOutputException(func)
- if is_float_dtype(arg.dtype):
- r = fake_mode.shape_env.create_unbacked_symfloat()
- elif is_integer_dtype(arg.dtype):
- r = fake_mode.shape_env.create_unbacked_symint()
- elif is_boolean_dtype(arg.dtype):
- r = fake_mode.shape_env.create_unbacked_symbool()
- else:
- raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
- arg.item_memo = r
- return r
- @register_op_impl(torch.ops.aten.nonzero_numpy.default)
- def nonzero_numpy(
- fake_mode: FakeTensorMode, func: OpOverload, arg: FakeTensor
- ) -> list[FakeTensor]:
- return torch.ops.aten.nonzero.default(arg).unbind(1)
- @register_op_impl(torch.ops.aten.nonzero.default)
- def nonzero(fake_mode: FakeTensorMode, func: OpOverload, arg: FakeTensor) -> FakeTensor:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- if (nnz := arg.nonzero_memo) is None:
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- has_free_symbols,
- )
- from torch.utils._sympy.numbers import IntInfinity
- from torch.utils._sympy.value_ranges import bound_sympy
- if not has_free_symbols(arg.numel()) and arg.numel() == 0:
- # If numel is zero, then the output size must be zero.
- # In this case, we must not allocate an unbacked SymInt,
- # because if we do, it will immediately get refined to
- # zero, but this will be inconsistent with size oblivious
- # tests (which will continue to claim that the unbacked
- # symint cannot equal zero). We could also unconditionally
- # allocate an unbacked SymInt and not refine its range,
- # but this seems more precise.
- nnz = 0
- else:
- nnz = fake_mode.shape_env.create_unbacked_symint()
- maxval = sys.maxsize - 1
- if not has_free_symbols(arg.numel()):
- maxval = int(arg.numel())
- else:
- prod_node = math.prod(arg.shape).node # type: ignore[union-attr]
- prod_range = bound_sympy(
- prod_node.expr, prod_node.shape_env.var_to_range
- )
- if isinstance(prod_range.upper, IntInfinity):
- maxval = sys.maxsize - 1
- else:
- maxval = prod_range.upper
- _constrain_range_for_size(nnz, max=maxval)
- arg.nonzero_memo = nnz
- return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64) # type: ignore[return]
- @register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)
- def _padded_dense_to_jagged_forward(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- padded: FakeTensor,
- offsets: list[FakeTensor],
- total_L: IntLikeType | None = None,
- ) -> FakeTensor:
- # only one jagged dim is supported for now
- if len(offsets) != 1:
- raise AssertionError(
- f"Only one jagged dim is supported, got {len(offsets)} offsets"
- )
- if not total_L:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- total_L = fake_mode.shape_env.create_unbacked_symint()
- maxval = sys.maxsize - 1
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- has_free_symbols,
- )
- if not has_free_symbols(padded.numel()):
- maxval = int(padded.numel())
- _constrain_range_for_size(total_L, min=0, max=maxval)
- output_shape = (total_L, *padded.shape[2:])
- return padded.new_empty(output_shape) # type: ignore[return]
- def _compute_slice_index(size: IntLikeType, index: IntLikeType) -> IntLikeType | None:
- from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and
- if guard_or_false(sym_and(index >= 0, index <= size)):
- return index
- elif guard_or_false(sym_and(index < 0, index >= -size)):
- return index + size
- elif guard_or_false(index < -size):
- return 0
- elif guard_or_false(index > size):
- return size
- return None
- @register_op_impl(torch.ops.aten.slice.Tensor)
- def slice_forward(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- self: FakeTensor,
- dim: int = 0,
- start: int | None = None,
- end: int | None = None,
- step: int = 1,
- ) -> FakeTensor:
- from torch.fx.experimental.symbolic_shapes import (
- guard_or_false,
- statically_known_true,
- )
- shape_env = fake_mode.shape_env
- ndim = self.dim()
- if ndim == 0:
- raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
- dim = canonicalize_dim(self.dim(), dim)
- sizes = list(self.size())
- strides = list(self.stride())
- if step <= 0:
- raise RuntimeError("slice step must be positive")
- # start, end
- start_index = 0 if start is None else _compute_slice_index(sizes[dim], start)
- end_index = (
- sizes[dim]
- if statically_known_true(end == sys.maxsize) or end is None
- else _compute_slice_index(sizes[dim], end)
- )
- # size
- new_size: IntLikeType | None = None
- if start_index is not None and end_index is not None:
- if guard_or_false(end_index >= start_index):
- new_size = (end_index - start_index + step - 1) // step
- elif guard_or_false(start_index >= end_index):
- new_size = 0
- # create unbacked if case unknown
- if new_size is None:
- if shape_env is None:
- raise AssertionError("Must have shape_env to create symint")
- new_size = shape_env.create_unbacked_symint()
- torch._check(new_size >= 0)
- torch._check(new_size <= sizes[dim])
- # stride
- new_stride = strides[dim] * step
- # storage offset
- if start_index is not None:
- storage_offset = self.storage_offset() + start_index * strides[dim]
- else:
- if shape_env is None:
- raise AssertionError("Must have shape_env to create symint")
- storage_offset = shape_env.create_unbacked_symint()
- torch._check(storage_offset >= 0)
- sizes[dim] = new_size # type: ignore[unsupported-operation]
- strides[dim] = new_stride
- if self.is_quantized:
- raise NotImplementedError(
- "Slice decomposition for quantized tensors aren't implemented"
- )
- else:
- return self.as_strided(sizes, strides, storage_offset) # type: ignore[return-value]
- @register_op_impl(torch.ops.aten.masked_select.default)
- def masked_select(
- fake_mode: FakeTensorMode, func: OpOverload, self: FakeTensor, mask: FakeTensor
- ) -> FakeTensor:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- nnz = fake_mode.shape_env.create_unbacked_symint()
- # see nonzero for commentary
- maxval = sys.maxsize - 1
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import (
- _constrain_range_for_size,
- has_free_symbols,
- )
- from torch.utils._sympy.numbers import IntInfinity
- from torch.utils._sympy.value_ranges import bound_sympy
- # If num elements is expressed symbolically, calculate
- # the concrete value based on upper bounds. Otherwise,
- # we can set max val directly.
- if not has_free_symbols(self.numel()):
- num_elements = int(self.numel())
- else:
- prod_node = math.prod(self.shape).node # type: ignore[union-attr]
- prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range)
- if isinstance(prod_range.upper, IntInfinity):
- num_elements = sys.maxsize - 1
- else:
- num_elements = prod_range.upper
- if num_elements > 2:
- maxval = num_elements
- _constrain_range_for_size(nnz, max=maxval)
- return self.new_empty((nnz,)) # type: ignore[return]
- @register_op_impl(torch.ops.aten._assert_tensor_metadata.default)
- def assert_tensor_metadata(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- t: FakeTensor,
- sizes: torch.Size | None = None,
- strides: tuple[int, ...] | None = None,
- dtype: torch.dtype | None = None,
- *,
- device: torch.device | None = None,
- layout: torch.layout | None = None,
- ) -> None:
- if sizes is not None:
- if t.size() != sizes:
- raise AssertionError(
- f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}"
- )
- if strides is not None:
- if t.stride() != strides:
- raise AssertionError(
- f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}"
- )
- if dtype is not None:
- if t.dtype != dtype:
- raise AssertionError(
- f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}"
- )
- if layout is not None:
- if t.layout != layout:
- raise AssertionError(
- f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout}"
- )
- if device is not None:
- if t.device != device:
- raise AssertionError(
- f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
- )
- # NB: this must be ordered after local_scalar_dense
- @register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
- def data_dep(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> None:
- raise DataDependentOutputException(func)
- # Bool Indices get Expanded as Masks
- # See: IndexingUtils.h:expandTensors
- def check_no_bool_index_tensors(
- func: OpOverload, self: FakeTensor, indices: list[FakeTensor | None]
- ) -> None:
- for index in indices:
- if index is not None and index.dtype in (torch.bool, torch.uint8):
- raise DynamicOutputShapeException(func)
- def run_and_return_new_tensor_of_input_device(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- args: tuple[Any, ...],
- kwargs: dict[str, Any],
- ) -> FakeTensor:
- # TODO: ref
- _, new_kwargs = _normalize_function_or_error(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- out_device = new_kwargs["input"].device
- with in_kernel_invocation_manager(fake_mode):
- out = func(*args, **kwargs)
- if not is_noncontiguous_supported(out_device):
- out = out.new_empty(out.shape)
- if out is new_kwargs["input"]:
- return out # copy_
- return FakeTensor(fake_mode, out, out_device)
- _is_builtin_namespaces = ordered_set("aten", "prims", "prim")
- def is_builtin(op: OpOverload) -> bool:
- return op.namespace in _is_builtin_namespaces
- def has_meta(func: OpOverload) -> bool:
- return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
- # These are for the `torch._foreach_...` ops like `torch._foreach_add`.
- @register_op_impl(
- lambda func: is_builtin(func)
- and func.name().startswith("aten::_foreach_")
- and has_meta(func)
- )
- def foreach_run_and_map_input_device(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> list[FakeTensor] | None:
- tensor_lists = [
- arg
- for arg in itertools.chain(args, kwargs.values())
- if isinstance(arg, (list, tuple))
- and len(arg)
- and isinstance(arg[0], torch.Tensor)
- ]
- try:
- with in_kernel_invocation_manager(fake_mode):
- out_meta = func(*args, **kwargs)
- except NotImplementedError:
- return NotImplemented
- if not out_meta:
- return out_meta
- if not tensor_lists:
- raise AssertionError("tensor_lists must not be empty")
- out_fake = []
- for i, meta_t in enumerate(out_meta):
- device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
- out_fake.append(
- fake_mode.fake_tensor_converter.from_meta_and_device(
- fake_mode, meta_t, device
- )
- )
- return out_fake
- # Dont default to default device handling,
- # Since op can take in non-zero sized cpu
- # index tensors with cuda self
- @register_op_impl(aten.index.Tensor)
- def index_tensor(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- from torch._meta_registrations import meta_index_Tensor
- _, new_kwargs = _normalize_function_or_error(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- out_device = new_kwargs["input"].device
- # ensure nonzero call goes to fake tensor
- with fake_mode:
- out = meta_index_Tensor(*args, **kwargs)
- return out.to(out_device)
- # Can take mixed meta/non-meta arguments; the meta registration
- # will roughly do the right thing even when given real devices
- @register_op_impl(aten._embedding_bag.default)
- def embedding_bag(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> tuple[FakeTensor, FakeTensor, FakeTensor, FakeTensor]:
- from torch._meta_registrations import meta_embedding_bag
- with fake_mode:
- return meta_embedding_bag(*args, **kwargs)
- # takes in multiple-devices, dont default to default device handling
- @register_op_impl(aten._unsafe_index_put.default)
- @register_op_impl(aten.copy.default)
- @register_op_impl(aten.copy_.default)
- @register_op_impl(aten.slice_scatter.default)
- @register_op_impl(aten.diagonal_scatter.default)
- def multi_device_op_default(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
- # same with multi_device_op_default, but return the input
- @register_op_impl(aten.copy.out)
- @register_op_impl(aten.slice_scatter.out)
- def multi_device_op_out(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- with in_kernel_invocation_manager(fake_mode):
- func(*args, **kwargs)
- _, new_kwargs = _normalize_function_or_error(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- return new_kwargs["input"]
- @register_op_impl(aten.index_put.default)
- @register_op_impl(aten.index_put_.default)
- def index_put_impl(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor:
- _, new_kwargs = _normalize_function_or_error(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- values = new_kwargs["values"]
- self_device = new_kwargs["input"].fake_device
- torch._check(
- self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
- lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
- )
- out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
- if func is aten.index_put_.default:
- return new_kwargs["input"]
- else:
- return out
- @register_op_impl(aten._nested_tensor_from_tensor_list.default)
- @register_op_impl(aten._nested_tensor_from_tensor_list.out)
- @register_op_impl(aten._nested_view_from_buffer.default)
- @register_op_impl(aten._nested_view_from_buffer_copy.default)
- def nested_tensors_unsupported(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> None:
- raise UnsupportedOperatorException(func)
- @register_op_impl(
- [
- x
- for x in _device_not_kwarg_ops
- if x
- not in (
- # these are already registered elsewhere
- aten.is_pinned.default,
- aten.to.device,
- aten.to.prim_Device,
- aten._nested_tensor_from_tensor_list.default,
- aten._nested_tensor_from_tensor_list.out,
- )
- ]
- )
- def nyi(fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any) -> None:
- if func in _device_not_kwarg_ops:
- raise AssertionError(f"NYI: {func}")
- @register_op_impl([aten.convolution.default, aten.convolution_backward.default])
- def conv(
- fake_mode: FakeTensorMode, func: OpOverload, *args: Any, **kwargs: Any
- ) -> FakeTensor | tuple[FakeTensor | None, FakeTensor | None, FakeTensor | None]:
- _, new_kwargs = _normalize_function_or_error(
- func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
- )
- device = new_kwargs["input"].fake_device
- # need to re-enable mode so the tensors report fake device
- with fake_mode:
- # if the input is unsqueezed is done in Convolution.cpp we get segfault
- k = new_kwargs["weight"].ndim
- batch = new_kwargs["input"].shape[0]
- # Avoid importing sympy at a module level
- from torch.fx.experimental.symbolic_shapes import has_hint
- if not has_hint(batch):
- # TODO: We can make this a little more faithful with best effort
- # channels last detection (but only if it's statically obvious!)
- mem_fmt = None
- else:
- if func is aten.convolution.default:
- conv_backend = torch._C._select_conv_backend(**new_kwargs)
- else:
- conv_backend = torch._C._select_conv_backend(
- new_kwargs["input"],
- new_kwargs["weight"],
- bias=None,
- stride=new_kwargs["stride"],
- padding=new_kwargs["padding"],
- dilation=new_kwargs["dilation"],
- transposed=new_kwargs["transposed"],
- output_padding=new_kwargs["output_padding"],
- groups=new_kwargs["groups"],
- bias_sizes=new_kwargs["bias_sizes"],
- )
- # Expand 1d -> 2d.
- # Note: Avoid expanding before calling _select_conv_backend,
- # as the function handles 2D expansion internally.
- if (
- k == 3
- and not new_kwargs["input"].is_mkldnn
- and not new_kwargs["input"].is_xpu
- ):
- # Note: Using input.to(memory_format=contiguous) does not work.
- new_kwargs["input"] = new_kwargs["input"].contiguous().unsqueeze(2)
- new_kwargs["weight"] = new_kwargs["weight"].unsqueeze(2)
- if len(new_kwargs["stride"]) == 1:
- new_kwargs["stride"].insert(0, 1)
- new_kwargs["padding"].insert(0, 0)
- new_kwargs["dilation"].insert(0, 1)
- new_kwargs["output_padding"].insert(0, 0)
- mem_fmt = torch._C._conv_determine_backend_memory_format(
- new_kwargs["input"], new_kwargs["weight"], conv_backend
- )
- # revert 2d -> 1d
- if (
- k == 3
- and not new_kwargs["input"].is_mkldnn
- and not new_kwargs["input"].is_xpu
- ):
- new_kwargs["input"] = new_kwargs["input"].squeeze(2)
- new_kwargs["weight"] = new_kwargs["weight"].squeeze(2)
- if len(new_kwargs["stride"]) == 2:
- new_kwargs["stride"].pop(0)
- new_kwargs["padding"].pop(0)
- new_kwargs["dilation"].pop(0)
- new_kwargs["output_padding"].pop(0)
- def convert(
- t: torch.Tensor | None, mem_fmt: torch.memory_format | None
- ) -> FakeTensor | None:
- if t is None:
- return t
- if mem_fmt is not None:
- # channels last only support 4d, try to expand dim then convert it back later.
- if t.dim() == 3 and mem_fmt == torch.channels_last:
- t = t.unsqueeze(2).to(memory_format=mem_fmt).squeeze(2)
- else:
- t = t.to(memory_format=mem_fmt)
- return FakeTensor(fake_mode, t, device)
- with in_kernel_invocation_manager(fake_mode):
- out = func(**new_kwargs)
- if func is aten.convolution.default:
- return convert(out, mem_fmt) # type: ignore[return]
- else:
- return (
- convert(out[0], mem_fmt),
- convert(out[1], mem_fmt),
- convert(out[2], None),
- )
- @register_op_impl(torch.ops.aten.bincount.default)
- def bincount(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- inputs: FakeTensor,
- weights: FakeTensor | None = None,
- minlength: IntLikeType = 0,
- ) -> FakeTensor:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- new_size = fake_mode.shape_env.create_unbacked_symint()
- from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
- _constrain_range_for_size(new_size)
- torch._check(new_size >= minlength)
- return inputs.new_empty(new_size) # type: ignore[return]
- @register_op_impl(torch.ops.aten._pack_padded_sequence.default)
- def _pack_padded_sequence(
- fake_mode: FakeTensorMode,
- func: OpOverload,
- inputs: FakeTensor,
- lengths: FakeTensor,
- batch_first: bool,
- ) -> tuple[FakeTensor, FakeTensor]:
- if (
- fake_mode.shape_env is None
- or not fake_mode.shape_env.allow_dynamic_output_shape_ops
- ):
- # Without symints/symfloats, cannot handle this
- raise DynamicOutputShapeException(func)
- new_batch_size = fake_mode.shape_env.create_unbacked_symint()
- from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
- _constrain_range_for_size(new_batch_size)
- if not batch_first:
- # Inputs should have shape (batch_size, seq_len, *)
- inputs = inputs.transpose(0, 1) # type: ignore[assignment]
- res_size = inputs.shape[1:]
- packed_data = inputs.new_empty(res_size)
- batch_size = inputs.new_empty((new_batch_size,))
- return (packed_data, batch_size) # type: ignore[return]
- # pyrefly: ignore [implicit-any]
- FAST_OP_IMPLEMENTATIONS = {}
- # Unlike register_op_impl, these don't do the slow iteration for
- # run_impl_check, and these run BEFORE decompositions
- def register_fast_op_impl(
- func: OpOverload,
- ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
- def impl_decorator(op_impl: Callable[_P, _R]) -> Callable[_P, _R]:
- FAST_OP_IMPLEMENTATIONS[func] = op_impl
- return op_impl
- return impl_decorator
- # infer_size_impl in ExpandUtils
- def infer_size(
- a: Sequence[IntLikeType], b: Sequence[IntLikeType]
- ) -> tuple[IntLikeType, ...]:
- from torch.fx.experimental.symbolic_shapes import guard_or_false
- dimsA = len(a)
- dimsB = len(b)
- ndim = max(dimsA, dimsB)
- expandedSizes: list[IntLikeType] = [0] * ndim
- for i in range(ndim - 1, -1, -1):
- offset = ndim - 1 - i
- dimA = dimsA - 1 - offset
- dimB = dimsB - 1 - offset
- sizeA = a[dimA] if dimA >= 0 else 1
- sizeB = b[dimB] if dimB >= 0 else 1
- # NB: It is very important to test for broadcasting, before testing
- # sizeA == sizeB. This is because the broadcasting tests are likely
- # to be statically known (in particular, if sizeA/sizeB is unbacked
- # but size-like, we will unsoundly assume they never equal 1), but
- # the sizeA == sizeB test may not be statically known. However, once
- # we have established that no broadcasting is happening, the
- # sizeA == sizeB is now expect_true and we can defer it as a runtime
- # assert (this works because Python will return the terminal
- # expression of an or statement as-is, without bool()'ing it; if this
- # were not the case, we'd need to write this using torch.sym_or() or
- # something like that).
- torch._check(
- guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
- lambda: f"The size of tensor a ({sizeA}) "
- f"must match the size of tensor b ({sizeB}) "
- f"at non-singleton dimension {i})",
- )
- expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
- return tuple(expandedSizes)
- def make_fast_binary_impl(
- slow_ref: Callable[..., Any],
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- ) -> Callable[..., FakeTensor]:
- def fast_binary_impl(mode: FakeTensorMode, *args: Any, **kwargs: Any) -> FakeTensor:
- def slow(msg: str) -> FakeTensor:
- count_label(f"slow {msg}")
- with mode:
- return slow_ref(*args, **kwargs)
- count_label("attempt fast")
- # Fast path (based off of TensorIterator fast path).
- # Unfortunately, there is no way to easily deduplicate
- # this with either the TensorIterator C++ implementation
- # (which we don't want to SymIntify, and also the algorithm
- # here is slightly different from TensorIterator to allow
- # for broadcasting), nor the PrimTorch implementation
- # (which does not actually implement a fast path.)
- operands = args
- # compute_shape
- final_shape: ShapeType | None = None
- for op in operands:
- shape: ShapeType = op.shape if isinstance(op, torch.Tensor) else ()
- if final_shape is None:
- final_shape = shape
- # TODO: Minor optimization: track if the shapes
- # were equal so you can skip the equality check
- # below if unnecessary
- # pyrefly: ignore[bad-assignment]
- final_shape = infer_size(final_shape, shape)
- if final_shape is None:
- raise AssertionError("final_shape must not be None")
- from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq
- # Do some extra safety checks to see if the output
- # stride is obvious
- for op in operands:
- if (
- isinstance(op, torch.Tensor)
- and len(op.shape) == len(final_shape)
- # take the slow path if result is not determined.
- and guard_or_false(sym_eq(op.shape, final_shape)) # type: ignore[arg-type]
- ):
- break
- else:
- # if we never break in the for loop above we take the slow path.
- return slow("both tensors nontrivially broadcast")
- # compute_types
- cpu = torch.device("cpu")
- common_device: torch.device = cpu
- common_dtype: torch.dtype | None = None
- has_different_input_dtypes = False
- for op in operands:
- if not isinstance(op, torch.Tensor):
- # Use elementwise_dtypes for the tricky case
- has_different_input_dtypes = True
- continue
- if common_device == cpu and op.device.type != "cpu":
- common_device = op.device
- if common_dtype is None:
- if type_promotion_kind != ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
- has_different_input_dtypes = True
- else:
- common_dtype = op.dtype
- elif common_dtype != op.dtype:
- has_different_input_dtypes = True
- if has_different_input_dtypes:
- # compute promotion
- # TODO: we don't need the compute type
- _, common_dtype = elementwise_dtypes(
- *operands, type_promotion_kind=type_promotion_kind
- )
- # check all tensors on same device
- # cpu scalars are assumed allow
- current_cpu_scalars_on_non_cpu = 0
- max_cpu_scalars_on_non_cpu = 1 # hard coded atm
- for op in operands:
- if not isinstance(op, torch.Tensor):
- continue
- if common_device != cpu and op.dim() == 0 and op.device == cpu:
- if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
- return slow("error")
- current_cpu_scalars_on_non_cpu += 1
- elif op.device != common_device:
- return slow("error")
- # compute_fast_setup_type
- definitely_contiguous = True
- definitely_channels_last = True
- # TODO: is_non-overlapping_and_dense not bound from Python
- # no inplace, no out, everything defined
- if is_noncontiguous_supported(common_device):
- for op in operands:
- if not isinstance(op, torch.Tensor):
- continue
- definitely_contiguous = (
- definitely_contiguous
- and is_contiguous_for_memory_format_or_false(
- op, memory_format=torch.contiguous_format
- )
- )
- definitely_channels_last = (
- definitely_channels_last
- and is_contiguous_for_memory_format_or_false(
- op, memory_format=torch.channels_last
- )
- )
- if definitely_contiguous:
- # do contiguous
- count_label("fast is_contiguous")
- return FakeTensor(
- mode,
- torch.empty(
- final_shape,
- dtype=common_dtype,
- device="meta",
- memory_format=torch.contiguous_format,
- ),
- device=common_device,
- )
- if definitely_channels_last:
- count_label("fast channels_last")
- # do channels last
- return FakeTensor(
- mode,
- torch.empty(
- final_shape,
- dtype=common_dtype,
- device="meta",
- memory_format=torch.channels_last,
- ),
- device=common_device,
- )
- return slow("no contiguity match")
- return fast_binary_impl
- # disable the python dispatcher to avoid decomposing detach() further
- # (proxy_mode should still decompose detach() though)
- def fast_detach(
- fake_mode: FakeTensorMode, x: FakeTensor, include_real: bool = False
- ) -> FakeTensor:
- with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
- out = torch.ops.aten.detach.default(x)
- if include_real:
- return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
- return FakeTensor(fake_mode, out, x.device)
- @functools.cache
- def get_fast_op_impls() -> dict[OpOverload, Callable[..., Any]]:
- import torch._refs
- register_fast_op_impl(torch.ops.aten.add.Tensor)(
- make_fast_binary_impl(torch._refs.add)
- )
- register_fast_op_impl(torch.ops.aten.sub.Tensor)(
- make_fast_binary_impl(torch._refs.sub)
- )
- register_fast_op_impl(torch.ops.aten.mul.Tensor)(
- make_fast_binary_impl(torch._refs.mul)
- ) # type: ignore[has-type]
- register_fast_op_impl(torch.ops.aten.div.Tensor)(
- make_fast_binary_impl(
- torch._refs.div,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- )
- register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
- return FAST_OP_IMPLEMENTATIONS
|