custom_ops.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import inspect
  4. import logging
  5. import warnings
  6. import weakref
  7. from collections.abc import Callable, Iterable, Sequence
  8. from contextlib import contextmanager
  9. from typing import Any, Optional, overload, Union
  10. import torch
  11. from torch import _C, _ops, Tensor
  12. from torch.types import _dtype
  13. from torch.utils._exposed_in import exposed_in
  14. from . import autograd, utils
  15. from .effects import EffectType
  16. device_types_t = Optional[Union[str, Sequence[str]]]
  17. log = logging.getLogger(__name__)
  18. @overload
  19. def custom_op(
  20. name: str,
  21. fn: None = None,
  22. /,
  23. *,
  24. mutates_args: Union[str, Iterable[str]],
  25. device_types: device_types_t = None,
  26. schema: Optional[str] = None,
  27. tags: Optional[Sequence[_C.Tag]] = None,
  28. ) -> Callable[[Callable[..., object]], "CustomOpDef"]: ...
  29. @overload
  30. def custom_op(
  31. name: str,
  32. fn: Callable[..., object],
  33. /,
  34. *,
  35. mutates_args: Union[str, Iterable[str]],
  36. device_types: device_types_t = None,
  37. schema: Optional[str] = None,
  38. tags: Optional[Sequence[_C.Tag]] = None,
  39. ) -> "CustomOpDef": ...
  40. @exposed_in("torch.library")
  41. def custom_op(
  42. name: str,
  43. fn: Optional[Callable] = None,
  44. /,
  45. *,
  46. mutates_args: Union[str, Iterable[str]],
  47. device_types: device_types_t = None,
  48. schema: Optional[str] = None,
  49. tags: Optional[Sequence[_C.Tag]] = None,
  50. ) -> Union[Callable[[Callable[..., object]], "CustomOpDef"], "CustomOpDef"]:
  51. """Wraps a function into custom operator.
  52. Reasons why you may want to create a custom op include:
  53. - Wrapping a third-party library or custom kernel to work with PyTorch
  54. subsystems like Autograd.
  55. - Preventing torch.compile/export/FX tracing from peeking inside your function.
  56. This API is used as a decorator around a function (please see examples).
  57. The provided function must have type hints; these are needed to interface
  58. with PyTorch's various subsystems.
  59. Args:
  60. name (str): A name for the custom op that looks like "{namespace}::{name}",
  61. e.g. "mylib::my_linear". The name is used as the op's stable identifier
  62. in PyTorch subsystems (e.g. torch.export, FX graphs).
  63. To avoid name collisions, please use your project name as the namespace;
  64. e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
  65. mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
  66. This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
  67. it pessimistically assumes that all inputs to the operator are being mutated.
  68. device_types (str | None | Sequence[str]): The device type(s) the function
  69. is valid for. If no device type is provided, then the function
  70. is used as the default implementation for all device types.
  71. Examples: "cpu", "cuda".
  72. When registering a device-specific implementation for an operator that accepts no Tensors,
  73. we require the operator to have a "device: torch.device argument".
  74. schema (str | None): A schema string for the operator. If None
  75. (recommended) we'll infer a schema for the operator from its type
  76. annotations. We recommend letting us infer a schema unless you
  77. have a specific reason not to.
  78. Example: "(Tensor x, int y) -> (Tensor, Tensor)".
  79. .. note::
  80. We recommend not passing in a ``schema`` arg and instead letting us infer
  81. it from the type annotations. It is error-prone to write your own schema.
  82. You may wish to provide your own schema if our interpretation of
  83. the type annotation is not what you want.
  84. For more info on how to write a schema string, see
  85. `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
  86. Examples::
  87. >>> import torch
  88. >>> from torch import Tensor
  89. >>> from torch.library import custom_op
  90. >>> import numpy as np
  91. >>>
  92. >>> @custom_op("mylib::numpy_sin", mutates_args=())
  93. >>> def numpy_sin(x: Tensor) -> Tensor:
  94. >>> x_np = x.cpu().numpy()
  95. >>> y_np = np.sin(x_np)
  96. >>> return torch.from_numpy(y_np).to(device=x.device)
  97. >>>
  98. >>> x = torch.randn(3)
  99. >>> y = numpy_sin(x)
  100. >>> assert torch.allclose(y, x.sin())
  101. >>>
  102. >>> # Example of a custom op that only works for one device type.
  103. >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
  104. >>> def numpy_sin_cpu(x: Tensor) -> Tensor:
  105. >>> x_np = x.numpy()
  106. >>> y_np = np.sin(x_np)
  107. >>> return torch.from_numpy(y_np)
  108. >>>
  109. >>> x = torch.randn(3)
  110. >>> y = numpy_sin_cpu(x)
  111. >>> assert torch.allclose(y, x.sin())
  112. >>>
  113. >>> # Example of a custom op that mutates an input
  114. >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
  115. >>> def numpy_sin_inplace(x: Tensor) -> None:
  116. >>> x_np = x.numpy()
  117. >>> np.sin(x_np, out=x_np)
  118. >>>
  119. >>> x = torch.randn(3)
  120. >>> expected = x.sin()
  121. >>> numpy_sin_inplace(x)
  122. >>> assert torch.allclose(x, expected)
  123. >>>
  124. >>> # Example of a factory function
  125. >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
  126. >>> def bar(device: torch.device) -> Tensor:
  127. >>> return torch.ones(3)
  128. >>>
  129. >>> bar("cpu")
  130. """
  131. def inner(fn: Callable[..., object]) -> CustomOpDef:
  132. import torch
  133. if schema is None:
  134. schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
  135. else:
  136. schema_str = schema
  137. namespace, opname = name.split("::")
  138. result = CustomOpDef(namespace, opname, schema_str, fn, tags)
  139. if schema is not None:
  140. # Check that schema's alias annotations match those of `mutates_args`.
  141. expected = set()
  142. for arg in result._opoverload._schema.arguments:
  143. if arg.alias_info is not None and arg.alias_info.is_write:
  144. expected.add(arg.name)
  145. if expected != set(mutates_args):
  146. raise ValueError(
  147. f"Attempted to create a custom op with `mutates_args={mutates_args}` "
  148. f"and `schema={schema}. The schema suggests that the op mutates {expected}"
  149. f"which is different from what was provided to us in `mutates_args`. "
  150. f"Please make these consistent."
  151. )
  152. result.register_kernel(device_types)(fn)
  153. return result
  154. if fn is None:
  155. return inner
  156. return inner(fn)
  157. class CustomOpDef:
  158. """CustomOpDef is a wrapper around a function that turns it into a custom op.
  159. It has various methods for registering additional behavior for this
  160. custom op.
  161. You should not instantiate CustomOpDef directly; instead, use the
  162. :func:`torch.library.custom_op` API.
  163. """
  164. def __init__(
  165. self,
  166. namespace: str,
  167. name: str,
  168. schema: str,
  169. fn: Callable,
  170. tags: Optional[Sequence[_C.Tag]] = None,
  171. ) -> None:
  172. # Fields used to interface with the PyTorch dispatcher
  173. self._namespace = namespace
  174. self._name = name
  175. self._schema = schema
  176. self._tags = tags if tags is not None else []
  177. self._init_fn = fn
  178. self._backend_fns: dict[Union[str, None], Callable] = {}
  179. self._abstract_fn: Optional[Callable] = None
  180. self._setup_context_fn: Optional[Callable] = None
  181. self._backward_fn: Optional[Callable] = None
  182. self._torch_dispatch_fns: dict[type, Callable] = {}
  183. self._vmap_fn: Optional[Callable] = None
  184. self._autocast_cuda_dtype: Optional[_dtype] = None
  185. self._autocast_cpu_dtype: Optional[_dtype] = None
  186. self._lib = get_library_allowing_overwrite(self._namespace, self._name)
  187. self._register_to_dispatcher(self._tags)
  188. self._disabled_kernel: set = set()
  189. self._used_triton_kernels: list[Any] = list()
  190. OPDEFS[self._qualname] = self
  191. @property
  192. def _qualname(self) -> str:
  193. return f"{self._namespace}::{self._name}"
  194. def __repr__(self) -> str:
  195. return f"<CustomOpDef({self._qualname})>"
  196. @contextmanager
  197. def set_kernel_enabled(self, device_type: str, enabled: bool = True):
  198. """
  199. Disable or re-enable an already registered kernel for this custom operator.
  200. If the kernel is already disabled/enabled, this is a no-op.
  201. Note:
  202. If a kernel is first disabled and then registered, it is disabled until enabled again.
  203. Args:
  204. device_type (str): The device type to disable/enable the kernel for.
  205. disable (bool): Whether to disable or enable the kernel.
  206. Example:
  207. >>> inp = torch.randn(1)
  208. >>>
  209. >>> # define custom op `f`.
  210. >>> @custom_op("mylib::f", mutates_args=())
  211. >>> def f(x: Tensor) -> Tensor:
  212. >>> return torch.zeros(1)
  213. >>>
  214. >>> print(f(inp)) # tensor([0.]), default kernel
  215. >>>
  216. >>> @f.register_kernel("cpu")
  217. >>> def _(x):
  218. >>> return torch.ones(1)
  219. >>>
  220. >>> print(f(inp)) # tensor([1.]), CPU kernel
  221. >>>
  222. >>> # temporarily disable the CPU kernel
  223. >>> with f.set_kernel_enabled("cpu", enabled = False):
  224. >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
  225. """
  226. action = "enable" if enabled else "disable"
  227. originally_disabled = device_type in self._disabled_kernel
  228. if device_type not in self._backend_fns:
  229. log.warning(
  230. "Attempted to %s kernel for %s but no kernel was registered for this device type.",
  231. action,
  232. device_type,
  233. )
  234. if not enabled:
  235. if originally_disabled:
  236. log.warning(
  237. "Attempted to disable kernel for %s but it was already disabled.",
  238. device_type,
  239. )
  240. else:
  241. self._disabled_kernel.add(device_type)
  242. else: # enable the kernel
  243. if not originally_disabled:
  244. log.warning(
  245. "Attempted to enable kernel for %s but it was already enabled.",
  246. device_type,
  247. )
  248. else:
  249. self._disabled_kernel.remove(device_type)
  250. try:
  251. yield
  252. finally:
  253. # restore original state
  254. if originally_disabled:
  255. self._disabled_kernel.add(device_type)
  256. else:
  257. self._disabled_kernel.discard(device_type)
  258. def register_kernel(
  259. self, device_types: device_types_t, fn: Optional[Callable] = None, /
  260. ) -> Callable:
  261. """Register an implementation for a device type for this operator.
  262. Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
  263. This API may be used as a decorator.
  264. Args:
  265. fn (Callable): The function to register as the implementation for
  266. the given device types.
  267. device_types (str | Sequence[str]): The device device_types to register an impl to.
  268. Examples::
  269. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  270. >>> import torch
  271. >>> from torch import Tensor
  272. >>> from torch.library import custom_op
  273. >>> import numpy as np
  274. >>>
  275. >>> # Create a custom op that works on cpu
  276. >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
  277. >>> def numpy_sin(x: Tensor) -> Tensor:
  278. >>> x_np = x.numpy()
  279. >>> y_np = np.sin(x_np)
  280. >>> return torch.from_numpy(y_np)
  281. >>>
  282. >>> # Add implementations for the cuda device
  283. >>> @numpy_sin.register_kernel("cuda")
  284. >>> def _(x):
  285. >>> x_np = x.cpu().numpy()
  286. >>> y_np = np.sin(x_np)
  287. >>> return torch.from_numpy(y_np).to(device=x.device)
  288. >>>
  289. >>> x_cpu = torch.randn(3)
  290. >>> x_cuda = x_cpu.cuda()
  291. >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
  292. >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
  293. """
  294. def inner(fn):
  295. if device_types is None or isinstance(device_types, str):
  296. dtypes: list[Union[str, None]] = [device_types]
  297. else:
  298. dtypes = list(device_types)
  299. for device_type in dtypes:
  300. if device_type not in self._backend_fns:
  301. def backend_impl(*args, **kwargs):
  302. result = self._backend_fns[device_type](*args, **kwargs)
  303. def get_module():
  304. fn = self._backend_fns[device_type]
  305. return inspect.getmodule(fn)
  306. schema = self._opoverload._schema
  307. if not schema._is_view_op():
  308. utils._c_check_aliasing_constraint(
  309. self._name,
  310. args,
  311. kwargs,
  312. result,
  313. get_module,
  314. )
  315. return result
  316. if device_type is None:
  317. self._lib.impl(
  318. self._name, backend_impl, "CompositeExplicitAutograd"
  319. )
  320. else:
  321. self._lib.impl(
  322. self._name,
  323. backend_impl,
  324. _C._dispatch_key_for_device(device_type),
  325. )
  326. # Wrap function to choose between the default implementation or the device-specific
  327. # implementation depending on if the kernel is disabled.
  328. @torch._disable_dynamo
  329. def wrapped_fn(*args, **kwargs):
  330. if device_type in self._disabled_kernel:
  331. return self._init_fn(*args, **kwargs)
  332. else:
  333. return fn(*args, **kwargs)
  334. self._backend_fns[device_type] = wrapped_fn
  335. return fn
  336. if device_types is not None and not utils.has_tensor_arg(
  337. self._opoverload._schema
  338. ):
  339. device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
  340. if device_arg_index is None:
  341. raise ValueError(
  342. "Functions without tensor inputs are required to have a `device: torch.device` argument"
  343. )
  344. self._register_backend_select_dispatcher(device_arg_index)
  345. # See NOTE: [Supporting decorator and non-decorator usage]
  346. if fn is None:
  347. return inner
  348. return inner(fn)
  349. def register_fake(self, fn: Callable, /) -> Callable:
  350. r"""Register a FakeTensor implementation for this custom op.
  351. This is necessary to get the operator to work efficiently with torch.compile.
  352. The Fake impl (sometimes also known as a meta kernel or abstract impl)
  353. specifies the behavior of this operator on Tensors that carry no data.
  354. Given some input Tensors with certain properties
  355. (sizes/strides/storage_offset/device), it specifies what the properties of
  356. the output Tensors are.
  357. Please see :func:`torch.library.register_fake` for more details.
  358. Args:
  359. fn (Callable): The function to register as the FakeTensor
  360. implementation.
  361. Examples:
  362. >>> import torch
  363. >>> import numpy as np
  364. >>> from torch import Tensor
  365. >>>
  366. >>> # Example 1: an operator without data-dependent output shape
  367. >>> @torch.library.custom_op("mylib::linear", mutates_args=())
  368. >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
  369. >>> return (x @ weight.t()) + bias
  370. >>>
  371. >>> @linear.register_fake
  372. >>> def _(x, weight, bias):
  373. >>> assert x.dim() == 2
  374. >>> assert weight.dim() == 2
  375. >>> assert bias.dim() == 1
  376. >>> assert x.shape[1] == weight.shape[1]
  377. >>> assert weight.shape[0] == bias.shape[0]
  378. >>> assert x.device == weight.device
  379. >>> return x.new_empty(x.size(0), weight.size(0))
  380. >>>
  381. >>> x = torch.randn(2, 2)
  382. >>> weight = torch.randn(2, 2)
  383. >>> bias = torch.randn(2)
  384. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  385. >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
  386. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  387. >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
  388. >>>
  389. >>> # Example 2: an operator with data-dependent output shape
  390. >>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
  391. >>> def nonzero(x: Tensor) -> Tensor:
  392. >>> x_np = x.cpu().numpy()
  393. >>> res = np.stack(np.nonzero(x_np), axis=1)
  394. >>> return torch.tensor(res, device=x.device)
  395. >>>
  396. >>> @nonzero.register_fake
  397. >>> def _(x):
  398. >>> # Number of nonzero-elements is data-dependent.
  399. >>> # Since we cannot peek at the data in an abstract impl,
  400. >>> # we use the ctx object to construct a new symint that
  401. >>> # represents the data-dependent size.
  402. >>> ctx = torch.library.get_ctx()
  403. >>> nnz = ctx.new_dynamic_size()
  404. >>> shape = [nnz, x.dim()]
  405. >>> result = x.new_empty(shape, dtype=torch.int64)
  406. >>> return result
  407. >>>
  408. >>> x = torch.tensor([0, 1, 2, 0, 0, 1])
  409. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  410. >>> out = torch.compile(nonzero, fullgraph=True)(x)
  411. >>> # xdoctest: +SKIP("Requires Python <= 3.11")
  412. >>> assert torch.allclose(out, x.nonzero())
  413. """
  414. self._abstract_fn = fn
  415. return fn
  416. def register_effect(self, effect: Optional[EffectType]) -> None:
  417. self._lib._register_effectful_op(self._qualname, effect)
  418. def register_torch_dispatch(
  419. self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
  420. ) -> Callable:
  421. r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
  422. This allows for open registration to specify the behavior between the operator
  423. and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
  424. or the operator directly.
  425. Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
  426. """
  427. def register(fn):
  428. if torch_dispatch_class not in self._torch_dispatch_fns:
  429. def inner(*args, **kwargs):
  430. return self._torch_dispatch_fns[torch_dispatch_class](
  431. *args, **kwargs
  432. )
  433. self._lib._register_torch_dispatch_rule(
  434. self._name, torch_dispatch_class, inner
  435. )
  436. self._torch_dispatch_fns[torch_dispatch_class] = fn
  437. return fn
  438. if fn is None:
  439. return register
  440. else:
  441. return register(fn)
  442. def register_autograd(
  443. self,
  444. backward: Callable,
  445. /,
  446. *,
  447. setup_context: Optional[Callable] = None,
  448. ) -> None:
  449. r"""Register a backward formula for this custom op.
  450. In order for an operator to work with autograd, you need to register
  451. a backward formula:
  452. 1. You must tell us how to compute gradients during the backward pass
  453. by providing us a "backward" function.
  454. 2. If you need any values from the forward to compute gradients, you can
  455. use `setup_context` to save values for backward.
  456. ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
  457. - ``grads`` is one or more gradients. The number of gradients matches
  458. the number of outputs of the operator.
  459. The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
  460. :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
  461. same as :meth:`torch.autograd.Function.backward`.
  462. ``setup_context(ctx, inputs, output)`` runs during the forward pass.
  463. Please save quantities needed for backward onto the ``ctx`` object via
  464. either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
  465. or assigning them as attributes of ``ctx``. If your custom op has
  466. kwarg-only arguments, we expect the signature of ``setup_context``
  467. to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
  468. Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
  469. they may not directly access :meth:`torch.Tensor.data_ptr` and they must
  470. not depend on or mutate global state. If you need a non-traceable backward,
  471. you can make it a separate custom_op that you call inside ``backward_fn``.
  472. If you need different autograd behavior on different devices, then we
  473. recommend creating two different custom operators, one for each device
  474. that needs different behavior, and switching between them at runtime.
  475. Examples:
  476. >>> import torch
  477. >>> import numpy as np
  478. >>> from torch import Tensor
  479. >>>
  480. >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
  481. >>> def numpy_sin(x: Tensor) -> Tensor:
  482. >>> x_np = x.cpu().numpy()
  483. >>> y_np = np.sin(x_np)
  484. >>> return torch.from_numpy(y_np).to(device=x.device)
  485. >>>
  486. >>> def setup_context(ctx, inputs, output) -> Tensor:
  487. >>> x, = inputs
  488. >>> ctx.save_for_backward(x)
  489. >>>
  490. >>> def backward(ctx, grad):
  491. >>> x, = ctx.saved_tensors
  492. >>> return grad * x.cos()
  493. >>>
  494. >>> numpy_sin.register_autograd(backward, setup_context=setup_context)
  495. >>>
  496. >>> x = torch.randn(3, requires_grad=True)
  497. >>> y = numpy_sin(x)
  498. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  499. >>> assert torch.allclose(grad_x, x.cos())
  500. >>>
  501. >>> # Example with a keyword-only arg
  502. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  503. >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
  504. >>> x_np = x.cpu().numpy()
  505. >>> y_np = x_np * val
  506. >>> return torch.from_numpy(y_np).to(device=x.device)
  507. >>>
  508. >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
  509. >>> ctx.val = keyword_only_inputs["val"]
  510. >>>
  511. >>> def backward(ctx, grad):
  512. >>> return grad * ctx.val
  513. >>>
  514. >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
  515. >>>
  516. >>> x = torch.randn(3, requires_grad=True)
  517. >>> y = numpy_mul(x, val=3.14)
  518. >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
  519. >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
  520. """
  521. schema = self._opoverload._schema
  522. if not utils.is_functional_schema(schema, allow_valid_view=True):
  523. raise RuntimeError(
  524. f"Cannot register autograd formula for non-functional operator "
  525. f"{self} with schema {schema}. Please create "
  526. f"a functional operator and register an autograd formula for that."
  527. )
  528. self._backward_fn = backward
  529. self._setup_context_fn = setup_context
  530. def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
  531. lib = self._lib
  532. schema_str = self._name + self._schema
  533. cpp_schema = _C.parse_schema(schema_str)
  534. if utils.has_kwarg_only_tensors(cpp_schema):
  535. # If you want to support this, the progression is:
  536. # - supporting kwarg-only Tensors that are non-differentiable
  537. # - supporting kwarg-only Tensors (regardless of differentiability)
  538. raise NotImplementedError(
  539. f"custom_op with kwarg-only Tensor args. Please make your "
  540. f"tensors not kwarg-only. Got: {schema_str}"
  541. )
  542. lib.define(
  543. schema_str,
  544. tags=[_C.Tag.pt2_compliant_tag, *tags],
  545. )
  546. self._opoverload = utils.lookup_op(self._qualname)
  547. def fake_impl(*args, **kwargs):
  548. if self._abstract_fn is None:
  549. if utils.can_generate_trivial_fake_impl(self._opoverload):
  550. return None
  551. raise RuntimeError(
  552. f"There was no fake impl registered for {self}. "
  553. f"This is necessary for torch.compile/export/fx tracing to work. "
  554. f"Please use `{self._init_fn.__name__}.register_fake` to add an "
  555. f"fake impl."
  556. )
  557. return self._abstract_fn(*args, **kwargs)
  558. lib._register_fake(self._name, fake_impl, _stacklevel=4)
  559. autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
  560. lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
  561. schema = self._opoverload._schema
  562. if schema._is_view_op() or schema.is_mutable:
  563. lib.m.register_ad_inplace_or_view_fallback(self._name) # type: ignore[union-attr]
  564. if schema.is_mutable:
  565. mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
  566. original_kernel = torch._C._dispatch_get_computed_kernel_for_dispatch_key(
  567. f"{lib.ns}::{self._name}", "ADInplaceOrView"
  568. )
  569. def adinplaceorview_impl(keyset, *args, **kwargs):
  570. # Handle the mutated idx the user gave us explicitly
  571. for idx in mutated_idxs:
  572. increment_version(args[idx])
  573. for key in mutated_keys:
  574. increment_version(kwargs[key])
  575. # Handle view + mutation that are in the schema
  576. return original_kernel.call_boxed(keyset, *args, **kwargs)
  577. with warnings.catch_warnings():
  578. warnings.filterwarnings(
  579. "ignore",
  580. message="Warning only once for all operators",
  581. category=UserWarning,
  582. )
  583. lib.impl(
  584. self._name,
  585. adinplaceorview_impl,
  586. "ADInplaceOrView",
  587. with_keyset=True,
  588. )
  589. def _register_backend_select_dispatcher(self, device_arg_index: int):
  590. """
  591. Switch on the device argument to select the correct backend to dispatch to.
  592. """
  593. def backend_select(keyset, *args, **kwargs):
  594. device = args[device_arg_index].type
  595. if device not in self._backend_fns:
  596. raise RuntimeError(
  597. f"{self._name} does not have a kernel registered for {device}. "
  598. "Please use register_kernel to do so."
  599. )
  600. dispatch_key = _C._dispatch_key_for_device(device)
  601. dispatch_key = getattr(_C.DispatchKey, dispatch_key)
  602. return self._opoverload.redispatch(
  603. _C.DispatchKeySet(dispatch_key), *args, **kwargs
  604. )
  605. self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
  606. def __call__(self, *args, **kwargs):
  607. return self._opoverload(*args, **kwargs)
  608. def register_vmap(
  609. self,
  610. func: Optional[Callable] = None,
  611. ):
  612. r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
  613. This API may be used as a decorator.
  614. In order for an operator to work with :func:`torch.vmap`, you may need to register a
  615. vmap implementation in the following signature:
  616. ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
  617. where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
  618. It specifies how do we compute the batched version of ``op`` given inputs with an additional
  619. dimension (specified by ``in_dims``).
  620. For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
  621. if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
  622. specifying what dimension of the Tensor is being vmapped over.
  623. ``info`` is a collection of additional metadata that may be helpful:
  624. ``info.batch_size`` specifies the size of the dimension being vmapped over, while
  625. ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
  626. The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
  627. ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
  628. per output that specifies if the output has the vmapped dimension and what index it is in.
  629. Examples:
  630. >>> import torch
  631. >>> import numpy as np
  632. >>> from torch import Tensor
  633. >>> from typing import Tuple
  634. >>>
  635. >>> def to_numpy(tensor):
  636. >>> return tensor.cpu().numpy()
  637. >>>
  638. >>> lib = torch.library.Library("mylib", "FRAGMENT")
  639. >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
  640. >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
  641. >>> x_np = to_numpy(x)
  642. >>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
  643. >>> return torch.tensor(x_np ** 3, device=x.device), dx
  644. >>>
  645. >>> def numpy_cube_vmap(info, in_dims, x):
  646. >>> result = numpy_cube(x)
  647. >>> return result, (in_dims[0], in_dims[0])
  648. >>>
  649. >>> numpy_cube.register_vmap(numpy_cube_vmap)
  650. >>>
  651. >>> x = torch.randn(3)
  652. >>> torch.vmap(numpy_cube)(x)
  653. >>>
  654. >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
  655. >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
  656. >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
  657. >>>
  658. >>> @numpy_mul.register_vmap
  659. >>> def numpy_mul_vmap(info, in_dims, x, y):
  660. >>> x_bdim, y_bdim = in_dims
  661. >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
  662. >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
  663. >>> result = x * y
  664. >>> result = result.movedim(-1, 0)
  665. >>> return result, 0
  666. >>>
  667. >>>
  668. >>> x = torch.randn(3)
  669. >>> y = torch.randn(3)
  670. >>> torch.vmap(numpy_mul)(x, y)
  671. """
  672. from torch._functorch.autograd_function import custom_function_call_vmap_helper
  673. from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
  674. def register(func):
  675. need_register = self._vmap_fn is None
  676. self._vmap_fn = func
  677. if need_register:
  678. def wrapped_func(keyset, *args, **kwargs):
  679. interpreter = retrieve_current_functorch_interpreter()
  680. return custom_function_call_vmap_helper(
  681. # pyrefly: ignore[bad-argument-type]
  682. interpreter,
  683. # pyrefly: ignore[bad-argument-type]
  684. self._vmap_fn,
  685. self._opoverload,
  686. *args,
  687. **kwargs,
  688. )
  689. self._lib.impl(
  690. self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
  691. )
  692. if func is None:
  693. return register
  694. else:
  695. return register(func)
  696. def register_autocast(
  697. self,
  698. device_type: str,
  699. cast_inputs: _dtype,
  700. ):
  701. r"""Register an autocast dispatch rule for this custom op.
  702. Valid `device_type` include: "cpu" and "cuda".
  703. Args:
  704. op (str | OpOverload): The operator to register an autocast dispatch rule to.
  705. device_type(str): Device type to use. 'cuda' or 'cpu'.
  706. The type is the same as the `type` attribute of a :class:`torch.device`.
  707. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  708. cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
  709. casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
  710. are not affected), then executes custom op with autocast disabled.
  711. lib (Optional[Library]): If provided, the lifetime of this registration
  712. Examples::
  713. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  714. >>> import torch
  715. >>> from torch import Tensor
  716. >>> from torch.library import custom_op
  717. >>>
  718. >>> # Create a custom op that works on cuda
  719. >>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
  720. >>> def my_sin(x: Tensor) -> Tensor:
  721. >>> return torch.sin(x)
  722. >>>
  723. >>> # Register autocast dispatch rule for the cuda device
  724. >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
  725. >>>
  726. >>> x = torch.randn(3, dtype=torch.float32, device="cuda")
  727. >>> with torch.autocast("cuda", dtype=torch.float16):
  728. >>> y = torch.ops.mylib.my_sin(x)
  729. >>> assert y.dtype == torch.float16
  730. """
  731. if not isinstance(device_type, str):
  732. raise ValueError(
  733. f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
  734. )
  735. if device_type not in ["cpu", "cuda"]:
  736. raise ValueError(f"Unknown device type: {device_type}")
  737. need_register_cuda = self._autocast_cuda_dtype is None
  738. need_register_cpu = self._autocast_cpu_dtype is None
  739. if device_type == "cuda":
  740. self._autocast_cuda_dtype = cast_inputs
  741. else:
  742. self._autocast_cpu_dtype = cast_inputs
  743. def kernel(_, *args, **kwargs):
  744. if len(kwargs) != 0:
  745. raise AssertionError(
  746. f"Custom ops do not support kwargs yet, got {list(kwargs.keys())}"
  747. )
  748. autocast_keyset = torch._C.DispatchKeySet(
  749. torch._C.DispatchKey.AutocastCPU
  750. ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
  751. with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
  752. return self._opoverload(*_cast(args, device_type, cast_inputs))
  753. if need_register_cuda and self._autocast_cuda_dtype:
  754. self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True)
  755. elif need_register_cpu and self._autocast_cpu_dtype:
  756. self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True)
  757. return kernel
  758. # TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it
  759. # into a utility function once custom ops support arbitrary input types.
  760. def _cast(value, device_type: str, dtype: _dtype):
  761. if isinstance(value, torch.Tensor):
  762. is_eligible = (
  763. value.is_floating_point()
  764. and value.device.type == device_type
  765. and (value.dtype is not torch.float64)
  766. )
  767. return value.to(dtype) if is_eligible else value
  768. elif isinstance(value, (str, bytes)):
  769. return value
  770. elif isinstance(value, collections.abc.Iterable):
  771. iterable = (_cast(v, device_type, dtype) for v in value)
  772. if isinstance(value, (list, tuple)):
  773. return type(value)(iterable)
  774. else:
  775. return iterable
  776. else:
  777. return value
  778. def increment_version(val: Any) -> None:
  779. if isinstance(val, Tensor):
  780. torch.autograd.graph.increment_version(val)
  781. elif isinstance(val, (tuple, list)):
  782. for v in val:
  783. if isinstance(v, Tensor):
  784. torch.autograd.graph.increment_version(v)
  785. # NOTE: [Supporting decorator and non-decorator usage]
  786. #
  787. # Some APIs may be both used as a decorator and not as a decorator.
  788. # For example:
  789. #
  790. # >>> def fn(x):
  791. # >>> return x.sin()
  792. # >>>
  793. # >>> # Usage 1: not as a decorator
  794. # >>> numpy_sin.register_kernel("cuda", fn)
  795. # >>>
  796. # >>> # Usage 2: as a decorator
  797. # >>> @numpy_sin.register_kernel("cuda")
  798. # >>> def fn2(x):
  799. # >>> return x.sin
  800. #
  801. # The way we support this is that `register_kernel` accepts an optional `fn`.
  802. # If `fn` is provided (Usage 1), then we know that the user is using it not
  803. # as a decorator.
  804. # If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
  805. # decorator.
  806. OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {}
  807. OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
  808. def get_library_allowing_overwrite(
  809. namespace: str, name: str
  810. ) -> "torch.library.Library":
  811. qualname = f"{namespace}::{name}"
  812. if qualname in OPDEF_TO_LIB:
  813. OPDEF_TO_LIB[qualname]._destroy()
  814. del OPDEF_TO_LIB[qualname]
  815. lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
  816. OPDEF_TO_LIB[qualname] = lib
  817. return lib
  818. def _maybe_get_opdef(
  819. op: Union[CustomOpDef, _ops.OpOverload, str],
  820. ) -> Optional[CustomOpDef]:
  821. if isinstance(op, CustomOpDef):
  822. return op
  823. if isinstance(op, _ops.OpOverload):
  824. op = op._name
  825. if not isinstance(op, str):
  826. raise AssertionError(f"op must be str, got {type(op)}")
  827. if op in OPDEFS:
  828. return OPDEFS[op]
  829. return None