| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958 |
- # mypy: allow-untyped-defs
- import collections
- import inspect
- import logging
- import warnings
- import weakref
- from collections.abc import Callable, Iterable, Sequence
- from contextlib import contextmanager
- from typing import Any, Optional, overload, Union
- import torch
- from torch import _C, _ops, Tensor
- from torch.types import _dtype
- from torch.utils._exposed_in import exposed_in
- from . import autograd, utils
- from .effects import EffectType
- device_types_t = Optional[Union[str, Sequence[str]]]
- log = logging.getLogger(__name__)
- @overload
- def custom_op(
- name: str,
- fn: None = None,
- /,
- *,
- mutates_args: Union[str, Iterable[str]],
- device_types: device_types_t = None,
- schema: Optional[str] = None,
- tags: Optional[Sequence[_C.Tag]] = None,
- ) -> Callable[[Callable[..., object]], "CustomOpDef"]: ...
- @overload
- def custom_op(
- name: str,
- fn: Callable[..., object],
- /,
- *,
- mutates_args: Union[str, Iterable[str]],
- device_types: device_types_t = None,
- schema: Optional[str] = None,
- tags: Optional[Sequence[_C.Tag]] = None,
- ) -> "CustomOpDef": ...
- @exposed_in("torch.library")
- def custom_op(
- name: str,
- fn: Optional[Callable] = None,
- /,
- *,
- mutates_args: Union[str, Iterable[str]],
- device_types: device_types_t = None,
- schema: Optional[str] = None,
- tags: Optional[Sequence[_C.Tag]] = None,
- ) -> Union[Callable[[Callable[..., object]], "CustomOpDef"], "CustomOpDef"]:
- """Wraps a function into custom operator.
- Reasons why you may want to create a custom op include:
- - Wrapping a third-party library or custom kernel to work with PyTorch
- subsystems like Autograd.
- - Preventing torch.compile/export/FX tracing from peeking inside your function.
- This API is used as a decorator around a function (please see examples).
- The provided function must have type hints; these are needed to interface
- with PyTorch's various subsystems.
- Args:
- name (str): A name for the custom op that looks like "{namespace}::{name}",
- e.g. "mylib::my_linear". The name is used as the op's stable identifier
- in PyTorch subsystems (e.g. torch.export, FX graphs).
- To avoid name collisions, please use your project name as the namespace;
- e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
- mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
- This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
- it pessimistically assumes that all inputs to the operator are being mutated.
- device_types (str | None | Sequence[str]): The device type(s) the function
- is valid for. If no device type is provided, then the function
- is used as the default implementation for all device types.
- Examples: "cpu", "cuda".
- When registering a device-specific implementation for an operator that accepts no Tensors,
- we require the operator to have a "device: torch.device argument".
- schema (str | None): A schema string for the operator. If None
- (recommended) we'll infer a schema for the operator from its type
- annotations. We recommend letting us infer a schema unless you
- have a specific reason not to.
- Example: "(Tensor x, int y) -> (Tensor, Tensor)".
- .. note::
- We recommend not passing in a ``schema`` arg and instead letting us infer
- it from the type annotations. It is error-prone to write your own schema.
- You may wish to provide your own schema if our interpretation of
- the type annotation is not what you want.
- For more info on how to write a schema string, see
- `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
- Examples::
- >>> import torch
- >>> from torch import Tensor
- >>> from torch.library import custom_op
- >>> import numpy as np
- >>>
- >>> @custom_op("mylib::numpy_sin", mutates_args=())
- >>> def numpy_sin(x: Tensor) -> Tensor:
- >>> x_np = x.cpu().numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np).to(device=x.device)
- >>>
- >>> x = torch.randn(3)
- >>> y = numpy_sin(x)
- >>> assert torch.allclose(y, x.sin())
- >>>
- >>> # Example of a custom op that only works for one device type.
- >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
- >>> def numpy_sin_cpu(x: Tensor) -> Tensor:
- >>> x_np = x.numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np)
- >>>
- >>> x = torch.randn(3)
- >>> y = numpy_sin_cpu(x)
- >>> assert torch.allclose(y, x.sin())
- >>>
- >>> # Example of a custom op that mutates an input
- >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
- >>> def numpy_sin_inplace(x: Tensor) -> None:
- >>> x_np = x.numpy()
- >>> np.sin(x_np, out=x_np)
- >>>
- >>> x = torch.randn(3)
- >>> expected = x.sin()
- >>> numpy_sin_inplace(x)
- >>> assert torch.allclose(x, expected)
- >>>
- >>> # Example of a factory function
- >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
- >>> def bar(device: torch.device) -> Tensor:
- >>> return torch.ones(3)
- >>>
- >>> bar("cpu")
- """
- def inner(fn: Callable[..., object]) -> CustomOpDef:
- import torch
- if schema is None:
- schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
- else:
- schema_str = schema
- namespace, opname = name.split("::")
- result = CustomOpDef(namespace, opname, schema_str, fn, tags)
- if schema is not None:
- # Check that schema's alias annotations match those of `mutates_args`.
- expected = set()
- for arg in result._opoverload._schema.arguments:
- if arg.alias_info is not None and arg.alias_info.is_write:
- expected.add(arg.name)
- if expected != set(mutates_args):
- raise ValueError(
- f"Attempted to create a custom op with `mutates_args={mutates_args}` "
- f"and `schema={schema}. The schema suggests that the op mutates {expected}"
- f"which is different from what was provided to us in `mutates_args`. "
- f"Please make these consistent."
- )
- result.register_kernel(device_types)(fn)
- return result
- if fn is None:
- return inner
- return inner(fn)
- class CustomOpDef:
- """CustomOpDef is a wrapper around a function that turns it into a custom op.
- It has various methods for registering additional behavior for this
- custom op.
- You should not instantiate CustomOpDef directly; instead, use the
- :func:`torch.library.custom_op` API.
- """
- def __init__(
- self,
- namespace: str,
- name: str,
- schema: str,
- fn: Callable,
- tags: Optional[Sequence[_C.Tag]] = None,
- ) -> None:
- # Fields used to interface with the PyTorch dispatcher
- self._namespace = namespace
- self._name = name
- self._schema = schema
- self._tags = tags if tags is not None else []
- self._init_fn = fn
- self._backend_fns: dict[Union[str, None], Callable] = {}
- self._abstract_fn: Optional[Callable] = None
- self._setup_context_fn: Optional[Callable] = None
- self._backward_fn: Optional[Callable] = None
- self._torch_dispatch_fns: dict[type, Callable] = {}
- self._vmap_fn: Optional[Callable] = None
- self._autocast_cuda_dtype: Optional[_dtype] = None
- self._autocast_cpu_dtype: Optional[_dtype] = None
- self._lib = get_library_allowing_overwrite(self._namespace, self._name)
- self._register_to_dispatcher(self._tags)
- self._disabled_kernel: set = set()
- self._used_triton_kernels: list[Any] = list()
- OPDEFS[self._qualname] = self
- @property
- def _qualname(self) -> str:
- return f"{self._namespace}::{self._name}"
- def __repr__(self) -> str:
- return f"<CustomOpDef({self._qualname})>"
- @contextmanager
- def set_kernel_enabled(self, device_type: str, enabled: bool = True):
- """
- Disable or re-enable an already registered kernel for this custom operator.
- If the kernel is already disabled/enabled, this is a no-op.
- Note:
- If a kernel is first disabled and then registered, it is disabled until enabled again.
- Args:
- device_type (str): The device type to disable/enable the kernel for.
- disable (bool): Whether to disable or enable the kernel.
- Example:
- >>> inp = torch.randn(1)
- >>>
- >>> # define custom op `f`.
- >>> @custom_op("mylib::f", mutates_args=())
- >>> def f(x: Tensor) -> Tensor:
- >>> return torch.zeros(1)
- >>>
- >>> print(f(inp)) # tensor([0.]), default kernel
- >>>
- >>> @f.register_kernel("cpu")
- >>> def _(x):
- >>> return torch.ones(1)
- >>>
- >>> print(f(inp)) # tensor([1.]), CPU kernel
- >>>
- >>> # temporarily disable the CPU kernel
- >>> with f.set_kernel_enabled("cpu", enabled = False):
- >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
- """
- action = "enable" if enabled else "disable"
- originally_disabled = device_type in self._disabled_kernel
- if device_type not in self._backend_fns:
- log.warning(
- "Attempted to %s kernel for %s but no kernel was registered for this device type.",
- action,
- device_type,
- )
- if not enabled:
- if originally_disabled:
- log.warning(
- "Attempted to disable kernel for %s but it was already disabled.",
- device_type,
- )
- else:
- self._disabled_kernel.add(device_type)
- else: # enable the kernel
- if not originally_disabled:
- log.warning(
- "Attempted to enable kernel for %s but it was already enabled.",
- device_type,
- )
- else:
- self._disabled_kernel.remove(device_type)
- try:
- yield
- finally:
- # restore original state
- if originally_disabled:
- self._disabled_kernel.add(device_type)
- else:
- self._disabled_kernel.discard(device_type)
- def register_kernel(
- self, device_types: device_types_t, fn: Optional[Callable] = None, /
- ) -> Callable:
- """Register an implementation for a device type for this operator.
- Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
- This API may be used as a decorator.
- Args:
- fn (Callable): The function to register as the implementation for
- the given device types.
- device_types (str | Sequence[str]): The device device_types to register an impl to.
- Examples::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> from torch import Tensor
- >>> from torch.library import custom_op
- >>> import numpy as np
- >>>
- >>> # Create a custom op that works on cpu
- >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
- >>> def numpy_sin(x: Tensor) -> Tensor:
- >>> x_np = x.numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np)
- >>>
- >>> # Add implementations for the cuda device
- >>> @numpy_sin.register_kernel("cuda")
- >>> def _(x):
- >>> x_np = x.cpu().numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np).to(device=x.device)
- >>>
- >>> x_cpu = torch.randn(3)
- >>> x_cuda = x_cpu.cuda()
- >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
- >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
- """
- def inner(fn):
- if device_types is None or isinstance(device_types, str):
- dtypes: list[Union[str, None]] = [device_types]
- else:
- dtypes = list(device_types)
- for device_type in dtypes:
- if device_type not in self._backend_fns:
- def backend_impl(*args, **kwargs):
- result = self._backend_fns[device_type](*args, **kwargs)
- def get_module():
- fn = self._backend_fns[device_type]
- return inspect.getmodule(fn)
- schema = self._opoverload._schema
- if not schema._is_view_op():
- utils._c_check_aliasing_constraint(
- self._name,
- args,
- kwargs,
- result,
- get_module,
- )
- return result
- if device_type is None:
- self._lib.impl(
- self._name, backend_impl, "CompositeExplicitAutograd"
- )
- else:
- self._lib.impl(
- self._name,
- backend_impl,
- _C._dispatch_key_for_device(device_type),
- )
- # Wrap function to choose between the default implementation or the device-specific
- # implementation depending on if the kernel is disabled.
- @torch._disable_dynamo
- def wrapped_fn(*args, **kwargs):
- if device_type in self._disabled_kernel:
- return self._init_fn(*args, **kwargs)
- else:
- return fn(*args, **kwargs)
- self._backend_fns[device_type] = wrapped_fn
- return fn
- if device_types is not None and not utils.has_tensor_arg(
- self._opoverload._schema
- ):
- device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
- if device_arg_index is None:
- raise ValueError(
- "Functions without tensor inputs are required to have a `device: torch.device` argument"
- )
- self._register_backend_select_dispatcher(device_arg_index)
- # See NOTE: [Supporting decorator and non-decorator usage]
- if fn is None:
- return inner
- return inner(fn)
- def register_fake(self, fn: Callable, /) -> Callable:
- r"""Register a FakeTensor implementation for this custom op.
- This is necessary to get the operator to work efficiently with torch.compile.
- The Fake impl (sometimes also known as a meta kernel or abstract impl)
- specifies the behavior of this operator on Tensors that carry no data.
- Given some input Tensors with certain properties
- (sizes/strides/storage_offset/device), it specifies what the properties of
- the output Tensors are.
- Please see :func:`torch.library.register_fake` for more details.
- Args:
- fn (Callable): The function to register as the FakeTensor
- implementation.
- Examples:
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>>
- >>> # Example 1: an operator without data-dependent output shape
- >>> @torch.library.custom_op("mylib::linear", mutates_args=())
- >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
- >>> return (x @ weight.t()) + bias
- >>>
- >>> @linear.register_fake
- >>> def _(x, weight, bias):
- >>> assert x.dim() == 2
- >>> assert weight.dim() == 2
- >>> assert bias.dim() == 1
- >>> assert x.shape[1] == weight.shape[1]
- >>> assert weight.shape[0] == bias.shape[0]
- >>> assert x.device == weight.device
- >>> return x.new_empty(x.size(0), weight.size(0))
- >>>
- >>> x = torch.randn(2, 2)
- >>> weight = torch.randn(2, 2)
- >>> bias = torch.randn(2)
- >>> # xdoctest: +SKIP("Requires Python <= 3.11")
- >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
- >>> # xdoctest: +SKIP("Requires Python <= 3.11")
- >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
- >>>
- >>> # Example 2: an operator with data-dependent output shape
- >>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
- >>> def nonzero(x: Tensor) -> Tensor:
- >>> x_np = x.cpu().numpy()
- >>> res = np.stack(np.nonzero(x_np), axis=1)
- >>> return torch.tensor(res, device=x.device)
- >>>
- >>> @nonzero.register_fake
- >>> def _(x):
- >>> # Number of nonzero-elements is data-dependent.
- >>> # Since we cannot peek at the data in an abstract impl,
- >>> # we use the ctx object to construct a new symint that
- >>> # represents the data-dependent size.
- >>> ctx = torch.library.get_ctx()
- >>> nnz = ctx.new_dynamic_size()
- >>> shape = [nnz, x.dim()]
- >>> result = x.new_empty(shape, dtype=torch.int64)
- >>> return result
- >>>
- >>> x = torch.tensor([0, 1, 2, 0, 0, 1])
- >>> # xdoctest: +SKIP("Requires Python <= 3.11")
- >>> out = torch.compile(nonzero, fullgraph=True)(x)
- >>> # xdoctest: +SKIP("Requires Python <= 3.11")
- >>> assert torch.allclose(out, x.nonzero())
- """
- self._abstract_fn = fn
- return fn
- def register_effect(self, effect: Optional[EffectType]) -> None:
- self._lib._register_effectful_op(self._qualname, effect)
- def register_torch_dispatch(
- self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
- ) -> Callable:
- r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
- This allows for open registration to specify the behavior between the operator
- and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
- or the operator directly.
- Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
- """
- def register(fn):
- if torch_dispatch_class not in self._torch_dispatch_fns:
- def inner(*args, **kwargs):
- return self._torch_dispatch_fns[torch_dispatch_class](
- *args, **kwargs
- )
- self._lib._register_torch_dispatch_rule(
- self._name, torch_dispatch_class, inner
- )
- self._torch_dispatch_fns[torch_dispatch_class] = fn
- return fn
- if fn is None:
- return register
- else:
- return register(fn)
- def register_autograd(
- self,
- backward: Callable,
- /,
- *,
- setup_context: Optional[Callable] = None,
- ) -> None:
- r"""Register a backward formula for this custom op.
- In order for an operator to work with autograd, you need to register
- a backward formula:
- 1. You must tell us how to compute gradients during the backward pass
- by providing us a "backward" function.
- 2. If you need any values from the forward to compute gradients, you can
- use `setup_context` to save values for backward.
- ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
- - ``grads`` is one or more gradients. The number of gradients matches
- the number of outputs of the operator.
- The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
- :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
- same as :meth:`torch.autograd.Function.backward`.
- ``setup_context(ctx, inputs, output)`` runs during the forward pass.
- Please save quantities needed for backward onto the ``ctx`` object via
- either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
- or assigning them as attributes of ``ctx``. If your custom op has
- kwarg-only arguments, we expect the signature of ``setup_context``
- to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
- Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
- they may not directly access :meth:`torch.Tensor.data_ptr` and they must
- not depend on or mutate global state. If you need a non-traceable backward,
- you can make it a separate custom_op that you call inside ``backward_fn``.
- If you need different autograd behavior on different devices, then we
- recommend creating two different custom operators, one for each device
- that needs different behavior, and switching between them at runtime.
- Examples:
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>>
- >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
- >>> def numpy_sin(x: Tensor) -> Tensor:
- >>> x_np = x.cpu().numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np).to(device=x.device)
- >>>
- >>> def setup_context(ctx, inputs, output) -> Tensor:
- >>> x, = inputs
- >>> ctx.save_for_backward(x)
- >>>
- >>> def backward(ctx, grad):
- >>> x, = ctx.saved_tensors
- >>> return grad * x.cos()
- >>>
- >>> numpy_sin.register_autograd(backward, setup_context=setup_context)
- >>>
- >>> x = torch.randn(3, requires_grad=True)
- >>> y = numpy_sin(x)
- >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
- >>> assert torch.allclose(grad_x, x.cos())
- >>>
- >>> # Example with a keyword-only arg
- >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
- >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
- >>> x_np = x.cpu().numpy()
- >>> y_np = x_np * val
- >>> return torch.from_numpy(y_np).to(device=x.device)
- >>>
- >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
- >>> ctx.val = keyword_only_inputs["val"]
- >>>
- >>> def backward(ctx, grad):
- >>> return grad * ctx.val
- >>>
- >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
- >>>
- >>> x = torch.randn(3, requires_grad=True)
- >>> y = numpy_mul(x, val=3.14)
- >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
- >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
- """
- schema = self._opoverload._schema
- if not utils.is_functional_schema(schema, allow_valid_view=True):
- raise RuntimeError(
- f"Cannot register autograd formula for non-functional operator "
- f"{self} with schema {schema}. Please create "
- f"a functional operator and register an autograd formula for that."
- )
- self._backward_fn = backward
- self._setup_context_fn = setup_context
- def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
- lib = self._lib
- schema_str = self._name + self._schema
- cpp_schema = _C.parse_schema(schema_str)
- if utils.has_kwarg_only_tensors(cpp_schema):
- # If you want to support this, the progression is:
- # - supporting kwarg-only Tensors that are non-differentiable
- # - supporting kwarg-only Tensors (regardless of differentiability)
- raise NotImplementedError(
- f"custom_op with kwarg-only Tensor args. Please make your "
- f"tensors not kwarg-only. Got: {schema_str}"
- )
- lib.define(
- schema_str,
- tags=[_C.Tag.pt2_compliant_tag, *tags],
- )
- self._opoverload = utils.lookup_op(self._qualname)
- def fake_impl(*args, **kwargs):
- if self._abstract_fn is None:
- if utils.can_generate_trivial_fake_impl(self._opoverload):
- return None
- raise RuntimeError(
- f"There was no fake impl registered for {self}. "
- f"This is necessary for torch.compile/export/fx tracing to work. "
- f"Please use `{self._init_fn.__name__}.register_fake` to add an "
- f"fake impl."
- )
- return self._abstract_fn(*args, **kwargs)
- lib._register_fake(self._name, fake_impl, _stacklevel=4)
- autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
- lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
- schema = self._opoverload._schema
- if schema._is_view_op() or schema.is_mutable:
- lib.m.register_ad_inplace_or_view_fallback(self._name) # type: ignore[union-attr]
- if schema.is_mutable:
- mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
- original_kernel = torch._C._dispatch_get_computed_kernel_for_dispatch_key(
- f"{lib.ns}::{self._name}", "ADInplaceOrView"
- )
- def adinplaceorview_impl(keyset, *args, **kwargs):
- # Handle the mutated idx the user gave us explicitly
- for idx in mutated_idxs:
- increment_version(args[idx])
- for key in mutated_keys:
- increment_version(kwargs[key])
- # Handle view + mutation that are in the schema
- return original_kernel.call_boxed(keyset, *args, **kwargs)
- with warnings.catch_warnings():
- warnings.filterwarnings(
- "ignore",
- message="Warning only once for all operators",
- category=UserWarning,
- )
- lib.impl(
- self._name,
- adinplaceorview_impl,
- "ADInplaceOrView",
- with_keyset=True,
- )
- def _register_backend_select_dispatcher(self, device_arg_index: int):
- """
- Switch on the device argument to select the correct backend to dispatch to.
- """
- def backend_select(keyset, *args, **kwargs):
- device = args[device_arg_index].type
- if device not in self._backend_fns:
- raise RuntimeError(
- f"{self._name} does not have a kernel registered for {device}. "
- "Please use register_kernel to do so."
- )
- dispatch_key = _C._dispatch_key_for_device(device)
- dispatch_key = getattr(_C.DispatchKey, dispatch_key)
- return self._opoverload.redispatch(
- _C.DispatchKeySet(dispatch_key), *args, **kwargs
- )
- self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
- def __call__(self, *args, **kwargs):
- return self._opoverload(*args, **kwargs)
- def register_vmap(
- self,
- func: Optional[Callable] = None,
- ):
- r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
- This API may be used as a decorator.
- In order for an operator to work with :func:`torch.vmap`, you may need to register a
- vmap implementation in the following signature:
- ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
- where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
- It specifies how do we compute the batched version of ``op`` given inputs with an additional
- dimension (specified by ``in_dims``).
- For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
- if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
- specifying what dimension of the Tensor is being vmapped over.
- ``info`` is a collection of additional metadata that may be helpful:
- ``info.batch_size`` specifies the size of the dimension being vmapped over, while
- ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
- The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
- ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
- per output that specifies if the output has the vmapped dimension and what index it is in.
- Examples:
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>> from typing import Tuple
- >>>
- >>> def to_numpy(tensor):
- >>> return tensor.cpu().numpy()
- >>>
- >>> lib = torch.library.Library("mylib", "FRAGMENT")
- >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
- >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
- >>> x_np = to_numpy(x)
- >>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
- >>> return torch.tensor(x_np ** 3, device=x.device), dx
- >>>
- >>> def numpy_cube_vmap(info, in_dims, x):
- >>> result = numpy_cube(x)
- >>> return result, (in_dims[0], in_dims[0])
- >>>
- >>> numpy_cube.register_vmap(numpy_cube_vmap)
- >>>
- >>> x = torch.randn(3)
- >>> torch.vmap(numpy_cube)(x)
- >>>
- >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
- >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
- >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
- >>>
- >>> @numpy_mul.register_vmap
- >>> def numpy_mul_vmap(info, in_dims, x, y):
- >>> x_bdim, y_bdim = in_dims
- >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
- >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
- >>> result = x * y
- >>> result = result.movedim(-1, 0)
- >>> return result, 0
- >>>
- >>>
- >>> x = torch.randn(3)
- >>> y = torch.randn(3)
- >>> torch.vmap(numpy_mul)(x, y)
- """
- from torch._functorch.autograd_function import custom_function_call_vmap_helper
- from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
- def register(func):
- need_register = self._vmap_fn is None
- self._vmap_fn = func
- if need_register:
- def wrapped_func(keyset, *args, **kwargs):
- interpreter = retrieve_current_functorch_interpreter()
- return custom_function_call_vmap_helper(
- # pyrefly: ignore[bad-argument-type]
- interpreter,
- # pyrefly: ignore[bad-argument-type]
- self._vmap_fn,
- self._opoverload,
- *args,
- **kwargs,
- )
- self._lib.impl(
- self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
- )
- if func is None:
- return register
- else:
- return register(func)
- def register_autocast(
- self,
- device_type: str,
- cast_inputs: _dtype,
- ):
- r"""Register an autocast dispatch rule for this custom op.
- Valid `device_type` include: "cpu" and "cuda".
- Args:
- op (str | OpOverload): The operator to register an autocast dispatch rule to.
- device_type(str): Device type to use. 'cuda' or 'cpu'.
- The type is the same as the `type` attribute of a :class:`torch.device`.
- Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
- cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
- casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
- are not affected), then executes custom op with autocast disabled.
- lib (Optional[Library]): If provided, the lifetime of this registration
- Examples::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> from torch import Tensor
- >>> from torch.library import custom_op
- >>>
- >>> # Create a custom op that works on cuda
- >>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
- >>> def my_sin(x: Tensor) -> Tensor:
- >>> return torch.sin(x)
- >>>
- >>> # Register autocast dispatch rule for the cuda device
- >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
- >>>
- >>> x = torch.randn(3, dtype=torch.float32, device="cuda")
- >>> with torch.autocast("cuda", dtype=torch.float16):
- >>> y = torch.ops.mylib.my_sin(x)
- >>> assert y.dtype == torch.float16
- """
- if not isinstance(device_type, str):
- raise ValueError(
- f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
- )
- if device_type not in ["cpu", "cuda"]:
- raise ValueError(f"Unknown device type: {device_type}")
- need_register_cuda = self._autocast_cuda_dtype is None
- need_register_cpu = self._autocast_cpu_dtype is None
- if device_type == "cuda":
- self._autocast_cuda_dtype = cast_inputs
- else:
- self._autocast_cpu_dtype = cast_inputs
- def kernel(_, *args, **kwargs):
- if len(kwargs) != 0:
- raise AssertionError(
- f"Custom ops do not support kwargs yet, got {list(kwargs.keys())}"
- )
- autocast_keyset = torch._C.DispatchKeySet(
- torch._C.DispatchKey.AutocastCPU
- ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
- with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
- return self._opoverload(*_cast(args, device_type, cast_inputs))
- if need_register_cuda and self._autocast_cuda_dtype:
- self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True)
- elif need_register_cpu and self._autocast_cpu_dtype:
- self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True)
- return kernel
- # TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it
- # into a utility function once custom ops support arbitrary input types.
- def _cast(value, device_type: str, dtype: _dtype):
- if isinstance(value, torch.Tensor):
- is_eligible = (
- value.is_floating_point()
- and value.device.type == device_type
- and (value.dtype is not torch.float64)
- )
- return value.to(dtype) if is_eligible else value
- elif isinstance(value, (str, bytes)):
- return value
- elif isinstance(value, collections.abc.Iterable):
- iterable = (_cast(v, device_type, dtype) for v in value)
- if isinstance(value, (list, tuple)):
- return type(value)(iterable)
- else:
- return iterable
- else:
- return value
- def increment_version(val: Any) -> None:
- if isinstance(val, Tensor):
- torch.autograd.graph.increment_version(val)
- elif isinstance(val, (tuple, list)):
- for v in val:
- if isinstance(v, Tensor):
- torch.autograd.graph.increment_version(v)
- # NOTE: [Supporting decorator and non-decorator usage]
- #
- # Some APIs may be both used as a decorator and not as a decorator.
- # For example:
- #
- # >>> def fn(x):
- # >>> return x.sin()
- # >>>
- # >>> # Usage 1: not as a decorator
- # >>> numpy_sin.register_kernel("cuda", fn)
- # >>>
- # >>> # Usage 2: as a decorator
- # >>> @numpy_sin.register_kernel("cuda")
- # >>> def fn2(x):
- # >>> return x.sin
- #
- # The way we support this is that `register_kernel` accepts an optional `fn`.
- # If `fn` is provided (Usage 1), then we know that the user is using it not
- # as a decorator.
- # If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
- # decorator.
- OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {}
- OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
- def get_library_allowing_overwrite(
- namespace: str, name: str
- ) -> "torch.library.Library":
- qualname = f"{namespace}::{name}"
- if qualname in OPDEF_TO_LIB:
- OPDEF_TO_LIB[qualname]._destroy()
- del OPDEF_TO_LIB[qualname]
- lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
- OPDEF_TO_LIB[qualname] = lib
- return lib
- def _maybe_get_opdef(
- op: Union[CustomOpDef, _ops.OpOverload, str],
- ) -> Optional[CustomOpDef]:
- if isinstance(op, CustomOpDef):
- return op
- if isinstance(op, _ops.OpOverload):
- op = op._name
- if not isinstance(op, str):
- raise AssertionError(f"op must be str, got {type(op)}")
- if op in OPDEFS:
- return OPDEFS[op]
- return None
|