| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759 |
- # mypy: allow-untyped-defs
- import contextlib
- import functools
- import inspect
- import re
- import sys
- import traceback
- import weakref
- from collections.abc import Callable, Sequence
- from typing import Any, overload, TYPE_CHECKING, TypeVar, Union
- from typing_extensions import deprecated, ParamSpec
- import torch
- import torch._library as _library
- from torch._library.custom_ops import (
- _cast,
- _maybe_get_opdef,
- custom_op,
- CustomOpDef,
- device_types_t,
- )
- from torch._library.effects import EffectType
- from torch._library.infer_schema import infer_schema # noqa: F401
- from torch._library.triton import triton_op, wrap_triton
- from torch._ops import OpOverload
- from torch.types import _dtype
- __all__ = [
- "Library",
- "impl",
- "define",
- "fallthrough_kernel",
- "impl_abstract",
- "register_autocast",
- "register_fake",
- "register_torch_dispatch",
- "register_vmap",
- "get_ctx",
- "get_kernel",
- "custom_op",
- "triton_op",
- "wrap_triton",
- "infer_schema",
- ]
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
- # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
- # This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
- # libraries calling into kernels not intended to be called.
- _impls: set[str] = set()
- _defs: set[str] = set()
- # prim is reserved by TorchScript interpreter
- _reserved_namespaces = ["prim"]
- def fallthrough_kernel():
- """
- A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
- """
- raise NotImplementedError("fallthrough_kernel() should never be called.")
- class Library:
- """
- A class to create libraries that can be used to register new operators or
- override operators in existing libraries from Python.
- A user can optionally pass in a dispatch keyname if they only want to register
- kernels corresponding to only one specific dispatch key.
- To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
- To create a new library (with name ns) to register new operators, set the kind to "DEF".
- To create a fragment of a possibly existing library to register operators (and bypass
- the limitation that there is only one library for a given namespace), set the kind to
- "FRAGMENT".
- Args:
- ns: library name
- kind: "DEF", "IMPL", "FRAGMENT"
- dispatch_key: PyTorch dispatch key (default: "")
- """
- def __init__(self, ns, kind, dispatch_key=""):
- from torch.fx.operator_schemas import _SCHEMA_TO_SIGNATURE_CACHE
- if kind not in ("IMPL", "DEF", "FRAGMENT"):
- raise ValueError("Unsupported kind: ", kind)
- if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
- raise ValueError(
- ns,
- " is a reserved namespace. Please try creating a library with another name.",
- )
- frame = traceback.extract_stack(limit=2)[0]
- filename, lineno = frame.filename, frame.lineno
- self.m: Any | None = torch._C._dispatch_library(
- kind, ns, dispatch_key, filename, lineno
- )
- self.ns = ns
- self._op_defs: set[str] = set()
- self._op_impls: set[str] = set()
- self._registration_handles: list[torch._library.utils.RegistrationHandle] = []
- self.kind = kind
- self.dispatch_key = dispatch_key
- # Use a finalizer to setup the "destructor" instead of __del__.
- # Python __del__ can lead to weird things (globals and locals may already
- # be gone when __del__ actually gets called!). finalizers help the
- # situation because it lets us capture references and keeps them alive
- weakref.finalize(
- self,
- _del_library,
- _impls,
- self._op_impls,
- _defs,
- self._op_defs,
- self._registration_handles,
- self.m,
- _SCHEMA_TO_SIGNATURE_CACHE,
- )
- def __repr__(self):
- return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
- def define(self, schema, alias_analysis="", *, tags=()):
- r"""Defines a new operator and its semantics in the ns namespace.
- Args:
- schema: function schema to define a new operator.
- alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
- inferred from the schema (default behavior) or not ("CONSERVATIVE").
- tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
- operator. Tagging an operator changes the operator's behavior
- under various PyTorch subsystems; please read the docs for the
- torch.Tag carefully before applying it.
- Returns:
- name of the operator as inferred from the schema.
- Example::
- >>> my_lib = Library("mylib", "DEF")
- >>> my_lib.define("sum(Tensor self) -> Tensor")
- """
- # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
- # AliasAnalysis type in C++
- if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
- raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}")
- if self.m is None:
- raise AssertionError("Library object has been destroyed")
- if isinstance(tags, torch.Tag):
- tags = (tags,)
- name = schema.split("(")[0]
- packet_name = name.split(".")[0] if "." in name else name
- has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
- getattr(torch.ops, self.ns), packet_name
- )
- result = self.m.define(schema, alias_analysis, tuple(tags))
- name = schema.split("(")[0]
- qualname = self.ns + "::" + name
- # If the OpOverloadPacket exists already, then this means we're adding a
- # new OpOverload for it. Refresh the packet to include the new OpOverload.
- if has_preexisting_packet:
- ns = getattr(torch.ops, self.ns)
- packet = getattr(ns, packet_name)
- torch._ops._refresh_packet(packet)
- self._op_defs.add(qualname)
- _defs.add(qualname)
- return result
- def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False):
- r"""Registers the fake impl for an operator defined in the library."""
- source = torch._library.utils.get_source(_stacklevel + 1)
- frame = sys._getframe(_stacklevel)
- caller_module = inspect.getmodule(frame)
- # Can be none if you call register_fake from somewhere there isn't a module
- # (e.g. __main__)
- caller_module_name = None if caller_module is None else caller_module.__name__
- # TODO(rzou): We're gonna need to stage this change with torchvision,
- # since torchvision is github first.
- if caller_module_name is not None and caller_module_name.startswith(
- "torchvision."
- ):
- caller_module_name = None
- qualname = f"{self.ns}::{op_name}"
- entry = torch._library.simple_registry.singleton.find(qualname)
- if caller_module_name is not None:
- func_to_register = _check_pystubs_once(fn, qualname, caller_module_name)
- else:
- func_to_register = fn
- handle = entry.fake_impl.register(
- func_to_register, source, lib=self, allow_override=allow_override
- )
- self._registration_handles.append(handle)
- def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn):
- r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class.
- This allows for open registration to specify the behavior between the operator
- and the torch_dispatch_class without needing to modify the torch_dispatch_class
- or the operator directly.
- The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a
- TorchDispatchMode.
- If it is a Tensor subclass, we expect fn to have the following signature:
- (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
- If it is a TorchDispatchMode, we expect fn to have the following signature:
- (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
- """
- qualname = f"{self.ns}::{op_name}"
- entry = torch._library.simple_registry.singleton.find(qualname)
- handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn)
- self._registration_handles.append(handle)
- def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
- r"""Register the operator to use the AOTI-compiled implementation.
- Args:
- op_name: operator name (along with the overload) or OpOverload object.
- dispatch_key: dispatch key that the input function should be registered for. By default, it uses
- the dispatch key that the library was created with.
- Example::
- >>> my_lib = Library("aten", "IMPL")
- >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
- """
- if dispatch_key == "":
- dispatch_key = self.dispatch_key
- # pyrefly: ignore [bad-argument-type]
- if not torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense):
- raise AssertionError(
- f"dispatch_key {dispatch_key} does not have Dense in its keyset"
- )
- if isinstance(op_name, str):
- name = op_name
- elif isinstance(op_name, OpOverload):
- name = op_name._schema.name
- overload_name = op_name._schema.overload_name
- if overload_name != "":
- name = name + "." + overload_name
- else:
- raise RuntimeError(
- "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
- "as the first argument"
- )
- key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
- if key in _impls:
- # TODO: in future, add more info about where the existing function is registered (this info is
- # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
- raise RuntimeError(
- "This is not allowed since there's already a kernel registered from python overriding {}"
- "'s behavior for {} dispatch key and {} namespace.".format(
- name.split("::")[-1], dispatch_key, self.ns
- )
- )
- if self.m is None:
- raise AssertionError("Library object has been destroyed")
- impl_fn: Callable = self.m.impl_with_aoti_compile
- impl_fn(self.ns, name.split("::")[-1], dispatch_key)
- _impls.add(key)
- self._op_impls.add(key)
- def impl(
- self, op_name, fn, dispatch_key="", *, with_keyset=False, allow_override=False
- ):
- r"""Registers the function implementation for an operator defined in the library.
- Args:
- op_name: operator name (along with the overload) or OpOverload object.
- fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
- to register a fallthrough.
- dispatch_key: dispatch key that the input function should be registered for. By default, it uses
- the dispatch key that the library was created with.
- with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
- to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
- allow_override: Flag controlling if we want to override an
- existing registered kernel implementation. This is by
- default off, and will error you're trying to register a
- kernel to a dispatch key with a kernel already
- registered.
- Example::
- >>> # xdoctest: +SKIP("Requires Python <= 3.11")
- >>> my_lib = Library("aten", "IMPL")
- >>> def div_cpu(self, other):
- >>> return self * (1 / other)
- >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
- """
- if not callable(fn):
- raise TypeError(
- f"Input function is required to be a callable but found type {type(fn)}"
- )
- if dispatch_key == "":
- dispatch_key = self.dispatch_key
- if isinstance(op_name, str):
- name = op_name
- elif isinstance(op_name, OpOverload):
- name = op_name._schema.name
- overload_name = op_name._schema.overload_name
- if overload_name != "":
- name = name + "." + overload_name
- else:
- raise RuntimeError(
- "impl should be passed either a name or an OpOverload object as the first argument"
- )
- key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
- if (not allow_override) and key in _impls:
- # TODO: in future, add more info about where the existing function is registered (this info is
- # today already returned by the C++ warning when impl is called but we error out before that)
- raise RuntimeError(
- "This is not allowed since there's already a kernel registered from python overriding {}"
- "'s behavior for {} dispatch key and {} namespace.".format(
- name.split("::")[-1], dispatch_key, self.ns
- )
- )
- if dispatch_key == "Meta":
- dispatcher_op_name = name
- if "::" not in dispatcher_op_name:
- dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
- # Internally, we shouldn't be registering meta kernels for any operators that
- # have CompositeImplicitAutograd kernels.
- # Instead, we should be letting those decompositions run, and writing meta kernels
- # only for the base operators.
- if torch._C._dispatch_has_kernel_for_dispatch_key(
- dispatcher_op_name, "CompositeImplicitAutograd"
- ):
- raise RuntimeError(
- f"We should not register a meta kernel directly to the operator '{name}',"
- " because it has a CompositeImplicitAutograd kernel in core."
- " Instead we should let the operator decompose, and ensure that we have meta kernels"
- " for the base ops that it decomposes into."
- )
- if self.m is None:
- raise AssertionError("Library object has been destroyed")
- self.m.impl(
- name,
- dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
- fn,
- with_keyset,
- )
- _impls.add(key)
- self._op_impls.add(key)
- def fallback(self, fn, dispatch_key="", *, with_keyset=False):
- r"""Registers the function implementation as the fallback for the given key.
- This function only works for a library with global namespace ("_").
- Args:
- fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
- to register a fallthrough.
- dispatch_key: dispatch key that the input function should be registered for. By default, it uses
- the dispatch key that the library was created with.
- with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
- to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
- Example::
- >>> my_lib = Library("_", "IMPL")
- >>> def fallback_kernel(op, *args, **kwargs):
- >>> # Handle all autocast ops generically
- >>> # ...
- >>> my_lib.fallback(fallback_kernel, "Autocast")
- """
- if dispatch_key == "":
- dispatch_key = self.dispatch_key
- if self.ns != "_":
- raise RuntimeError(
- f"""Fallback can only be registered using library fragment on the global namespace "_" but it is {self.ns}"""
- )
- if dispatch_key == "":
- raise AssertionError("dispatch_key must not be empty for fallback")
- if self.m is None:
- raise AssertionError("Library object has been destroyed")
- self.m.fallback(dispatch_key, fn, with_keyset)
- def _register_effectful_op(self, op_name: str, effect: EffectType | None):
- """
- Registers an effect to an operator. This is used to register an op that
- has side effects that is not capturable by the schema.
- Args:
- op_name: operator name (along with the overload) or OpOverload object.
- effect: The effect of the op.
- """
- from torch._higher_order_ops.effects import (
- _register_effectful_op as hoo_register_effect,
- )
- handle = hoo_register_effect(op_name, effect)
- self._registration_handles.append(handle)
- def _destroy(self):
- if self.m is not None:
- self.m.reset()
- self.m = None
- for handle in self._registration_handles:
- handle.destroy()
- self._registration_handles.clear()
- global _impls
- _impls -= self._op_impls
- for name in self._op_defs:
- # Delete the cached torch.ops.ns.foo if it was registered.
- # Otherwise, accessing it leads to a segfault.
- # It's possible that we only registered an overload in this Library
- # and another library owns an alive overload.
- # That's OK - the next time torch.ops.ns.foo gets called, it'll be
- # recomputed to point at the right collection of overloads.
- ns, name_with_overload = name.split("::")
- name = name_with_overload.split(".")[0]
- if not hasattr(torch.ops, ns):
- continue
- namespace = getattr(torch.ops, ns)
- if not hasattr(namespace, name):
- continue
- delattr(namespace, name)
- namespace._dir.remove(name)
- def _del_library(
- captured_impls,
- op_impls,
- captured_defs,
- op_defs,
- registration_handles,
- m,
- schema_to_signature_cache,
- ):
- for op_def in op_defs:
- name = op_def
- overload_name = ""
- if "." in op_def:
- name, overload_name = op_def.split(".")
- if (
- name,
- overload_name,
- ) in schema_to_signature_cache:
- del schema_to_signature_cache[(name, overload_name)]
- captured_impls -= op_impls
- captured_defs -= op_defs
- for handle in registration_handles:
- handle.destroy()
- if m is not None:
- m.reset()
- @contextlib.contextmanager
- def _scoped_library(*args, **kwargs):
- try:
- lib = Library(*args, **kwargs)
- yield lib
- finally:
- lib._destroy()
- _keep_alive: list[Library] = []
- NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*")
- @functools.singledispatch
- def define(qualname, schema, *, lib=None, tags=()):
- r"""Defines a new operator.
- In PyTorch, defining an op (short for "operator") is a two step-process:
- - we need to define the op (by providing an operator name and schema)
- - we need to implement behavior for how the operator interacts with
- various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
- This entrypoint defines the custom operator (the first step)
- you must then perform the second step by calling various
- ``impl_*`` APIs, like :func:`torch.library.impl` or
- :func:`torch.library.register_fake`.
- Args:
- qualname (str): The qualified name for the operator. Should be
- a string that looks like "namespace::name", e.g. "aten::sin".
- Operators in PyTorch need a namespace to
- avoid name collisions; a given operator may only be created once.
- If you are writing a Python library, we recommend the namespace to
- be the name of your top-level module.
- schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor"
- for an op that accepts one Tensor and returns one Tensor. It does
- not contain the operator name (that is passed in ``qualname``).
- lib (Optional[Library]): If provided, the lifetime of this operator
- will be tied to the lifetime of the Library object.
- tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
- operator. Tagging an operator changes the operator's behavior
- under various PyTorch subsystems; please read the docs for the
- torch.Tag carefully before applying it.
- Example::
- >>> import torch
- >>> import numpy as np
- >>>
- >>> # Define the operator
- >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
- >>>
- >>> # Add implementations for the operator
- >>> @torch.library.impl("mylib::sin", "cpu")
- >>> def f(x):
- >>> return torch.from_numpy(np.sin(x.numpy()))
- >>>
- >>> # Call the new operator from torch.ops.
- >>> x = torch.randn(3)
- >>> y = torch.ops.mylib.sin(x)
- >>> assert torch.allclose(y, x.sin())
- """
- if not isinstance(qualname, str):
- raise ValueError(
- f"define(qualname, schema): expected qualname "
- f"to be instance of str, got {type(qualname)}"
- )
- namespace, name = torch._library.utils.parse_namespace(qualname)
- if lib is None:
- lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(lib)
- if not NAMELESS_SCHEMA.fullmatch(schema):
- raise ValueError(
- f"define(qualname, schema, ...): expected schema "
- f'to look like e.g. "(Tensor x) -> Tensor" but '
- f'got "{schema}"'
- )
- lib.define(name + schema, alias_analysis="", tags=tags)
- @define.register
- def _(lib: Library, schema, alias_analysis=""):
- """The old torch.library.define.
- We're keeping this around for BC reasons
- """
- def wrap(f):
- name = lib.define(schema, alias_analysis)
- lib.impl(name, f)
- return f
- return wrap
- @overload
- def impl(
- qualname: str,
- types: str | Sequence[str],
- func: None = None,
- *,
- lib: Library | None = None,
- ) -> Callable[[Callable[..., object]], None]: ...
- @overload
- def impl(
- qualname: str,
- types: str | Sequence[str],
- func: Callable[..., object],
- *,
- lib: Library | None = None,
- ) -> None: ...
- # Deprecated BC API
- @overload
- def impl(
- lib: Library,
- name: str,
- dispatch_key: str = "",
- ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
- @functools.singledispatch
- def impl(
- qualname: str,
- types: str | Sequence[str],
- func: Callable[_P, _T] | None = None,
- *,
- lib: Library | None = None,
- ) -> object:
- """Register an implementation for a device type for this operator.
- You may pass "default" for ``types`` to register this implementation as the
- default implementation for ALL device types.
- Please only use this if the implementation truly supports all device types;
- for example, this is true if it is a composition of built-in PyTorch operators.
- This API may be used as a decorator. You can use nested decorators
- with this API provided they return a function and are placed inside
- this API (see Example 2).
- Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
- Args:
- qualname (str): Should be a string that looks like "namespace::operator_name".
- types (str | Sequence[str]): The device types to register an impl to.
- lib (Optional[Library]): If provided, the lifetime of this registration
- will be tied to the lifetime of the Library object.
- Examples:
- >>> # xdoctest: +SKIP("Requires Python <= 3.11")
- >>> import torch
- >>> import numpy as np
- >>> # Example 1: Register function.
- >>> # Define the operator
- >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
- >>>
- >>> # Add implementations for the cpu device
- >>> @torch.library.impl("mylib::mysin", "cpu")
- >>> def f(x):
- >>> return torch.from_numpy(np.sin(x.numpy()))
- >>>
- >>> x = torch.randn(3)
- >>> y = torch.ops.mylib.mysin(x)
- >>> assert torch.allclose(y, x.sin())
- >>>
- >>> # Example 2: Register function with decorator.
- >>> def custom_decorator(func):
- >>> def wrapper(*args, **kwargs):
- >>> return func(*args, **kwargs) + 1
- >>> return wrapper
- >>>
- >>> # Define the operator
- >>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor")
- >>>
- >>> # Add implementations for the operator
- >>> @torch.library.impl("mylib::sin_plus_one", "cpu")
- >>> @custom_decorator
- >>> def f(x):
- >>> return torch.from_numpy(np.sin(x.numpy()))
- >>>
- >>> # Call the new operator from torch.ops.
- >>> x = torch.randn(3)
- >>>
- >>> y1 = torch.ops.mylib.sin_plus_one(x)
- >>> y2 = torch.sin(x) + 1
- >>> assert torch.allclose(y1, y2)
- """
- return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
- if not TYPE_CHECKING:
- @impl.register
- def _(
- lib: Library, name: str, dispatch_key: str = ""
- ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
- """Legacy torch.library.impl API. Kept around for BC"""
- def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]:
- lib.impl(name, f, dispatch_key)
- return f
- return wrap
- @overload
- def _impl(
- qualname: str,
- types: str | Sequence[str],
- func: None = None,
- *,
- lib: Library | None = None,
- disable_dynamo: bool = False,
- ) -> Callable[[Callable[..., object]], None]: ...
- @overload
- def _impl(
- qualname: str,
- types: str | Sequence[str],
- func: Callable[..., object],
- *,
- lib: Library | None = None,
- disable_dynamo: bool = False,
- ) -> None: ...
- def _impl(
- qualname: str,
- types: str | Sequence[str],
- func: Callable[..., object] | None = None,
- *,
- lib: Library | None = None,
- disable_dynamo: bool = False,
- ) -> Callable[[Callable[..., object]], None] | None:
- # See impl()
- if isinstance(types, str):
- types = (types,)
- keys = set({})
- for typ in types:
- is_dispatch_key = torch._C._parse_dispatch_key(typ)
- if is_dispatch_key:
- # We also support passing a DispatchKey to impl. Please prefer using
- # the higher-level torch.library APIs and only pass DispatchKey to
- # torch.library.impl with caution (or even better, don't use this
- # option and file an issue on GitHub for what you need).
- # We don't advertise this to users because
- # it is very easy to shoot yourself in the foot.
- keys.add(typ)
- else:
- keys.add(_device_type_to_key(typ))
- def register_(func: Callable[..., object]) -> None:
- namespace, _ = torch._library.utils.parse_namespace(qualname)
- if lib is None:
- use_lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(use_lib)
- else:
- use_lib = lib
- if disable_dynamo:
- @torch._disable_dynamo
- def func_no_dynamo(*args, **kwargs):
- return func(*args, **kwargs)
- for key in keys:
- use_lib.impl(qualname, func_no_dynamo, key)
- else:
- for key in keys:
- use_lib.impl(qualname, func, key)
- if func is None:
- return register_
- else:
- register_(func)
- return None
- def _device_type_to_key(device_type: str) -> str:
- if device_type == "default":
- # This is technically not correct, because although all device_type
- # DispatchKeys are included in CompositeExplicitAutograd,
- # not everything in CompositeExplicitAutograd is associated with a
- # device_type. I don't really care that much about the difference.
- return "CompositeExplicitAutograd"
- return torch._C._dispatch_key_for_device(device_type)
- @deprecated(
- "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that "
- "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.",
- category=FutureWarning,
- )
- def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
- r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4.
- Please use that instead.
- """
- if func is not None:
- _stacklevel = _stacklevel + 1
- return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
- _op_identifier = Union[
- str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
- ]
- def register_kernel(
- op: _op_identifier,
- device_types: device_types_t,
- func: Callable | None = None,
- /,
- *,
- lib: Library | None = None,
- ):
- """Register an implementation for a device type for this operator.
- Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
- This API may be used as a decorator.
- Args:
- op (str | OpOverload): The operator to register an impl to.
- device_types (str | None | Sequence[str]): The device_types to register an impl to.
- If None, we will register to all device types -- please only use
- this option if your implementation is truly device-type-agnostic.
- func (Callable): The function to register as the implementation for
- the given device types.
- lib (Optional[Library]): If provided, the lifetime of this registration
- Examples::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> from torch import Tensor
- >>> from torch.library import custom_op
- >>> import numpy as np
- >>>
- >>> # Create a custom op that works on cpu
- >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
- >>> def numpy_sin(x: Tensor) -> Tensor:
- >>> x_np = x.numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np)
- >>>
- >>> # Add implementations for the cuda device
- >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
- >>> def _(x):
- >>> x_np = x.cpu().numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np).to(device=x.device)
- >>>
- >>> x_cpu = torch.randn(3)
- >>> x_cuda = x_cpu.cuda()
- >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
- >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
- """
- if not isinstance(
- op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
- ):
- raise ValueError(
- f"register_kernel({op}): got unexpected type for op: {type(op)}"
- )
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- opdef = _maybe_get_opdef(op)
- if opdef is not None:
- return opdef.register_kernel(device_types, func)
- if not isinstance(op, str):
- raise AssertionError(f"op must be str at this point, got {type(op).__name__}")
- if device_types is None:
- device_types = "CompositeExplicitAutograd"
- return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
- def register_autocast(
- op: _op_identifier,
- device_type: str,
- cast_inputs: _dtype,
- /,
- *,
- lib: Library | None = None,
- ):
- r"""Register an autocast dispatch rule for this custom op.
- Valid `device_type` include: "cpu" and "cuda".
- Args:
- op (str | OpOverload): The operator to register an autocast dispatch rule to.
- device_type(str): Device type to use. 'cuda' or 'cpu'.
- The type is the same as the `type` attribute of a :class:`torch.device`.
- Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
- cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
- casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
- are not affected), then executes custom op with autocast disabled.
- lib (Optional[Library]): If provided, the lifetime of this registration
- Examples::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> from torch import Tensor
- >>> from torch.library import custom_op
- >>>
- >>> # Create a custom op that works on cuda
- >>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
- >>> def my_sin(x: Tensor) -> Tensor:
- >>> return torch.sin(x)
- >>>
- >>> # Register autocast dispatch rule for the cuda device
- >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
- >>>
- >>> x = torch.randn(3, dtype=torch.float32, device="cuda")
- >>> with torch.autocast("cuda", dtype=torch.float16):
- >>> y = torch.ops.mylib.my_sin(x)
- >>> assert y.dtype == torch.float16
- """
- if not isinstance(
- op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
- ):
- raise ValueError(
- f"register_autocast({op}): got unexpected type for op: {type(op)}"
- )
- if device_type not in ["cpu", "cuda"]:
- raise ValueError(f"Unknown device type: {device_type}")
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- opdef = _maybe_get_opdef(op)
- if opdef is not None:
- return opdef.register_autocast(device_type, cast_inputs)
- if not isinstance(op, str):
- raise AssertionError(f"op must be str at this point, got {type(op).__name__}")
- qualname = op
- _op = torch._library.utils.lookup_op(qualname)
- namespace, opname = torch._library.utils.parse_namespace(qualname)
- if lib is None:
- lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(lib)
- def _maybe_override_py_impl(op: torch._ops.OpOverload, dispatch_key):
- def inner(kernel):
- if op.has_kernel_for_dispatch_key(dispatch_key):
- op.py_kernels.pop(dispatch_key)
- return op.py_impl(dispatch_key)(kernel)
- return inner
- @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCPU)
- @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCUDA)
- def _autocast_py_impl(*args, **kwargs):
- if len(kwargs) != 0:
- raise AssertionError("Custom ops do not support kwargs yet.")
- autocast_keyset = torch._C.DispatchKeySet(
- torch._C.DispatchKey.AutocastCPU
- ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
- with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
- return _op(*_cast(args, device_type, cast_inputs))
- def kernel(_, *args, **kwargs):
- if len(kwargs) != 0:
- raise AssertionError("Custom ops do not support kwargs yet.")
- return _autocast_py_impl(*args, **kwargs)
- if device_type == "cuda":
- return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True)
- else:
- # device_type is "cpu"
- return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True)
- def register_fake(
- op: _op_identifier,
- func: Callable | None = None,
- /,
- *,
- lib: Library | None = None,
- _stacklevel: int = 1,
- allow_override: bool = False,
- ):
- r"""Register a FakeTensor implementation ("fake impl") for this operator.
- Also sometimes known as a "meta kernel", "abstract impl".
- An "FakeTensor implementation" specifies the behavior of this operator on
- Tensors that carry no data ("FakeTensor"). Given some input Tensors with
- certain properties (sizes/strides/storage_offset/device), it specifies
- what the properties of the output Tensors are.
- The FakeTensor implementation has the same signature as the operator.
- It is run for both FakeTensors and meta tensors. To write a FakeTensor
- implementation, assume that all Tensor inputs to the operator are
- regular CPU/CUDA/Meta tensors, but they do not have storage, and
- you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
- The FakeTensor implementation must consist of only PyTorch operations
- (and may not directly access the storage or data of any input or
- intermediate Tensors).
- This API may be used as a decorator (see examples).
- For a detailed guide on custom ops, please see
- https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
- Args:
- op_name: Operator name (along with the overload) or OpOverload object.
- func: Fake tensor implementation.
- lib (Optional[Library]): Library to register the fake tensor to.
- allow_override: Flag controlling if we want to override an
- existing registered fake impl. This is by default off,
- and will error you're trying to register a fake impl to
- an operator that already has a fake impl. This also only
- applies if the custom operator was not created via
- torch.library.custom_op, as overriding and existing fake
- impl is already allowed.
- Examples:
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>>
- >>> # Example 1: an operator without data-dependent output shape
- >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
- >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
- >>> raise NotImplementedError("Implementation goes here")
- >>>
- >>> @torch.library.register_fake("mylib::custom_linear")
- >>> def _(x, weight, bias):
- >>> assert x.dim() == 2
- >>> assert weight.dim() == 2
- >>> assert bias.dim() == 1
- >>> assert x.shape[1] == weight.shape[1]
- >>> assert weight.shape[0] == bias.shape[0]
- >>> assert x.device == weight.device
- >>>
- >>> return (x @ weight.t()) + bias
- >>>
- >>> with torch._subclasses.fake_tensor.FakeTensorMode():
- >>> x = torch.randn(2, 3)
- >>> w = torch.randn(3, 3)
- >>> b = torch.randn(3)
- >>> y = torch.ops.mylib.custom_linear(x, w, b)
- >>>
- >>> assert y.shape == (2, 3)
- >>>
- >>> # Example 2: an operator with data-dependent output shape
- >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
- >>> def custom_nonzero(x: Tensor) -> Tensor:
- >>> x_np = x.numpy(force=True)
- >>> res = np.stack(np.nonzero(x_np), axis=1)
- >>> return torch.tensor(res, device=x.device)
- >>>
- >>> @torch.library.register_fake("mylib::custom_nonzero")
- >>> def _(x):
- >>> # Number of nonzero-elements is data-dependent.
- >>> # Since we cannot peek at the data in an fake impl,
- >>> # we use the ctx object to construct a new symint that
- >>> # represents the data-dependent size.
- >>> ctx = torch.library.get_ctx()
- >>> nnz = ctx.new_dynamic_size()
- >>> shape = [nnz, x.dim()]
- >>> result = x.new_empty(shape, dtype=torch.int64)
- >>> return result
- >>>
- >>> from torch.fx.experimental.proxy_tensor import make_fx
- >>>
- >>> x = torch.tensor([0, 1, 2, 3, 4, 0])
- >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
- >>> trace.print_readable()
- >>>
- >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
- """
- if not isinstance(
- op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
- ):
- raise ValueError(f"register_fake({op}): got unexpected type for op: {type(op)}")
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- opdef = _maybe_get_opdef(op)
- if opdef is not None:
- if func is None:
- return opdef.register_fake
- else:
- return opdef.register_fake(func)
- if not isinstance(op, str):
- raise AssertionError(f"op must be str at this point, got {type(op).__name__}")
- stacklevel = _stacklevel
- def register(func):
- namespace, op_name = torch._library.utils.parse_namespace(op)
- if lib is None:
- use_lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(use_lib)
- else:
- use_lib = lib
- use_lib._register_fake(
- op_name, func, _stacklevel=stacklevel + 1, allow_override=allow_override
- )
- return func
- if func is None:
- return register
- else:
- stacklevel += 1
- return register(func)
- def _register_effectful_op(
- op: _op_identifier,
- effect: EffectType | None,
- *,
- lib: Library | None = None,
- ) -> None:
- r"""
- To specify that an operator has side-effects, we must register an effect
- type for the operator. This will prevent graph passes in torch.compile from
- reordering operations with the same effect type.
- Args:
- op_name: Operator name (along with the overload) or OpOverload object.
- effect: Effect type to register. None means the operator is not effectful.
- """
- if not isinstance(
- op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
- ):
- raise ValueError(
- f"register_effectful_op({op}): got unexpected type for op: {type(op)}"
- )
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- opdef = _maybe_get_opdef(op)
- if opdef is not None:
- opdef.register_effect(effect)
- if not isinstance(op, str):
- raise AssertionError(f"op must be str at this point, got {type(op).__name__}")
- namespace, _ = torch._library.utils.parse_namespace(op)
- if lib is None:
- use_lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(use_lib)
- else:
- use_lib = lib
- use_lib._register_effectful_op(op, effect)
- def register_autograd(
- op: _op_identifier,
- backward: Callable,
- /,
- *,
- setup_context: Callable | None = None,
- lib=None,
- ) -> None:
- r"""Register a backward formula for this custom op.
- In order for an operator to work with autograd, you need to register
- a backward formula:
- 1. You must tell us how to compute gradients during the backward pass
- by providing us a "backward" function.
- 2. If you need any values from the forward to compute gradients, you can
- use `setup_context` to save values for backward.
- ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``:
- - ``grads`` is one or more gradients. The number of gradients matches
- the number of outputs of the operator.
- The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
- :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
- same as :meth:`torch.autograd.Function.backward`.
- ``setup_context(ctx, inputs, output)`` runs during the forward pass.
- Please save quantities needed for backward onto the ``ctx`` object via
- either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
- or assigning them as attributes of ``ctx``. If your custom op has
- kwarg-only arguments, we expect the signature of ``setup_context``
- to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
- Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
- they may not directly access :meth:`torch.Tensor.data_ptr` and they must
- not depend on or mutate global state. If you need a non-traceable backward,
- you can make it a separate custom_op that you call inside ``backward_fn``.
- If you need different autograd behavior on different devices, then we
- recommend creating two different custom operators, one for each device
- that needs different behavior, and switching between them at runtime.
- Examples:
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>>
- >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
- >>> def numpy_sin(x: Tensor) -> Tensor:
- >>> x_np = x.cpu().numpy()
- >>> y_np = np.sin(x_np)
- >>> return torch.from_numpy(y_np).to(device=x.device)
- >>>
- >>> def setup_context(ctx, inputs, output) -> Tensor:
- >>> x, = inputs
- >>> ctx.save_for_backward(x)
- >>>
- >>> def backward(ctx, grad):
- >>> x, = ctx.saved_tensors
- >>> return grad * x.cos()
- >>>
- >>> torch.library.register_autograd(
- ... "mylib::numpy_sin", backward, setup_context=setup_context
- ... )
- >>>
- >>> x = torch.randn(3, requires_grad=True)
- >>> y = numpy_sin(x)
- >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
- >>> assert torch.allclose(grad_x, x.cos())
- >>>
- >>> # Example with a keyword-only arg
- >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
- >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
- >>> x_np = x.cpu().numpy()
- >>> y_np = x_np * val
- >>> return torch.from_numpy(y_np).to(device=x.device)
- >>>
- >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
- >>> ctx.val = keyword_only_inputs["val"]
- >>>
- >>> def backward(ctx, grad):
- >>> return grad * ctx.val
- >>>
- >>> torch.library.register_autograd(
- ... "mylib::numpy_mul", backward, setup_context=setup_context
- ... )
- >>>
- >>> x = torch.randn(3, requires_grad=True)
- >>> y = numpy_mul(x, val=3.14)
- >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
- >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
- """
- if not isinstance(
- op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
- ):
- raise ValueError(
- f"register_autograd({op}): got unexpected type for op: {type(op)}"
- )
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- opdef = _maybe_get_opdef(op)
- if opdef is not None:
- opdef.register_autograd(backward, setup_context=setup_context)
- return
- if not isinstance(op, str):
- raise AssertionError(f"op must be str at this point, got {type(op).__name__}")
- qualname = op
- op = torch._library.utils.lookup_op(qualname)
- schema = op._schema
- if not _library.utils.is_functional_schema(schema):
- raise RuntimeError(
- f"Cannot register autograd formula for non-functional operator "
- f"{op} with schema {schema}. Please create "
- f"a functional operator and register an autograd formula for that."
- )
- if _library.utils.has_kwarg_only_tensors(schema):
- raise NotImplementedError(
- f"register_autograd with kwarg-only Tensor args. In the original "
- f"definition of the op, please make your tensors not kwarg-only. "
- f"Got: {schema}"
- )
- info = _library.autograd.Info(backward, setup_context)
- autograd_kernel = _library.autograd.make_autograd_impl(op, info)
- namespace, opname = torch._library.utils.parse_namespace(qualname)
- if lib is None:
- lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(lib)
- lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)
- def register_torch_dispatch(
- op: _op_identifier,
- torch_dispatch_class: Any,
- func: Callable | None = None,
- /,
- *,
- lib: Library | None = None,
- ):
- r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
- This allows for open registration to specify the behavior between the operator
- and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
- or the operator directly.
- The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a
- TorchDispatchMode.
- If it is a Tensor subclass, we expect ``func`` to have the following signature:
- ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
- If it is a TorchDispatchMode, we expect ``func`` to have the following signature:
- ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
- ``args`` and ``kwargs`` will have been normalized the same way they are
- in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`).
- Examples:
- >>> import torch
- >>>
- >>> @torch.library.custom_op("mylib::foo", mutates_args={})
- >>> def foo(x: torch.Tensor) -> torch.Tensor:
- >>> return x.clone()
- >>>
- >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
- >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- >>> return func(*args, **kwargs)
- >>>
- >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
- >>> def _(mode, func, types, args, kwargs):
- >>> x, = args
- >>> return x + 1
- >>>
- >>> x = torch.randn(3)
- >>> y = foo(x)
- >>> assert torch.allclose(y, x)
- >>>
- >>> with MyMode():
- >>> y = foo(x)
- >>> assert torch.allclose(y, x + 1)
- """
- if not isinstance(
- op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
- ):
- raise ValueError(
- f"register_torch_dispatch({op}): got unexpected type for op: {type(op)}"
- )
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- opdef = _maybe_get_opdef(op)
- if opdef is not None:
- return opdef.register_torch_dispatch(torch_dispatch_class, func)
- if not isinstance(op, str):
- raise AssertionError(f"op must be str at this point, got {type(op).__name__}")
- def register(func):
- namespace, op_name = torch._library.utils.parse_namespace(op)
- if lib is None:
- use_lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(use_lib)
- else:
- use_lib = lib
- use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func)
- return func
- if func is None:
- return register
- else:
- return register(func)
- def register_vmap(
- op: _op_identifier,
- func: Callable | None = None,
- /,
- *,
- lib=None,
- ):
- r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
- This API may be used as a decorator (see examples).
- In order for an operator to work with :func:`torch.vmap`, you may need to register a
- vmap implementation in the following signature:
- ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
- where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
- We do not support kwarg-only Tensor args.
- It specifies how do we compute the batched version of ``op`` given inputs with an additional
- dimension (specified by ``in_dims``).
- For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
- if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
- specifying what dimension of the Tensor is being vmapped over.
- ``info`` is a collection of additional metadata that may be helpful:
- ``info.batch_size`` specifies the size of the dimension being vmapped over, while
- ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
- The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
- ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
- per output that specifies if the output has the vmapped dimension and what index it is in.
- Examples:
- >>> import torch
- >>> import numpy as np
- >>> from torch import Tensor
- >>> from typing import Tuple
- >>>
- >>> def to_numpy(tensor):
- >>> return tensor.cpu().numpy()
- >>>
- >>> lib = torch.library.Library("mylib", "FRAGMENT")
- >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
- >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
- >>> x_np = to_numpy(x)
- >>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
- >>> return torch.tensor(x_np ** 3, device=x.device), dx
- >>>
- >>> def numpy_cube_vmap(info, in_dims, x):
- >>> result = numpy_cube(x)
- >>> return result, (in_dims[0], in_dims[0])
- >>>
- >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
- >>>
- >>> x = torch.randn(3)
- >>> torch.vmap(numpy_cube)(x)
- >>>
- >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
- >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
- >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
- >>>
- >>> @torch.library.register_vmap("mylib::numpy_mul")
- >>> def numpy_mul_vmap(info, in_dims, x, y):
- >>> x_bdim, y_bdim = in_dims
- >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
- >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
- >>> result = x * y
- >>> result = result.movedim(-1, 0)
- >>> return result, 0
- >>>
- >>>
- >>> x = torch.randn(3)
- >>> y = torch.randn(3)
- >>> torch.vmap(numpy_mul)(x, y)
- .. note::
- The vmap function should aim to preserve the semantics of the entire custom operator.
- That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``.
- If your custom operator has any custom behavior in the backward pass, please
- keep this in mind.
- """
- if not isinstance(
- op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
- ):
- raise ValueError(f"register_vmap({op}): got unexpected type for op: {type(op)}")
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- opdef = _maybe_get_opdef(op)
- if opdef is not None:
- return opdef.register_vmap(func)
- if not isinstance(op, str):
- raise AssertionError(f"op must be str at this point, got {type(op).__name__}")
- qualname = op
- op = torch._library.utils.lookup_op(qualname)
- schema = op._schema
- if _library.utils.has_kwarg_only_tensors(schema):
- raise NotImplementedError(
- f"register_vmap with kwarg-only Tensor args. In the original "
- f"definition of the op, please make your tensors not kwarg-only. "
- f"Got: {schema}"
- )
- def register(func):
- nonlocal op, lib
- namespace, opname = torch._library.utils.parse_namespace(qualname)
- if lib is None:
- lib = Library(namespace, "FRAGMENT")
- _keep_alive.append(lib)
- from torch._functorch.autograd_function import custom_function_call_vmap_helper
- from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
- def wrapped_func(keyset, *args, **kwargs):
- interpreter = retrieve_current_functorch_interpreter()
- return custom_function_call_vmap_helper(
- # pyrefly: ignore[bad-argument-type]
- interpreter,
- func,
- op,
- *args,
- **kwargs,
- )
- lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True)
- if func is None:
- return register
- else:
- return register(func)
- # If the op was defined in C++, then we want to make sure there was an
- # m.set_python_module(module, ...) call and that the module is the
- # same as the module that called torch.library.register_fake.
- def _check_pystubs_once(func, qualname, actual_module_name):
- checked = False
- def inner(*args, **kwargs):
- nonlocal checked
- if checked:
- return func(*args, **kwargs)
- op = torch._library.utils.lookup_op(qualname)
- if op._defined_in_python:
- checked = True
- return func(*args, **kwargs)
- maybe_pystub = torch._C._dispatch_pystub(
- op._schema.name, op._schema.overload_name
- )
- if maybe_pystub is None:
- if torch._library.utils.requires_set_python_module():
- namespace = op.namespace
- cpp_filename = op._handle.debug()
- raise RuntimeError(
- f"Operator '{qualname}' was defined in C++ and has a Python "
- f"fake impl. In this situation, we require there to also be a "
- f'companion C++ `m.set_python_module("{actual_module_name}")` '
- f"call, but we could not find one. Please add that to "
- f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
- f"operator was registered in ({cpp_filename})"
- )
- else:
- pystub_module = maybe_pystub[0]
- if actual_module_name != pystub_module:
- cpp_filename = op._handle.debug()
- raise RuntimeError(
- f"Operator '{qualname}' specified that its python fake impl "
- f"is in the Python module '{pystub_module}' but it was actually found "
- f"in '{actual_module_name}'. Please either move the fake impl "
- f"or correct the m.set_python_module call ({cpp_filename})"
- )
- checked = True
- return func(*args, **kwargs)
- return inner
- # NOTE [ctx inside the fake implementation]
- # If a user has an operator with data-dependent output shape, then when writing
- # a fake implementation they must query the current ctx and use methods on the
- # ctx to construct a new unbacked symint.
- #
- # This is done via us setting the global_ctx_getter function every time a fake
- # implementation is invoked.
- def get_ctx() -> "torch._library.fake_impl.FakeImplCtx":
- """get_ctx() returns the current AbstractImplCtx object.
- Calling ``get_ctx()`` is only valid inside of an fake impl
- (see :func:`torch.library.register_fake` for more usage details.
- """
- return torch._library.fake_impl.global_ctx_getter()
- def get_kernel(
- op: _op_identifier, dispatch_key: str | torch.DispatchKey
- ) -> torch._C._SafeKernelFunction:
- """Returns the computed kernel for a given operator and dispatch key.
- This function retrieves the kernel that would be executed for a given
- operator and dispatch key combination. The returned SafeKernelFunction
- can be used to call the kernel in a boxed fashion. The intended use
- case for this function is to retrieve the original kernel for a given
- dispatch key and then register another kernel to the same dispatch key
- that calls into the original kernel for certain cases.
- Args:
- op: Operator name (along with the overload) or OpOverload object
- Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef.
- dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for.
- Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value.
- Returns:
- torch._C._SafeKernelFunction: A safe kernel function that can be used to
- call the kernel.
- Raises:
- RuntimeError: If the operator does not exist.
- Example:
- >>> # Get the CPU kernel for torch.add
- >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")
- >>>
- >>> # You can also use DispatchKey enum
- >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU)
- >>>
- >>> # Or use an OpOverload directly
- >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")
- >>>
- >>> # Example: Using get_kernel in a custom op with conditional dispatch
- >>> # Get the original kernel for torch.sin
- >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU")
- >>>
- >>> # If input has negative values, use original sin, otherwise return zeros
- >>> def conditional_sin_impl(dispatch_keys, x):
- >>> if (x < 0).any():
- >>> return original_sin_kernel.call_boxed(dispatch_keys, x)
- >>> else:
- >>> return torch.zeros_like(x)
- >>>
- >>> lib = torch.library.Library("aten", "IMPL")
- >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet
- >>> which needs to be the first argument to ``kernel.call_boxed``
- >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True)
- >>>
- >>> # Test the conditional behavior
- >>> x_positive = torch.tensor([1.0, 2.0])
- >>> x_mixed = torch.tensor([-1.0, 2.0])
- >>> torch.sin(x_positive)
- tensor([0., 0.])
- >>> torch.sin(x_mixed)
- tensor([-0.8415, 0.9093])
- """
- if not isinstance(op, (str, torch._ops.OpOverload)):
- raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}")
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
- if isinstance(dispatch_key, str):
- try:
- dispatch_key = torch._C.DispatchKey.__members__[dispatch_key]
- except KeyError:
- raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None
- return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key)
- _OPCHECK_DEFAULT_UTILS = (
- "test_schema",
- "test_autograd_registration",
- "test_faketensor",
- "test_aot_dispatch_dynamic",
- )
- def opcheck(
- op: torch._ops.OpOverload | torch._ops.OpOverloadPacket | CustomOpDef,
- args: tuple[Any, ...],
- kwargs: dict[str, Any] | None = None,
- *,
- test_utils: str | Sequence[str] = _OPCHECK_DEFAULT_UTILS,
- raise_exception: bool = True,
- atol=None,
- rtol=None,
- ) -> dict[str, str]:
- """Given an operator and some sample arguments, tests if the operator is
- registered correctly.
- That is, when you use the torch.library/TORCH_LIBRARY APIs to create a
- custom op, you specified metadata (e.g. mutability info) about the custom op
- and these APIs require that the functions you pass them satisfy certain
- properties (e.g. no data pointer access in the fake/meta/abstract kernel)
- ``opcheck`` tests these metadata and properties.
- Concretely, we test the following:
- - test_schema: If the schema matches the implementation of
- the operator. For example: if the schema specifies a Tensor is mutated,
- then we check the implementation mutates the Tensor. If the schema
- specifies that we return a new Tensor, then we check that the
- implementation returns a new Tensor (instead of an existing one or
- a view of an existing one).
- - test_autograd_registration: If the operator supports training
- (autograd): we check that its autograd formula is registered via
- torch.library.register_autograd or a manual registration to one
- or more DispatchKey::Autograd keys. Any other DispatchKey-based
- registrations may lead to undefined behavior.
- - test_faketensor: If the operator has a FakeTensor kernel
- (and if it is correct). The FakeTensor kernel is necessary (
- but not sufficient) for the operator to work with PyTorch compilation
- APIs (torch.compile/export/FX). We check that a FakeTensor kernel
- (also sometimes known as a meta kernel) was registered for the
- operator and that it is correct. This test takes the result of
- running the operator on real tensors and the result of running
- the operator on FakeTensors and checks that they have the same
- Tensor metadata (sizes/strides/dtype/device/etc).
- - test_aot_dispatch_dynamic: If the operator has correct behavior
- with PyTorch compilation APIs (torch.compile/export/FX).
- This checks that the outputs (and gradients, if applicable) are the
- same under eager-mode PyTorch and torch.compile.
- This test is a superset of ``test_faketensor`` and is an e2e test;
- other things it tests are that the operator supports
- functionalization and that the backward pass (if it exists) also
- supports FakeTensor and functionalization.
- For best results, please call ``opcheck`` multiple times with a
- representative set of inputs. If your operator supports
- autograd, please use ``opcheck`` with inputs with ``requires_grad = True``;
- if your operator supports multiple devices (e.g. CPU and CUDA), please
- use ``opcheck`` with inputs on all supported devices.
- Args:
- op: The operator. Must either be a function decorated with
- :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket
- found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)
- args: The args to the operator
- kwargs: The kwargs to the operator
- test_utils: Tests that we should run. Default: all of them.
- Example: ("test_schema", "test_faketensor")
- raise_exception: If we should raise an exception on the first
- error. If False, we will return a dict with information
- on if each test passed or not.
- rtol (Optional[float]): Relative tolerance for floating point comparisons.
- If specified ``atol`` must also be specified.
- If omitted, default values based on the ``dtype`` are selected
- (see the table in :func:`torch.testing.assert_close`).
- atol (Optional[float]): Absolute tolerance for floating point comparisons.
- If specified ``rtol`` must also be specified.
- If omitted, default values based on the ``dtype`` are selected
- (see the table in :func:`torch.testing.assert_close`).
- .. warning::
- opcheck and :func:`torch.autograd.gradcheck` test different things;
- opcheck tests if your usage of torch.library APIs is correct while
- :func:`torch.autograd.gradcheck` tests if your autograd formula is
- mathematically correct. Use both to test custom ops that support
- gradient computation.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
- >>> def numpy_mul(x: Tensor, y: float) -> Tensor:
- >>> x_np = x.numpy(force=True)
- >>> z_np = x_np * y
- >>> return torch.from_numpy(z_np).to(x.device)
- >>>
- >>> @numpy_mul.register_fake
- >>> def _(x, y):
- >>> return torch.empty_like(x)
- >>>
- >>> def setup_context(ctx, inputs, output):
- >>> y, = inputs
- >>> ctx.y = y
- >>>
- >>> def backward(ctx, grad):
- >>> return grad * ctx.y, None
- >>>
- >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
- >>>
- >>> sample_inputs = [
- >>> (torch.randn(3), 3.14),
- >>> (torch.randn(2, 3, device='cuda'), 2.718),
- >>> (torch.randn(1, 10, requires_grad=True), 1.234),
- >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
- >>> ]
- >>>
- >>> for args in sample_inputs:
- >>> torch.library.opcheck(numpy_mul, args)
- """
- import torch.testing._internal.optests as optests
- return optests.opcheck(
- op,
- args,
- kwargs,
- test_utils=test_utils,
- raise_exception=raise_exception,
- rtol=rtol,
- atol=atol,
- )
|