wrappers.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import types
  4. import warnings
  5. from collections.abc import Callable, Sequence
  6. from functools import wraps
  7. from types import GenericAlias
  8. from typing import NamedTuple, Optional, overload, TypeVar, Union
  9. from typing_extensions import ParamSpec
  10. import torch
  11. import torch._prims_common as utils
  12. from torch._prims_common import (
  13. CustomOutParamAnnotation,
  14. ELEMENTWISE_TYPE_PROMOTION_KIND,
  15. Number,
  16. NumberType,
  17. ShapeType,
  18. TensorLike,
  19. TensorLikeType,
  20. )
  21. from torch.utils import _pytree as pytree
  22. from torch.utils._pytree import tree_flatten, tree_unflatten
  23. _T = TypeVar("_T")
  24. _P = ParamSpec("_P")
  25. @overload
  26. # pyrefly: ignore [bad-return]
  27. def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
  28. pass
  29. @overload
  30. # pyrefly: ignore [bad-return]
  31. def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
  32. pass
  33. @overload
  34. # pyrefly: ignore [bad-return]
  35. def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
  36. pass
  37. @overload
  38. def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
  39. pass
  40. # TODO: implement ref.cast with an option to enforce safe casting
  41. def _maybe_convert_to_dtype(a, dtype):
  42. if isinstance(a, TensorLike):
  43. if a.dtype != dtype:
  44. return a.to(dtype)
  45. return a
  46. if isinstance(a, Number):
  47. return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type]
  48. if isinstance(a, Sequence):
  49. return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
  50. # Passthrough None because some functions wrapped with type promotion
  51. # wrapper might have optional args
  52. if a is None:
  53. return None
  54. raise ValueError(
  55. f"Received unsupported type {type(a)}. Expected TensorLike, Number, or Sequence."
  56. )
  57. def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
  58. if not isinstance(a, Number):
  59. msg = f"Found unknown type {type(a)} when trying to convert scalars!"
  60. raise ValueError(msg)
  61. if not utils.is_weakly_lesser_type(type(a), typ):
  62. msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!"
  63. raise ValueError(msg)
  64. return typ(a)
  65. def _annotation_has_type(*, typ, annotation):
  66. if hasattr(annotation, "__args__"):
  67. for a in annotation.__args__:
  68. if _annotation_has_type(typ=typ, annotation=a):
  69. return True
  70. return False
  71. return typ is annotation
  72. class elementwise_type_promotion_wrapper:
  73. """
  74. Adds elementwise type promotion to a Python reference implementation.
  75. Takes two kwargs, type_promoting_args and type_promotion_kind.
  76. type_promoting_args must be a string Sequence specifying the argument names of all
  77. arguments that participate in type promotion (and should be type promoted). If the
  78. arg specifies a Sequence-type then every element of the Sequence will participate in
  79. type promotion.
  80. type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
  81. See its documentation for details.
  82. The return_dtype will be coerced to the wrapped function's dtype arg if it is available and
  83. not None.
  84. Other type promotion behavior, like validating the Python type of scalar arguments, must
  85. be handled separately.
  86. """
  87. def __init__(
  88. self,
  89. *,
  90. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  91. type_promoting_args: Optional[Sequence[str]] = None,
  92. ):
  93. self.type_promoting_arg_names = type_promoting_args
  94. self.type_promotion_kind = type_promotion_kind
  95. def __call__(self, fn: Callable) -> Callable:
  96. sig = inspect.signature(fn)
  97. # TorchDynamo tracing of inspect causes fake tensor dynamo_wrapped tests to fail
  98. # PYTORCH_TEST_WITH_DYNAMO=1 python test/test_fake_tensor.py FakeTensorTest.test_basic
  99. @torch._disable_dynamo
  100. @wraps(fn)
  101. def _fn(*args, **kwargs):
  102. bound = sig.bind(*args, **kwargs)
  103. type_promoting_args = tuple(
  104. bound.arguments[x]
  105. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  106. if x in bound.arguments
  107. )
  108. flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args)
  109. compute_dtype, result_dtype = utils.elementwise_dtypes(
  110. *flattened_type_promoting_args,
  111. type_promotion_kind=self.type_promotion_kind,
  112. )
  113. promoted_args = {
  114. x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
  115. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  116. if x in bound.arguments
  117. }
  118. bound.arguments.update(promoted_args)
  119. result = fn(**bound.arguments)
  120. # Override the return_dtype if a dtype arg is present and not None
  121. if "dtype" in bound.arguments:
  122. maybe_dtype = bound.arguments["dtype"]
  123. if maybe_dtype: # dtype cannot be None
  124. result_dtype = maybe_dtype
  125. if isinstance(result, TensorLike):
  126. return _maybe_convert_to_dtype(result, result_dtype)
  127. if isinstance(result, Sequence):
  128. return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
  129. raise AssertionError(f"Unhandled result type: {type(result)}")
  130. _fn.__signature__ = sig # type: ignore[attr-defined]
  131. return _fn
  132. # Returns True if resize is necessary
  133. def _resize_output_check(out: TensorLikeType, shape: ShapeType):
  134. # If the shapes are correct there's nothing to do
  135. if utils.same_shape(out.shape, shape):
  136. return False
  137. if out.numel() != 0:
  138. msg = (
  139. f"An output with one or more elements was resized since it had shape {str(out.shape)} "
  140. "which does not match the required output shape {str(shape)}. "
  141. "This behavior is deprecated, and in a future PyTorch release outputs will not "
  142. "be resized unless they have zero elements. "
  143. "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
  144. )
  145. warnings.warn(msg, stacklevel=2)
  146. return True
  147. # TODO: handle tuples of tensors
  148. def _maybe_resize_out(
  149. out: TensorLikeType,
  150. shape: ShapeType,
  151. memory_format: Optional[torch.memory_format] = None,
  152. ):
  153. if _resize_output_check(out, shape):
  154. return out.resize_(shape, memory_format=memory_format)
  155. else:
  156. return out
  157. def is_cpu_scalar(x: TensorLikeType) -> bool:
  158. return x.dim() == 0 and x.device.type == "cpu"
  159. def check_copy_devices(*, copy_from: TensorLikeType, copy_to: TensorLikeType) -> None:
  160. if copy_from.device != copy_to.device:
  161. msg = (
  162. f"Attempting to copy from device {copy_from.device} "
  163. f"to device {copy_to.device}, but cross-device copies are not allowed!"
  164. )
  165. raise RuntimeError(msg)
  166. def _safe_copy_out(
  167. *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
  168. ):
  169. # Checks same device
  170. if not is_cpu_scalar(copy_from):
  171. check_copy_devices(copy_from=copy_from, copy_to=copy_to)
  172. # Checks safe cast
  173. if exact_dtype:
  174. torch._check(
  175. copy_from.dtype == copy_to.dtype,
  176. lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
  177. f"but got {copy_to.dtype} instead",
  178. )
  179. else:
  180. torch._check(
  181. utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
  182. lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
  183. "but this can't be cast because it is not safe!",
  184. )
  185. return copy_to.copy_(copy_from)
  186. def out_wrapper(
  187. *out_names: str,
  188. exact_dtype: bool = False,
  189. pass_is_out: bool = False,
  190. preserve_memory_format: bool = False,
  191. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  192. # The wrapped function needs to convert the output parameters to ensure
  193. # compatibility between the Python API (which always uses "out" as the
  194. # parameter name and may be a tuple) and the Aten API (which may have
  195. # multiple output parameters and use different parameter names such as
  196. # "grad_input", "indices" or "values".)
  197. default_out_names = ("out",)
  198. if len(out_names) == 0:
  199. # Use default in out name
  200. out_names = default_out_names
  201. is_tensor = len(out_names) == 1
  202. def maybe_compute_memory_format(t):
  203. return utils.suggest_memory_format(t) if preserve_memory_format else None
  204. def _out_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
  205. """
  206. Adds the out parameter to a Python reference.
  207. """
  208. out_type = (
  209. TensorLikeType
  210. if is_tensor
  211. else GenericAlias(
  212. tuple, tuple(TensorLikeType for _ in range(len(out_names)))
  213. )
  214. )
  215. # For backward compatibility - should be able to remove once PEP585
  216. # conversion is complete.
  217. bc_out_type = (
  218. TensorLikeType
  219. if is_tensor
  220. else types.GenericAlias(
  221. tuple, tuple(TensorLikeType for _ in range(len(out_names)))
  222. )
  223. )
  224. return_type = (
  225. TensorLikeType
  226. if is_tensor
  227. else NamedTuple(
  228. f"return_types_{fn.__name__}",
  229. # pyrefly: ignore [bad-argument-count]
  230. [(o, TensorLikeType) for o in out_names],
  231. )
  232. )
  233. sig = inspect.signature(fn)
  234. factory_kwargs = ("device", "dtype")
  235. is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
  236. @wraps(fn)
  237. def _fn(*args: _P.args, **kwargs: _P.kwargs):
  238. out = kwargs.pop("out", None)
  239. if is_factory_fn and out is not None:
  240. for k in factory_kwargs:
  241. out_attr = getattr(out, k)
  242. if k not in kwargs:
  243. kwargs[k] = out_attr
  244. def maybe_check_copy_devices(out):
  245. if isinstance(out, TensorLike) and isinstance(args[0], TensorLike):
  246. check_copy_devices(copy_from=args[0], copy_to=out)
  247. if isinstance(out, (tuple, list)):
  248. for o in out:
  249. maybe_check_copy_devices(o)
  250. else:
  251. maybe_check_copy_devices(out)
  252. if pass_is_out:
  253. result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type]
  254. else:
  255. result = fn(*args, **kwargs)
  256. if result is NotImplemented:
  257. return NotImplemented
  258. if not (
  259. (isinstance(result, TensorLike) and is_tensor)
  260. or (
  261. isinstance(result, tuple) # type: ignore[arg-type]
  262. and len(result) == len(out_names) # type: ignore[arg-type]
  263. )
  264. or (
  265. fn.__name__ == "unbind" and isinstance(result, (list, tuple)) # type: ignore[arg-type]
  266. )
  267. ):
  268. raise AssertionError(
  269. f"Unexpected result type: {type(result)}, is_tensor={is_tensor}, "
  270. f"out_names={out_names}"
  271. )
  272. # unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
  273. if out is not None:
  274. # Naively you might expect this assert to be true, but
  275. # it's not:
  276. #
  277. # assert type(out) is type(result)
  278. #
  279. # The reason is that functions under this wrapper can
  280. # get registered to the Meta dispatch key, and that
  281. # means they can be executed in a context where tensor
  282. # subclasses are disabled (with no_dispatch), which is a
  283. # handy way for an is-a tensor subclass (e.g.,
  284. # FakeTensor) to have the normal meta backend create a
  285. # meta tensor, to be wrapped once it gets returned.
  286. # In this situation, you will get a FakeTensor as
  287. # the output tensor, but not the result--which will
  288. # be a normal meta tensor, but this is perfectly
  289. # harmless.
  290. if is_tensor and fn.__name__ != "unbind":
  291. if not isinstance(out, TensorLike):
  292. raise AssertionError(
  293. f"out must be TensorLike, got {type(out)}"
  294. ) # mypy
  295. # These two operations are done in-place
  296. _maybe_resize_out(
  297. out,
  298. result.shape, # type: ignore[union-attr]
  299. maybe_compute_memory_format(result),
  300. )
  301. _safe_copy_out(
  302. copy_from=result, # type: ignore[arg-type]
  303. copy_to=out,
  304. exact_dtype=exact_dtype,
  305. )
  306. else:
  307. if fn.__name__ != "unbind":
  308. if not isinstance(out, tuple):
  309. raise AssertionError(f"out must be tuple, got {type(out)}") # type: ignore[arg-type] # mypy
  310. else:
  311. if not isinstance(out, (list, tuple)):
  312. raise AssertionError(
  313. f"out must be list or tuple, got {type(out)}"
  314. ) # type: ignore[arg-type] # mypy
  315. torch._check_type(
  316. len(out) == len(result), # type: ignore[arg-type]
  317. lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type]
  318. )
  319. for r, o in zip(result, out): # type: ignore[arg-type]
  320. # These two operations are done in-place
  321. _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r))
  322. _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
  323. else:
  324. out = result
  325. # mypy does not see through the definition of out_type given that it's in a different scope
  326. return out if is_tensor else return_type(*out) # type: ignore[operator]
  327. out_param = inspect.Parameter(
  328. "out",
  329. kind=inspect.Parameter.KEYWORD_ONLY,
  330. default=None,
  331. annotation=out_type,
  332. )
  333. # Mark that the function now returns a tuple
  334. if not (
  335. isinstance(sig.return_annotation, (str, TypeVar))
  336. or sig.return_annotation in (sig.empty, out_type, bc_out_type)
  337. ):
  338. raise AssertionError(
  339. f"Unexpected return annotation: {sig.return_annotation}, "
  340. f"expected str, TypeVar, empty, {out_type}, or {bc_out_type}"
  341. )
  342. params = *sig.parameters.values(), out_param
  343. # If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear
  344. # after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by
  345. # Parameter.kind guarantees that all the parameters are in legal order.
  346. params = sorted(params, key=lambda p: p.kind)
  347. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  348. parameters=params,
  349. return_annotation=return_type, # type: ignore[arg-type]
  350. )
  351. _fn.__annotations__ = dict(getattr(fn, "__annotations__", {}))
  352. _fn.__annotations__["out"] = out_type
  353. _fn.__annotations__["return"] = return_type
  354. # In the special case of having a single tensor out parameter with a
  355. # name other than out, add a special annotation to name the parameter
  356. if is_tensor and out_names != default_out_names:
  357. _fn.__annotations__[CustomOutParamAnnotation] = out_names[0]
  358. # Add an indicator attribute that can be used in special cases
  359. # where having a function wrapped by `out_wrapper` is not desirable e.g.
  360. # jit
  361. _fn._torch_decompositions_out_wrapper = ( # type: ignore[attr-defined]
  362. f"This function is wrapped by {out_wrapper.__module__}.out_wrapper"
  363. )
  364. return _fn
  365. return _out_wrapper
  366. def _maybe_remove_out_wrapper(fn: Callable):
  367. return inspect.unwrap(
  368. fn,
  369. stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"),
  370. )
  371. def backwards_not_supported(prim):
  372. def redispatch_prim(args, kwargs):
  373. with torch._C._AutoDispatchBelowAutograd():
  374. return prim(*args, **kwargs)
  375. class BackwardsNotSupported(torch.autograd.Function):
  376. @staticmethod
  377. # pyrefly: ignore [bad-override]
  378. def forward(ctx, args_spec, *flat_args):
  379. args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
  380. return redispatch_prim(args, kwargs)
  381. @staticmethod
  382. def backward(ctx, *args):
  383. raise RuntimeError("backwards not supported on prim")
  384. @wraps(prim)
  385. def _autograd_impl(*args, **kwargs):
  386. flat_args, args_spec = tree_flatten((args, kwargs))
  387. if torch.is_grad_enabled() and any(
  388. a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)
  389. ):
  390. # TODO: There is a subtle bug here: prims like copy_to
  391. # return their input argument after mutating it; and custom
  392. # autograd function will incorrectly turn the result into
  393. # a view which will fail test_python_ref_executor tests.
  394. # At the moment, we sidestep this by observing that the
  395. # unit tests don't ever try to run the executor with
  396. # autograd, so we don't exercise the buggy case, but if
  397. # you ever want to feed autograd through this, be aware
  398. # of it! We need a way of properly implementing autograd
  399. # for mutating operations in Python to do this.
  400. return BackwardsNotSupported.apply(args_spec, *flat_args)
  401. else:
  402. return redispatch_prim(args, kwargs)
  403. return _autograd_impl
  404. # TODO: when tracing this will add torch tensors and not TensorMeta objects
  405. # to the trace -- we should fix this by adding a tracing context and NumberMeta classes
  406. # TODO: this wrapper is currently untested
  407. def elementwise_unary_scalar_wrapper(
  408. fn: Callable[_P, _T],
  409. ) -> Callable[_P, Union[_T, NumberType]]:
  410. """
  411. Allows unary operators that accept tensors to work with Python numbers.
  412. """
  413. sig = inspect.signature(fn)
  414. @wraps(fn)
  415. def _fn(*args, **kwargs):
  416. if len(args) > 0 and isinstance(args[0], Number):
  417. dtype = utils.type_to_dtype(type(args[0]))
  418. args_ = list(args)
  419. args_[0] = torch.tensor(args[0], dtype=dtype)
  420. # pyrefly: ignore [invalid-param-spec]
  421. result = fn(*args_, **kwargs)
  422. if not isinstance(result, torch.Tensor):
  423. raise AssertionError(f"Expected torch.Tensor, got {type(result)}")
  424. return result.item()
  425. # pyrefly: ignore [invalid-param-spec]
  426. return fn(*args, **kwargs)
  427. _fn.__signature__ = sig # type: ignore[attr-defined]
  428. # pyrefly: ignore [bad-return]
  429. return _fn