impl.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import functools
  4. import inspect
  5. import sys
  6. import typing
  7. import warnings
  8. import weakref
  9. import torch
  10. import torch._C as _C
  11. import torch._library.infer_schema
  12. import torch.library as library
  13. from torch._library.infer_schema import infer_schema
  14. from torch.library import get_ctx
  15. from torchgen.model import (
  16. BaseTy,
  17. BaseType,
  18. FunctionSchema,
  19. ListType,
  20. OperatorName,
  21. SchemaKind,
  22. )
  23. from .autograd import autograd_kernel_indirection, construct_autograd_kernel
  24. """
  25. torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library.
  26. Please use those APIs instead.
  27. """
  28. __all__ = ["custom_op", "CustomOp", "get_ctx"]
  29. SUPPORTED_DEVICE_TYPE_TO_KEY = {
  30. "cpu": "CPU",
  31. "cuda": "CUDA",
  32. }
  33. # We will not let users register CustomOps with anything that could look like
  34. # PyTorch internals to avoid confusion.
  35. RESERVED_NS = {
  36. "prim",
  37. "prims",
  38. "aten",
  39. "at",
  40. "torch",
  41. "pytorch",
  42. }
  43. def warn_deprecated():
  44. warnings.warn(
  45. "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please "
  46. "use the equivalent torch.library API instead.",
  47. DeprecationWarning,
  48. stacklevel=2,
  49. )
  50. def custom_op(
  51. qualname: str, manual_schema: typing.Optional[str] = None
  52. ) -> typing.Callable:
  53. r"""
  54. This API is deprecated, please use torch.library.custom_op instead
  55. """
  56. warn_deprecated()
  57. def inner(func):
  58. if not inspect.isfunction(func):
  59. raise ValueError(
  60. f"custom_op(...)(func): Expected `func` to be a Python "
  61. f"function, got: {type(func)}"
  62. )
  63. ns, name = parse_qualname(qualname)
  64. validate_namespace(ns)
  65. if func.__name__ != name:
  66. raise ValueError(
  67. f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
  68. f"to have name '{name}' but got '{func.__name__}'. "
  69. f"Please either change the name of `func` or the qualname that "
  70. f"is passed to `custom_op`"
  71. )
  72. schema = (
  73. infer_schema(func, mutates_args=())
  74. if manual_schema is None
  75. else manual_schema
  76. )
  77. schema_str = f"{name}{schema}"
  78. function_schema = FunctionSchema.parse(schema_str)
  79. validate_schema(function_schema)
  80. if manual_schema is not None:
  81. validate_function_matches_schema(function_schema, func)
  82. lib = library.Library(ns, "FRAGMENT")
  83. lib.define(schema_str)
  84. ophandle = find_ophandle_or_throw(ns, function_schema.name)
  85. result = CustomOp(
  86. lib, ns, function_schema, name, ophandle, _private_access=True
  87. )
  88. result.__name__ = func.__name__ # pyrefly: ignore [bad-assignment]
  89. result.__module__ = func.__module__
  90. result.__doc__ = func.__doc__
  91. library.impl(lib, result._opname, "Autograd")(
  92. autograd_kernel_indirection(weakref.proxy(result))
  93. )
  94. torch._C._dispatch_set_report_error_callback(
  95. ophandle, functools.partial(report_error_callback, weakref.proxy(result))
  96. )
  97. return result
  98. return inner
  99. # Global dictionary holding references to all CustomOp objects
  100. # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
  101. # Used to query the CustomOp associated with a specific C++ dispatcher operator.
  102. # An example usage is FakeTensor: FakeTensor checks if a specific operator
  103. # has an implementation registered via the CustomOp API.
  104. # Indexed by qualname (e.g. aten::foo)
  105. global_registry: dict[str, "CustomOp"] = {}
  106. class CustomOp:
  107. r"""
  108. This API is deprecated, please use torch.library.custom_op instead
  109. """
  110. def __init__(
  111. self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False
  112. ):
  113. super().__init__()
  114. warn_deprecated()
  115. if not _private_access:
  116. raise RuntimeError(
  117. "The CustomOp constructor is private and we do not guarantee "
  118. "BC for it. Please use custom_op(...) to create a CustomOp object"
  119. )
  120. name = f"{cpp_ns}::{operator_name}"
  121. self._schema = schema
  122. self._cpp_ns = cpp_ns
  123. self._lib: library.Library = lib
  124. self._ophandle: _C._DispatchOperatorHandle = ophandle
  125. # Has the name of the op, e.g. "foo". We cache here for convenience.
  126. self._opname: str = operator_name
  127. # this is _opname but with namespace. e.g. "custom::foo"
  128. self._qualname: str = name
  129. self.__name__ = None # mypy requires this
  130. # NB: Some of these impls are registered as kernels to DispatchKeys.
  131. # Modifying the _impls dict directly won't do anything in that case.
  132. self._impls: dict[str, typing.Optional[FuncAndLocation]] = {}
  133. # See NOTE [CustomOp autograd kernel indirection]
  134. self._registered_autograd_kernel_indirection = False
  135. global_registry[self._qualname] = self
  136. def _register_autograd_kernel_indirection(self):
  137. if self._registered_autograd_kernel_indirection:
  138. raise AssertionError("autograd kernel indirection already registered")
  139. self._lib.impl(
  140. self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd"
  141. )
  142. self._registered_autograd_kernel_indirection = True
  143. # Records the impl and the source location in self._impls
  144. # Note that this doesn't cause torch.library to use the impl, that
  145. # needs to be done in a separate self._lib.impl call.
  146. def _register_impl(self, kind, func, stacklevel=2):
  147. if self._has_impl(kind):
  148. func_and_location = self._impls[kind]
  149. if func_and_location is None:
  150. raise AssertionError("func_and_location is unexpectedly None")
  151. location = func_and_location.location
  152. raise RuntimeError(
  153. f"Attempting to register a {kind} impl for operator {self._qualname} "
  154. f"that already has a {kind} impl registered from Python at "
  155. f"{location}. This is not supported."
  156. )
  157. frame = inspect.getframeinfo(sys._getframe(stacklevel))
  158. location = f"{frame.filename}:{frame.lineno}"
  159. self._impls[kind] = FuncAndLocation(func, location)
  160. def _get_impl(self, kind):
  161. return self._impls[kind]
  162. def _has_impl(self, kind):
  163. return kind in self._impls
  164. def _destroy(self):
  165. # NOTE: [CustomOp lifetime]
  166. # A CustomOp, once created, lives forever. The mechanism is that the
  167. # global registry holds a reference to it. However, to make testing
  168. # easier, we want to be able to destroy CustomOp objects.
  169. # CustomOp._destroy does the job, though it leaves the CustomOp
  170. # in a garbage state.
  171. del self._lib
  172. opnamespace = getattr(torch.ops, self._cpp_ns)
  173. if hasattr(opnamespace, self._opname):
  174. delattr(opnamespace, self._opname)
  175. del global_registry[self._qualname]
  176. def __repr__(self):
  177. return f'<CustomOp(op="{self._qualname}")>'
  178. def __call__(self, *args, **kwargs):
  179. # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
  180. # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
  181. # issues from caching operators that make testing CustomOp difficult).
  182. result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
  183. return result
  184. def impl(
  185. self,
  186. device_types: typing.Union[str, typing.Iterable[str]],
  187. _stacklevel=2,
  188. ) -> typing.Callable:
  189. r"""
  190. This API is deprecated, please use torch.library.custom_op instead
  191. """
  192. if isinstance(device_types, str):
  193. device_types = [device_types]
  194. for device_type in device_types:
  195. validate_device_type(device_type)
  196. def inner(f):
  197. for device_type in set(device_types):
  198. self._check_doesnt_have_library_impl(device_type)
  199. self._register_impl(device_type, f, stacklevel=_stacklevel)
  200. dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
  201. library.impl(self._lib, self._opname, dispatch_key)(f)
  202. return f
  203. return inner
  204. def _check_doesnt_have_library_impl(self, device_type):
  205. if self._has_impl(device_type):
  206. return
  207. key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
  208. if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
  209. raise RuntimeError(
  210. f"impl(..., device_types={device_type}): the operator {self._qualname} "
  211. f"already has an implementation for this device type via a "
  212. f"pre-existing torch.library or TORCH_LIBRARY registration."
  213. )
  214. def impl_factory(self) -> typing.Callable:
  215. r"""Register an implementation for a factory function."""
  216. def inner(f):
  217. self._register_impl("factory", f)
  218. library.impl(self._lib, self._opname, "BackendSelect")(f)
  219. return f
  220. return inner
  221. def impl_abstract(self, _stacklevel=2) -> typing.Callable:
  222. r"""
  223. This API is deprecated, please use torch.library.custom_op instead
  224. """
  225. def inner(f):
  226. self._check_doesnt_have_library_meta_impl()
  227. self._register_impl("abstract", f, stacklevel=_stacklevel)
  228. location = self._get_impl("abstract").location
  229. qualname = self._qualname
  230. # Handle DispatchKey.Meta registration
  231. @functools.wraps(f)
  232. def f_with_ctx(*args, **kwargs):
  233. def error_on_ctx():
  234. raise RuntimeError(
  235. f"Attempted to call get_ctx() for the meta implementation "
  236. f"for {qualname}."
  237. f"You have presumably called get_ctx() because the operator "
  238. f"has a data-dependent output shape; if so, there is no "
  239. f"such meta implementation and this error is the correct "
  240. f"behavior. Otherwise, please remove the call to get_ctx() "
  241. f"in the implementation registered with impl_abstract "
  242. f"at {location}"
  243. )
  244. with torch._library.fake_impl.set_ctx_getter(error_on_ctx):
  245. return f(*args, **kwargs)
  246. self._lib.impl(self._opname, f_with_ctx, "Meta")
  247. return f
  248. return inner
  249. def _check_can_register_backward(self):
  250. def error(detail):
  251. raise RuntimeError(
  252. f"Cannot use torch._custom_ops APIs to register backward "
  253. f"formula for {detail}. Got operator "
  254. f"{self._qualname} with schema: {schema}"
  255. )
  256. schema = self._schema
  257. if schema.kind() != SchemaKind.functional:
  258. error("non-functional operator")
  259. rets = schema.returns
  260. if not schema.returns:
  261. error("operator with no returns")
  262. if len(rets) <= 0:
  263. raise AssertionError(f"expected at least one return, got {len(rets)}")
  264. is_non_mutating_view = any(
  265. r.annotation is not None and not r.annotation.is_write for r in rets
  266. )
  267. if is_non_mutating_view:
  268. error("operator that returns views")
  269. # We make assumptions about the schema's return types.
  270. allowed_return_types = {
  271. BaseType(BaseTy.int): "int",
  272. BaseType(BaseTy.SymInt): "SymInt",
  273. BaseType(BaseTy.bool): "bool",
  274. BaseType(BaseTy.float): "float",
  275. BaseType(BaseTy.Tensor): "Tensor",
  276. ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
  277. }
  278. for ret in schema.returns:
  279. if ret.type in allowed_return_types:
  280. continue
  281. error(
  282. f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})"
  283. )
  284. def _check_doesnt_have_library_autograd_impl(self):
  285. if self._registered_autograd_kernel_indirection:
  286. return
  287. if _C._dispatch_has_kernel_for_dispatch_key(
  288. self._qualname, "CompositeImplicitAutograd"
  289. ):
  290. raise RuntimeError(
  291. f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
  292. f"already has an implementation for this device type via a "
  293. f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
  294. f"CompositeImplicitAutograd operators do not need an autograd formula; "
  295. f"instead, the operator will decompose into its constituents and those "
  296. f"can have autograd formulas defined on them."
  297. )
  298. # We can improve this by adding "all Autograd<BACKEND> keys", but
  299. # realistically people will just be using this API for CPU/CUDA for now.
  300. for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
  301. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
  302. raise RuntimeError(
  303. f"impl_backward/impl_save_for_backward: "
  304. f"the operator {self._qualname} already has an Autograd kernel "
  305. f"registered to DispatchKey::{key} vi a pre-existing "
  306. f"torch.library or TORCH_LIBRARY registration. Please either "
  307. f"remove those registrations or don't use the torch._custom_ops APIs"
  308. )
  309. def _check_doesnt_have_library_meta_impl(self):
  310. if self._has_impl("abstract"):
  311. return
  312. # If the user's operator is CompositeExplicitAutograd,
  313. # allow them to impl_abstract. This is being pragmatic
  314. # (existing custom ops may have CompositeExplicitAutograd
  315. # registration that don't work with Meta kernels, so this
  316. # gives them an escape hatch).
  317. if _C._dispatch_has_kernel_for_dispatch_key(
  318. self._qualname, "CompositeExplicitAutograd"
  319. ) and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
  320. return
  321. # Otherwise, if the user's already has a Meta kernel or their
  322. # op is CompositeImplicitAutograd or some other alias dispatch key,
  323. # raise.
  324. # Special case for CompositeImplicitAutograd
  325. if _C._dispatch_has_kernel_for_dispatch_key(
  326. self._qualname, "CompositeImplicitAutograd"
  327. ):
  328. raise RuntimeError(
  329. f"impl_abstract(...): the operator {self._qualname} "
  330. f"already has an implementation for this device type via a "
  331. f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
  332. f"CompositeImplicitAutograd operators do not need an abstract impl; "
  333. f"instead, the operator will decompose into its constituents and those "
  334. f"can have abstract impls defined on them."
  335. )
  336. if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
  337. raise RuntimeError(
  338. f"impl_abstract(...): the operator {self._qualname} "
  339. f"already has an DispatchKey::Meta implementation via a "
  340. f"pre-existing torch.library or TORCH_LIBRARY registration. "
  341. f"Please either remove that registration or don't call impl_abstract."
  342. )
  343. # NOTE ["backward", "save_for_backward", and "autograd"]
  344. # As a part of the explicit autograd API, a user must provide us
  345. # a "save_for_backward" function and a "backward" function.
  346. # When both of these have been provided, then we automatically
  347. # construct the "autograd" kernel.
  348. def _register_autograd_kernel(self):
  349. if not self._has_impl("backward"):
  350. raise AssertionError("backward impl must be registered first")
  351. if not self._has_impl("save_for_backward"):
  352. raise AssertionError("save_for_backward impl must be registered first")
  353. kernel = construct_autograd_kernel(
  354. self._schema,
  355. self._output_differentiability,
  356. self,
  357. get_op(self._qualname),
  358. self._get_impl("save_for_backward").func,
  359. self._get_impl("backward").func,
  360. )
  361. self._register_impl("autograd", kernel)
  362. def impl_save_for_backward(self, _stacklevel=2):
  363. r"""Register a function that tells us what to save for backward.
  364. Please see impl_backward for more details.
  365. """
  366. def inner(f):
  367. self._check_can_register_backward()
  368. self._check_doesnt_have_library_autograd_impl()
  369. if not self._registered_autograd_kernel_indirection:
  370. self._register_autograd_kernel_indirection()
  371. self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
  372. if self._has_impl("backward"):
  373. self._register_autograd_kernel()
  374. return inner
  375. def impl_backward(self, output_differentiability=None, _stacklevel=2):
  376. r"""
  377. This API is deprecated, please use torch.library.custom_op instead
  378. """
  379. if output_differentiability is not None:
  380. def yell():
  381. raise RuntimeError(
  382. f"impl_backward(output_differentiability): expected "
  383. f"output_differentiability to be a list of bools with "
  384. f"length equal to the number of outputs of this CustomOp "
  385. f"got: {output_differentiability}"
  386. )
  387. if not isinstance(output_differentiability, list):
  388. yell()
  389. for diff in output_differentiability:
  390. if not isinstance(diff, bool):
  391. yell()
  392. if len(self._schema.returns) != len(output_differentiability):
  393. yell()
  394. def inner(f):
  395. self._check_can_register_backward()
  396. self._check_doesnt_have_library_autograd_impl()
  397. if not self._registered_autograd_kernel_indirection:
  398. self._register_autograd_kernel_indirection()
  399. self._register_impl("backward", f, stacklevel=_stacklevel)
  400. self._output_differentiability = output_differentiability
  401. if self._has_impl("save_for_backward"):
  402. self._register_autograd_kernel()
  403. return inner
  404. @dataclasses.dataclass
  405. class FuncAndLocation:
  406. func: typing.Callable
  407. location: str
  408. def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
  409. overload_name = (
  410. "" if operator_name.overload_name is None else operator_name.overload_name
  411. )
  412. return _C._dispatch_find_schema_or_throw(
  413. f"{cpp_ns}::{str(operator_name.name)}", overload_name
  414. )
  415. def validate_namespace(ns: str) -> None:
  416. if "." in ns:
  417. raise ValueError(
  418. f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
  419. f"valid variable name)"
  420. )
  421. if ns in RESERVED_NS:
  422. raise ValueError(
  423. f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
  424. f"please choose something else. "
  425. )
  426. def validate_schema(schema: FunctionSchema) -> None:
  427. if not torch._library.utils.is_functional_schema(schema):
  428. raise ValueError(
  429. f"custom_op only supports functional operators "
  430. f"(ops that do not mutate any inputs, do not return "
  431. f"views of the inputs, and has at least one return). "
  432. f"Got the following non-functional schema: {schema}"
  433. )
  434. # For simplicity: don't allow self arguments
  435. if schema.arguments.self_arg is not None:
  436. raise ValueError(
  437. f"custom_op does not support arguments named 'self'. Please "
  438. f"rename your argument. Got: {schema}"
  439. )
  440. def parse_qualname(qualname: str) -> tuple[str, str]:
  441. names = qualname.split("::", 1)
  442. if len(names) != 2:
  443. raise ValueError(
  444. f"Expected there to be a namespace in {qualname}, i.e. The "
  445. f"operator name should look something like ns::foo"
  446. )
  447. if "." in names[1]:
  448. raise ValueError(
  449. f"The torch.custom_ops APIs do not handle overloads, "
  450. f"i.e. operator names with '.' in them. "
  451. f"Please name your operator something like ns::foo. "
  452. f"Got: {qualname}"
  453. )
  454. return names[0], names[1]
  455. def validate_device_type(device_type: str) -> None:
  456. if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
  457. raise ValueError(
  458. f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
  459. f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
  460. )
  461. def supported_param(param: inspect.Parameter) -> bool:
  462. return param.kind in (
  463. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  464. inspect.Parameter.KEYWORD_ONLY,
  465. )
  466. def validate_function_matches_schema(
  467. schema: FunctionSchema, func: typing.Callable
  468. ) -> None:
  469. sig = inspect.signature(func)
  470. if not all(supported_param(p) for _, p in sig.parameters.items()):
  471. raise ValueError(
  472. f"custom_op(..., manual_schema)(func): positional-only args, "
  473. f"varargs, and kwargs are not supported. Please rewrite `func` "
  474. f"to not have them. Got `func` with signature: {sig}"
  475. )
  476. if (
  477. any(
  478. p.annotation is not inspect.Parameter.empty
  479. for _, p in sig.parameters.items()
  480. )
  481. or sig.return_annotation is not inspect.Signature.empty
  482. ):
  483. raise ValueError(
  484. f"custom_op(..., manual_schema)(func): When passing in a manual "
  485. f"schema, we expect `func` to have no type annotations to avoid "
  486. f"ambiguity. Got `func` with signature: {sig}"
  487. )
  488. positional = [
  489. (name, param)
  490. for name, param in sig.parameters.items()
  491. if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
  492. ]
  493. kwargonly = [
  494. (name, param)
  495. for name, param in sig.parameters.items()
  496. if param.kind == inspect.Parameter.KEYWORD_ONLY
  497. ]
  498. def error():
  499. raise ValueError(
  500. f"custom_op(..., manual_schema)(func): When passing in a manual "
  501. f"schema, we expect `func`'s signature to match `manual_schema` "
  502. f"(aside from type annotations). "
  503. f"func's signature: {sig}, manual_schema: {schema}"
  504. )
  505. def error_default_args():
  506. raise ValueError(
  507. f"custom_op(..., manual_schema)(func): "
  508. f"neither func nor manual_schema should have default "
  509. f"arguments. Got "
  510. f"func's signature: {sig}, manual_schema: {schema}"
  511. )
  512. def compare(sig_args, schema_args):
  513. if len(sig_args) != len(schema_args):
  514. error()
  515. for (name, param), arg in zip(sig_args, schema_args):
  516. if name != arg.name:
  517. error()
  518. if param.default is not inspect.Parameter.empty or arg.default is not None:
  519. error_default_args()
  520. compare(positional, schema.arguments.flat_positional)
  521. compare(kwargonly, schema.arguments.flat_kwarg_only)
  522. def report_error_callback(custom_op: typing.Any, key: str) -> None:
  523. if key == "Undefined":
  524. raise NotImplementedError(
  525. f"{custom_op}: There were no Tensor inputs to this operator "
  526. f"(e.g. you passed an empty list of Tensors). If your operator is a "
  527. f"factory function (that is, it takes no Tensors and constructs "
  528. f"a new one), then please use CustomOp.impl_factory to register "
  529. f"an implementation for it"
  530. )
  531. if key == "Meta":
  532. raise NotImplementedError(
  533. f"{custom_op}: when running with device='Meta' tensors: there is no "
  534. f"abstract impl registered for this CustomOp. Please register one via "
  535. f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
  536. )
  537. if key in ("CPU", "CUDA"):
  538. device = key.lower()
  539. raise NotImplementedError(
  540. f"{custom_op}: when running with device='{device}' tensors: there is no "
  541. f"{device} impl registered for this CustomOp. Please register one via "
  542. f"CustomOp.impl(device_type='{device}')"
  543. )
  544. raise NotImplementedError(
  545. f"{custom_op}: No implementation for dispatch key {key}. It is likely "
  546. f"that we have not added this functionality yet, please either open an "
  547. f"issue or if you're feeling adventurous, use the low-level "
  548. f"torch.library API"
  549. )
  550. def custom_op_from_existing(op):
  551. ns = op.namespace
  552. lib = torch.library.Library(ns, "FRAGMENT")
  553. name = op.name().split("::")[-1]
  554. schema_str = str(op._schema)
  555. # CustomOp expects the schema string without the namespace
  556. schema_str = schema_str.rsplit("::", maxsplit=1)[-1]
  557. schema = FunctionSchema.parse(schema_str)
  558. return CustomOp(lib, ns, schema, name, op, _private_access=True)
  559. def get_op(qualname):
  560. def error_not_found():
  561. raise ValueError(
  562. f"Could not find the operator {qualname}. Please make sure you have "
  563. f"already registered the operator and (if registered from C++) "
  564. f"loaded it via torch.ops.load_library."
  565. )
  566. ns, name = parse_qualname(qualname)
  567. if not hasattr(torch.ops, ns):
  568. error_not_found()
  569. opnamespace = getattr(torch.ops, ns)
  570. if not hasattr(opnamespace, name):
  571. error_not_found()
  572. packet = getattr(opnamespace, name)
  573. if not hasattr(packet, "default"):
  574. error_not_found()
  575. return packet.default
  576. def _find_custom_op(qualname, also_check_torch_library=False):
  577. if qualname in global_registry:
  578. return global_registry[qualname]
  579. if not also_check_torch_library:
  580. raise RuntimeError(
  581. f'Could not find custom op "{qualname}". Did you register it via '
  582. f"the torch._custom_ops API?"
  583. )
  584. overload = get_op(qualname)
  585. result = custom_op_from_existing(overload)
  586. return result
  587. def get_abstract_impl(qualname):
  588. if qualname not in torch._custom_op.impl.global_registry:
  589. return None
  590. custom_op = torch._custom_op.impl.global_registry[qualname]
  591. if custom_op is None:
  592. return None
  593. if not custom_op._has_impl("abstract"):
  594. return None
  595. return custom_op._get_impl("abstract").func
  596. def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
  597. ns, name = qualname.split("::")
  598. schema_str = f"{name}{schema}"
  599. function_schema = FunctionSchema.parse(schema_str)
  600. validate_schema(function_schema)
  601. tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
  602. lib = library.Library(ns, "FRAGMENT")
  603. lib.define(schema_str, tags=tags)
  604. ophandle = find_ophandle_or_throw(ns, function_schema.name)
  605. result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
  606. result._register_autograd_kernel_indirection()
  607. torch._C._dispatch_set_report_error_callback(
  608. ophandle, functools.partial(report_error_callback, weakref.proxy(result))
  609. )
  610. return get_op(qualname)