overrides.py 103 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133
  1. """
  2. Python implementation of ``__torch_function__``
  3. While most of the torch API and handling for ``__torch_function__`` happens
  4. at the C++ level, some of the torch API is written in Python so we need
  5. python-level handling for ``__torch_function__`` overrides as well. The main
  6. developer-facing functionality in this file are handle_torch_function and
  7. has_torch_function. See torch/functional.py and test/test_overrides.py
  8. for usage examples.
  9. Note
  10. ----
  11. heavily inspired by NumPy's ``__array_function__`` (see:
  12. https://github.com/pytorch/pytorch/issues/24015 and
  13. https://www.numpy.org/neps/nep-0018-array-function-protocol.html
  14. )
  15. If changing this file in a way that can affect ``__torch_function__`` overhead,
  16. please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
  17. instructions in the ``README.md`` in that directory.
  18. """
  19. import __future__ # noqa: F404
  20. import collections
  21. import contextlib
  22. import functools
  23. import sys
  24. import types
  25. import warnings
  26. from collections.abc import Callable, Iterable
  27. from functools import wraps
  28. from typing import Any, TypeVar
  29. from typing_extensions import ParamSpec
  30. import torch
  31. from torch._C import (
  32. _add_docstr,
  33. _get_function_stack_at,
  34. _has_torch_function,
  35. _has_torch_function_unary,
  36. _has_torch_function_variadic,
  37. _is_torch_function_mode_enabled,
  38. _len_torch_function_stack,
  39. _pop_torch_function_stack,
  40. _push_on_torch_function_stack,
  41. )
  42. __all__ = [
  43. "get_ignored_functions",
  44. "get_overridable_functions",
  45. "get_testing_overrides",
  46. "handle_torch_function",
  47. "has_torch_function",
  48. "resolve_name",
  49. "is_tensor_like",
  50. "is_tensor_method_or_property",
  51. "wrap_torch_function",
  52. "enable_reentrant_dispatch",
  53. ]
  54. _P = ParamSpec("_P")
  55. _R = TypeVar("_R")
  56. def _disable_user_warnings(
  57. func: Callable[_P, _R],
  58. regex: str = ".*is deprecated, please use.*",
  59. module: str = "torch",
  60. ) -> Callable[_P, _R]:
  61. """
  62. Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
  63. given ``regex`` pattern.
  64. Arguments
  65. ---------
  66. func : function
  67. Function to disable the warnings for.
  68. regex : str
  69. A regex pattern compilable by ``re.compile``. This is used to match the ``UserWarning`` message.
  70. module : str
  71. The python module to which the filtering should be restricted.
  72. Returns
  73. -------
  74. function
  75. The wrapped function.
  76. """
  77. @wraps(func)
  78. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  79. with warnings.catch_warnings():
  80. warnings.filterwarnings(
  81. "ignore", category=UserWarning, message=regex, module=module
  82. )
  83. return func(*args, **kwargs)
  84. return wrapper
  85. @functools.cache
  86. @_disable_user_warnings
  87. def get_ignored_functions() -> set[Callable]:
  88. """
  89. Return public functions that cannot be overridden by ``__torch_function__``.
  90. Returns
  91. -------
  92. set[Callable]
  93. A tuple of functions that are publicly available in the torch API but cannot
  94. be overridden with ``__torch_function__``. Mostly this is because none of the
  95. arguments of these functions are tensors or tensor-likes.
  96. Examples
  97. --------
  98. >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
  99. True
  100. >>> torch.add in torch.overrides.get_ignored_functions()
  101. False
  102. """
  103. Tensor = torch.Tensor
  104. functions = {
  105. torch.typename,
  106. torch.is_tensor,
  107. torch.is_storage,
  108. torch.set_default_tensor_type,
  109. torch.set_default_device,
  110. torch.get_default_device,
  111. torch.set_rng_state,
  112. torch.get_rng_state,
  113. torch.manual_seed,
  114. torch.initial_seed,
  115. torch.seed,
  116. torch.save,
  117. torch.load,
  118. torch.set_printoptions,
  119. torch.fork,
  120. torch.get_default_dtype,
  121. torch.get_num_interop_threads,
  122. torch.get_num_threads,
  123. torch.init_num_threads,
  124. torch.import_ir_module,
  125. torch.import_ir_module_from_buffer,
  126. torch.is_anomaly_enabled,
  127. torch.is_anomaly_check_nan_enabled,
  128. torch.is_grad_enabled,
  129. torch.merge_type_from_type_comment,
  130. torch.parse_ir,
  131. torch.parse_schema,
  132. torch.parse_type_comment,
  133. torch.set_anomaly_enabled,
  134. torch.set_flush_denormal,
  135. torch.set_num_interop_threads,
  136. torch.set_num_threads,
  137. torch.wait,
  138. torch.as_tensor,
  139. torch.from_numpy,
  140. torch.tensor,
  141. torch.default_generator,
  142. torch.has_cuda,
  143. torch.has_cudnn,
  144. torch.has_lapack,
  145. torch.device,
  146. torch.dtype,
  147. torch.finfo,
  148. torch.has_mkl,
  149. torch.has_mps,
  150. torch.has_mkldnn,
  151. torch.has_openmp,
  152. torch.iinfo,
  153. torch.memory_format,
  154. torch.qscheme,
  155. torch.set_grad_enabled,
  156. torch.no_grad,
  157. torch.enable_grad,
  158. torch.inference_mode,
  159. torch.is_inference_mode_enabled,
  160. torch.layout,
  161. torch.align_tensors,
  162. torch.arange,
  163. torch.as_strided,
  164. torch.bartlett_window,
  165. torch.blackman_window,
  166. torch.broadcast_shapes,
  167. torch.can_cast,
  168. torch.compile,
  169. torch.cudnn_affine_grid_generator,
  170. torch.cudnn_batch_norm,
  171. torch.cudnn_convolution,
  172. torch.cudnn_convolution_transpose,
  173. torch.cudnn_convolution_relu,
  174. torch.cudnn_convolution_add_relu,
  175. torch.cudnn_grid_sampler,
  176. torch.cudnn_is_acceptable,
  177. torch.miopen_ctc_loss,
  178. torch.empty,
  179. torch.empty_permuted,
  180. torch.empty_strided,
  181. torch.empty_quantized,
  182. torch.export.export,
  183. torch.export.load,
  184. torch.export.register_dataclass,
  185. torch.export.save,
  186. torch.eye,
  187. torch.fft.fftfreq,
  188. torch.fft.rfftfreq,
  189. torch.from_file,
  190. torch.full,
  191. torch.fill,
  192. torch.hamming_window,
  193. torch.hann_window,
  194. torch.kaiser_window,
  195. torch.linspace,
  196. torch.logspace,
  197. torch.mkldnn_adaptive_avg_pool2d,
  198. torch.mkldnn_convolution,
  199. torch.mkldnn_max_pool2d,
  200. torch.mkldnn_max_pool3d,
  201. torch.mkldnn_linear_backward_weights,
  202. torch.mkldnn_rnn_layer,
  203. torch.normal,
  204. torch.ones,
  205. torch.promote_types,
  206. torch.rand,
  207. torch.rand_like,
  208. torch.randn,
  209. torch.randn_like,
  210. torch.randint,
  211. torch.randint_like,
  212. torch.randperm,
  213. torch.range,
  214. torch.result_type,
  215. torch.scalar_tensor,
  216. torch.sparse_coo_tensor,
  217. torch.sparse_compressed_tensor,
  218. torch.sparse_csr_tensor,
  219. torch.sparse_csc_tensor,
  220. torch.sparse_bsr_tensor,
  221. torch.sparse_bsc_tensor,
  222. torch.sym_constrain_range,
  223. torch.sym_constrain_range_for_size,
  224. torch.sym_fresh_size,
  225. torch.tril_indices,
  226. torch.triu_indices,
  227. torch.vander,
  228. torch.zeros,
  229. torch._jit_internal.boolean_dispatch,
  230. torch.nn.functional.assert_int_or_pair,
  231. torch.nn.functional.upsample,
  232. torch.nn.functional.upsample_bilinear,
  233. torch.nn.functional.upsample_nearest,
  234. torch.nn.functional.has_torch_function,
  235. torch.nn.functional.has_torch_function_unary,
  236. torch.nn.functional.has_torch_function_variadic,
  237. torch.nn.functional.handle_torch_function,
  238. torch.nn.functional.grouped_mm,
  239. torch.nn.functional.scaled_grouped_mm,
  240. torch.nn.functional.scaled_mm,
  241. torch.nn.functional.sigmoid,
  242. torch.nn.functional.hardsigmoid,
  243. torch.nn.functional.tanh,
  244. torch.nn.functional._canonical_mask,
  245. torch.nn.functional._none_or_dtype,
  246. # Doesn't actually take or return tensor arguments
  247. torch.nn.init.calculate_gain,
  248. # These are deprecated; don't test them
  249. torch.nn.init.uniform,
  250. torch.nn.init.normal,
  251. torch.nn.init.constant,
  252. torch.nn.init.eye,
  253. torch.nn.init.dirac,
  254. torch.nn.init.xavier_uniform,
  255. torch.nn.init.xavier_normal,
  256. torch.nn.init.kaiming_uniform,
  257. torch.nn.init.kaiming_normal,
  258. torch.nn.init.orthogonal,
  259. torch.nn.init.sparse,
  260. torch.nested.to_padded_tensor,
  261. has_torch_function,
  262. handle_torch_function,
  263. torch.set_autocast_enabled,
  264. torch.is_autocast_enabled,
  265. torch.set_autocast_dtype,
  266. torch.get_autocast_dtype,
  267. torch.clear_autocast_cache,
  268. torch.set_autocast_cpu_enabled,
  269. torch.is_autocast_cpu_enabled,
  270. torch.set_autocast_xla_enabled,
  271. torch.is_autocast_xla_enabled,
  272. torch.set_autocast_ipu_enabled,
  273. torch.is_autocast_ipu_enabled,
  274. torch.set_autocast_cpu_dtype,
  275. torch.get_autocast_cpu_dtype,
  276. torch.set_autocast_ipu_dtype,
  277. torch.get_autocast_ipu_dtype,
  278. torch.get_autocast_gpu_dtype,
  279. torch.set_autocast_gpu_dtype,
  280. torch.get_autocast_xla_dtype,
  281. torch.set_autocast_xla_dtype,
  282. torch.autocast_increment_nesting,
  283. torch.autocast_decrement_nesting,
  284. torch.is_autocast_cache_enabled,
  285. torch.set_autocast_cache_enabled,
  286. torch.nn.functional.hardswish,
  287. torch.is_vulkan_available,
  288. torch.are_deterministic_algorithms_enabled,
  289. torch.use_deterministic_algorithms,
  290. torch.is_deterministic_algorithms_warn_only_enabled,
  291. torch.set_deterministic_debug_mode,
  292. torch.get_device_module,
  293. torch.get_deterministic_debug_mode,
  294. torch.set_float32_matmul_precision,
  295. torch.get_float32_matmul_precision,
  296. torch.unify_type_list,
  297. torch.is_warn_always_enabled,
  298. torch.set_warn_always,
  299. torch.vitals_enabled,
  300. torch.set_vital,
  301. torch.read_vitals,
  302. torch.vmap,
  303. torch.cond,
  304. torch.frombuffer,
  305. torch.asarray,
  306. torch._functional_sym_constrain_range,
  307. torch._make_dep_token,
  308. Tensor.__delitem__,
  309. Tensor.__dir__,
  310. Tensor.__getattribute__,
  311. Tensor.__init__,
  312. Tensor.__iter__,
  313. Tensor.__init_subclass__,
  314. Tensor.__delattr__,
  315. Tensor.__setattr__,
  316. Tensor.__torch_function__,
  317. Tensor.__torch_dispatch__,
  318. Tensor.__new__,
  319. Tensor.__class__,
  320. Tensor.__subclasshook__,
  321. Tensor.__hash__,
  322. Tensor.as_subclass,
  323. Tensor.eig,
  324. Tensor.lstsq,
  325. Tensor.reinforce,
  326. Tensor.new,
  327. Tensor.new_tensor,
  328. Tensor.new_empty,
  329. Tensor.new_empty_strided,
  330. Tensor.new_zeros,
  331. Tensor.new_ones,
  332. Tensor.new_full,
  333. Tensor._make_subclass,
  334. Tensor.solve,
  335. Tensor.symeig,
  336. Tensor.stride,
  337. Tensor.unflatten,
  338. Tensor.to_sparse_coo,
  339. Tensor.to_sparse_csr,
  340. Tensor.to_sparse_csc,
  341. Tensor.to_sparse_bsr,
  342. Tensor.to_sparse_bsc,
  343. Tensor._to_sparse,
  344. Tensor._to_sparse_csr,
  345. Tensor._to_sparse_csc,
  346. Tensor._to_sparse_bsr,
  347. Tensor._to_sparse_bsc,
  348. Tensor._typed_storage,
  349. Tensor._reduce_ex_internal,
  350. Tensor._fix_weakref,
  351. Tensor._view_func,
  352. Tensor._view_func_unsafe,
  353. Tensor._rev_view_func_unsafe,
  354. Tensor._dtensor__new__,
  355. Tensor._make_wrapper_subclass,
  356. Tensor._python_dispatch.__get__,
  357. Tensor._has_symbolic_sizes_strides.__get__,
  358. Tensor._conj,
  359. Tensor._conj_physical,
  360. Tensor._lazy_clone,
  361. Tensor._neg_view,
  362. Tensor._is_zerotensor,
  363. Tensor._is_all_true,
  364. Tensor._is_any_true,
  365. Tensor._addmm_activation,
  366. Tensor.to_padded_tensor,
  367. Tensor._use_count,
  368. }
  369. if sys.version_info >= (3, 14):
  370. functions.add(Tensor.__annotate__)
  371. return functions
  372. @functools.cache
  373. def get_default_nowrap_functions() -> set[Callable]:
  374. """
  375. Return public functions that do not wrap in a subclass when invoked by
  376. the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
  377. these functions represent field accesses (i.e., retrieving a Tensor that
  378. is stored somewhere on the Tensor) as opposed to computation. Users of
  379. these functions expect object identity to be preserved over multiple accesses
  380. (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
  381. the fly every time (furthermore, the tensor stored here might already be
  382. the subclass, in which case wrapping really ought not to happen).
  383. Not ALL property accessors have this property; for example ``Tensor.T`` actually
  384. just creates a new transposed tensor on the fly, and so we SHOULD interpose on
  385. these calls (you need to check the implementation of the function to see if
  386. this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
  387. it doesn't have to be on this list (though it is harmless if it is).
  388. """
  389. Tensor = torch.Tensor
  390. return {
  391. Tensor._base.__get__,
  392. Tensor.grad.__get__,
  393. Tensor._grad.__get__,
  394. }
  395. @functools.cache
  396. @_disable_user_warnings
  397. def get_testing_overrides() -> dict[Callable, Callable]:
  398. """Return a dict containing dummy overrides for all overridable functions
  399. Returns
  400. -------
  401. Dict[Callable, Callable]
  402. A dictionary that maps overridable functions in the PyTorch API to
  403. lambda functions that have the same signature as the real function
  404. and unconditionally return -1. These lambda functions are useful
  405. for testing API coverage for a type that defines ``__torch_function__``.
  406. Examples
  407. --------
  408. >>> import inspect
  409. >>> my_add = torch.overrides.get_testing_overrides()[torch.add]
  410. >>> inspect.signature(my_add)
  411. <Signature (input, other, out=None)>
  412. """
  413. # Every function in the PyTorchAPI that can be overridden needs an entry
  414. # in this dict.
  415. #
  416. # Optimally we would use inspect to get the function signature and define
  417. # the lambda function procedurally but that is blocked by generating
  418. # function signatures for native kernels that can be consumed by inspect.
  419. # See Issue #28233.
  420. Tensor = torch.Tensor
  421. ret: dict[Callable, Callable] = {
  422. torch.abs: lambda input, out=None: -1,
  423. torch.absolute: lambda input, out=None: -1,
  424. torch.adaptive_avg_pool1d: lambda input, output_size: -1,
  425. torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
  426. torch.acos: lambda input, out=None: -1,
  427. torch.adjoint: lambda input: -1,
  428. torch.arccos: lambda input, out=None: -1,
  429. torch.acosh: lambda input, out=None: -1,
  430. torch.arccosh: lambda input, out=None: -1,
  431. torch.add: lambda input, other, out=None: -1,
  432. torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
  433. torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
  434. torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
  435. torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  436. torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
  437. torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
  438. torch.affine_grid_generator: lambda theta, size, align_corners: -1,
  439. torch.all: lambda input, dim=None: -1,
  440. torch.allclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
  441. torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
  442. torch.amax: lambda input, dim=None: -1,
  443. torch.amin: lambda input, dim=None: -1,
  444. torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1,
  445. torch.angle: lambda input, out=None: -1,
  446. torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
  447. torch.argmax: lambda input: -1,
  448. torch.argmin: lambda input: -1,
  449. torch.argsort: lambda input, dim=None: -1,
  450. torch.asin: lambda input, out=None: -1,
  451. torch._assert_async: lambda input, msg: -1,
  452. torch.arcsin: lambda input, out=None: -1,
  453. torch.asinh: lambda input, out=None: -1,
  454. torch.arcsinh: lambda input, out=None: -1,
  455. torch.atan: lambda input, out=None: -1,
  456. torch.arctan: lambda input, out=None: -1,
  457. torch.atan2: lambda input, other, out=None: -1,
  458. torch.arctan2: lambda input, other, out=None: -1,
  459. torch.atanh: lambda input, out=None: -1,
  460. torch.arctanh: lambda input, out=None: -1,
  461. torch.atleast_1d: lambda *tensors: -1,
  462. torch.atleast_2d: lambda *tensors: -1,
  463. torch.atleast_3d: lambda *tensors: -1,
  464. torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
  465. torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
  466. torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
  467. torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
  468. torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
  469. torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
  470. torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
  471. torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
  472. torch.batch_norm_stats: lambda input, eps: -1,
  473. torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
  474. torch.bernoulli: lambda input, generator=None, out=None: -1,
  475. torch.bilinear: lambda input1, input2, weight, bias: -1,
  476. torch.binary_cross_entropy_with_logits: (
  477. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
  478. ),
  479. torch.bincount: lambda input, weights=None, minlength=0: -1,
  480. torch.binomial: lambda count, prob, generator=None: -1,
  481. torch.bitwise_and: lambda input, other, out=None: -1,
  482. torch.bitwise_not: lambda input, out=None: -1,
  483. torch.bitwise_or: lambda input, other, out=None: -1,
  484. torch.bitwise_xor: lambda input, other, out=None: -1,
  485. torch.bitwise_left_shift: lambda input, other, out=None: -1,
  486. torch.bitwise_right_shift: lambda input, other, out=None: -1,
  487. torch.block_diag: lambda *tensors: -1,
  488. torch.bmm: lambda input, mat2, out_dtype=None, out=None: -1,
  489. torch.broadcast_tensors: lambda *tensors: -1,
  490. torch.broadcast_to: lambda self, size: -1,
  491. torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
  492. torch.cartesian_prod: lambda *tensors: -1,
  493. torch.cat: lambda tensors, dim=0, out=None: -1,
  494. torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
  495. torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
  496. torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
  497. torch.ceil: lambda input, out=None: -1,
  498. torch.celu: lambda input, alpha=1.0, inplace=False: -1,
  499. torch.chain_matmul: lambda *matrices, out=None: -1,
  500. torch.channel_shuffle: lambda input, groups: -1,
  501. torch.cholesky: lambda input, upper=False, out=None: -1,
  502. torch.linalg.cholesky: lambda input, out=None: -1,
  503. torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
  504. torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
  505. torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
  506. torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
  507. torch.chunk: lambda input, chunks, dim=0: -1,
  508. torch.clamp: lambda input, min=None, max=None, out=None: -1,
  509. torch.clip: lambda input, min=None, max=None, out=None: -1,
  510. torch.clamp_min: lambda input, min, out=None: -1,
  511. torch.clamp_max: lambda input, max, out=None: -1,
  512. torch.column_stack: lambda tensors, out=None: -1,
  513. torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
  514. torch.clone: lambda input: -1,
  515. torch.combinations: lambda input, r=2, with_replacement=False: -1,
  516. torch.complex: lambda real, imag: -1,
  517. torch.copysign: lambda input, other, out=None: -1,
  518. torch.polar: lambda abs, ang: -1,
  519. torch.linalg.cond: lambda input, ord=None: -1,
  520. torch.conj: lambda input, out=None: -1,
  521. torch.conj_physical: lambda input, out=None: -1,
  522. torch.resolve_conj: lambda input, out=None: -1,
  523. torch.resolve_neg: lambda input, out=None: -1,
  524. torch.constant_pad_nd: lambda input, pad, value=0: -1,
  525. torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  526. torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  527. torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
  528. torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
  529. torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
  530. torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  531. torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  532. torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
  533. torch.corrcoef: lambda input: -1,
  534. torch.cos: lambda input, out=None: -1,
  535. torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
  536. torch.cosh: lambda input, out=None: -1,
  537. torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
  538. torch.count_nonzero: lambda input: -1,
  539. torch.cross: lambda input, other, dim=None, out=None: -1,
  540. torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
  541. torch.ctc_loss: (
  542. lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
  543. ),
  544. torch.cummax: lambda input, dim, out=None: -1,
  545. torch.cummin: lambda input, dim, out=None: -1,
  546. torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
  547. torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
  548. torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
  549. torch.logcumsumexp: lambda input, dim, out=None: -1,
  550. torch.deg2rad: lambda input, out=None: -1,
  551. torch.dequantize: lambda input: -1,
  552. torch.det: lambda input: -1,
  553. torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined]
  554. torch.detach: lambda input: -1,
  555. torch.diag: lambda input, diagonal=0, out=None: -1,
  556. torch.diag_embed: lambda input, diagonal=0, out=None: -1,
  557. torch.diagflat: lambda input, offset=0: -1,
  558. torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
  559. torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
  560. torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
  561. torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
  562. torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1,
  563. torch.digamma: lambda input, out=None: -1,
  564. torch.dist: lambda input, other, p=2: -1,
  565. torch.div: lambda input, other, rounding_mode=None, out=None: -1,
  566. torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
  567. torch.dot: lambda input, other, out=None: -1,
  568. torch.dropout: lambda input, p, train, inplace=False: -1,
  569. torch.dsmm: lambda input, mat2, out_dtype=None: -1,
  570. torch.hsmm: lambda mat1, mat2: -1,
  571. torch.dsplit: lambda input, indices_or_sections: -1,
  572. torch.dstack: lambda tensors, out=None: -1,
  573. torch.linalg.eig: lambda input, out=None: -1,
  574. torch.linalg.eigvals: lambda input, out=None: -1,
  575. torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
  576. torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
  577. torch.einsum: lambda equation, *operands: -1,
  578. torch.embedding: (
  579. lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
  580. ),
  581. torch.embedding_bag: (
  582. lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950
  583. ),
  584. torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  585. torch.eq: lambda input, other, out=None: -1,
  586. torch.equal: lambda input, other: -1,
  587. torch.erf: lambda input, out=None: -1,
  588. torch.erfc: lambda input, out=None: -1,
  589. torch.erfinv: lambda input, out=None: -1,
  590. torch.exp: lambda input, out=None: -1,
  591. torch.exp2: lambda input, out=None: -1,
  592. torch.expm1: lambda input, out=None: -1,
  593. torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
  594. torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
  595. torch.fused_moving_avg_obs_fake_quant: (
  596. lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950
  597. ),
  598. torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias, output: -1,
  599. torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias, output: -1,
  600. torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950
  601. torch.fbgemm_linear_int8_weight_fp32_activation: (
  602. lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
  603. ),
  604. torch.fbgemm_linear_quantize_weight: lambda input: -1,
  605. torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
  606. torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
  607. torch.feature_alpha_dropout: lambda input, p, train: -1,
  608. torch.feature_dropout: lambda input, p, train: -1,
  609. torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
  610. torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
  611. torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
  612. torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
  613. torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
  614. torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  615. torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  616. torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1,
  617. torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1,
  618. torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
  619. torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
  620. torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
  621. torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
  622. torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  623. torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  624. torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  625. torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
  626. torch.fft.fftshift: lambda input, dim=None: -1,
  627. torch.fft.ifftshift: lambda input, dim=None: -1,
  628. torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
  629. torch.fix: lambda input, out=None: -1,
  630. torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
  631. torch.flip: lambda input, dims: -1,
  632. torch.fliplr: lambda input: -1,
  633. torch.flipud: lambda input: -1,
  634. torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
  635. torch.floor: lambda input, out=None: -1,
  636. torch.floor_divide: lambda input, other: -1,
  637. torch.float_power: lambda input, exponent, out=None: -1,
  638. torch.fmod: lambda input, other, out=None: -1,
  639. torch.frac: lambda input, out=None: -1,
  640. torch.frexp: lambda input, out=None: -1,
  641. torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950
  642. torch._functional_assert_async: lambda input, msg, dep_token: -1,
  643. torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
  644. torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
  645. torch.gcd: lambda input, other, out=None: -1,
  646. torch.ge: lambda input, other, out=None: -1,
  647. torch.get_device: lambda input: -1,
  648. torch.greater_equal: lambda input, other, out=None: -1,
  649. torch.geqrf: lambda input, out=None: -1,
  650. torch.i0: lambda input, out=None: -1,
  651. torch.inner: lambda input, other, out=None: -1,
  652. torch.outer: lambda input, vec2, out=None: -1,
  653. torch.ger: lambda input, vec2, out=None: -1, # alias for torch.outer
  654. torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
  655. torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  656. torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  657. torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
  658. torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
  659. torch.gru: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
  660. torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  661. torch.gt: lambda input, other, out=None: -1,
  662. torch.greater: lambda input, other, out=None: -1,
  663. torch.hardshrink: lambda input, lambd=0.5: -1,
  664. torch.hash_tensor: lambda input, dim=None, keepdim=False, mode=0, out=None: -1,
  665. torch.heaviside: lambda input, values, out=None: -1,
  666. torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
  667. torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
  668. torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
  669. torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
  670. torch.linalg.householder_product: lambda input, tau: -1,
  671. torch.hspmm: lambda mat1, mat2, out=None: -1,
  672. torch.hsplit: lambda input, indices_or_sections: -1,
  673. torch.hstack: lambda tensors, out=None: -1,
  674. torch.hypot: lambda input, other, out=None: -1,
  675. torch.igamma: lambda input, other, out=None: -1,
  676. torch.igammac: lambda input, other, out=None: -1,
  677. torch.imag: lambda input, out=None: -1,
  678. torch.index_add: lambda input, dim, index, source: -1,
  679. torch.index_copy: lambda input, dim, index, source: -1,
  680. torch.index_put: lambda input, indices, values, accumulate=False: -1,
  681. torch.index_select: lambda input, dim, index, out=None: -1,
  682. torch.index_fill: lambda input, dim, index, value: -1,
  683. torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1,
  684. torch.isfinite: lambda tensor: -1,
  685. torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
  686. torch.isinf: lambda tensor: -1,
  687. torch.isreal: lambda tensor: -1,
  688. torch.isposinf: lambda input, out=None: -1,
  689. torch.isneginf: lambda input, out=None: -1,
  690. torch.instance_norm: (
  691. lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
  692. ),
  693. torch.int_repr: lambda input: -1,
  694. torch.inverse: lambda input, out=None: -1,
  695. torch.linalg.inv: lambda input, out=None: -1,
  696. torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
  697. torch.is_complex: lambda input: -1,
  698. torch.is_conj: lambda input: -1,
  699. torch.is_neg: lambda input: -1,
  700. torch.is_distributed: lambda input: -1,
  701. torch.is_inference: lambda input: -1,
  702. torch.is_floating_point: lambda input: -1,
  703. torch.is_nonzero: lambda input: -1,
  704. torch.is_same_size: lambda input, other: -1,
  705. torch.is_signed: lambda input: -1,
  706. torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
  707. torch.isnan: lambda input: -1,
  708. torch.istft: (
  709. lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950
  710. ),
  711. torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
  712. torch.kron: lambda input, other: -1,
  713. torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
  714. torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
  715. torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1,
  716. torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1,
  717. torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
  718. torch.lcm: lambda input, other, out=None: -1,
  719. torch.ldexp: lambda input, other, out=None: -1,
  720. torch.le: lambda input, other, out=None: -1,
  721. torch.less_equal: lambda input, other, out=None: -1,
  722. torch.lerp: lambda input, end, weight, out=None: -1,
  723. torch.lgamma: lambda input, out=None: -1,
  724. torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950
  725. torch.log: lambda input, out=None: -1,
  726. torch.log_softmax: lambda input, dim, dtype=None: -1,
  727. torch.log10: lambda input, out=None: -1,
  728. torch.log1p: lambda input, out=None: -1,
  729. torch.log2: lambda input, out=None: -1,
  730. torch.logaddexp: lambda input, other, out=None: -1,
  731. torch.logaddexp2: lambda input, other, out=None: -1,
  732. torch.logdet: lambda input: -1,
  733. torch.xlogy: lambda x, y, out=None: -1,
  734. torch.logical_and: lambda input, other, out=None: -1,
  735. torch.logical_not: lambda input, out=None: -1,
  736. torch.logical_or: lambda input, other, out=None: -1,
  737. torch.logical_xor: lambda input, other, out=None: -1,
  738. torch.logit: lambda input, eps=None: -1,
  739. torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
  740. torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
  741. torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  742. torch.lt: lambda input, other, out=None: -1,
  743. torch.less: lambda input, other, out=None: -1,
  744. torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
  745. torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
  746. torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950
  747. torch.masked_fill: lambda input, mask, value: -1,
  748. torch.masked_scatter: lambda input, mask, source: -1,
  749. torch.masked_select: lambda input, mask, out=None: -1,
  750. torch.matmul: lambda input, other, out=None: -1,
  751. torch.linalg.lu: lambda input, pivot=True, out=None: -1,
  752. torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1,
  753. torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1,
  754. torch.linalg.lu_solve: lambda LU, pivots, B, left=True, adjoint=False, out=None: -1,
  755. torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul
  756. torch.matrix_power: lambda input, n: -1,
  757. torch.linalg.matrix_power: lambda input, n, out=None: -1,
  758. torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
  759. torch.linalg.multi_dot: lambda tensors, out=None: -1,
  760. torch.matrix_exp: lambda input: -1,
  761. torch.linalg.matrix_exp: lambda input: -1,
  762. torch.max: lambda input, out=None: -1,
  763. torch.maximum: lambda input, other, out=None: -1,
  764. torch.fmax: lambda input, other, out=None: -1,
  765. torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  766. torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  767. torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
  768. torch.max_pool1d_with_indices: (
  769. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  770. ),
  771. torch.mean: lambda input, dim=None: -1,
  772. torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
  773. torch.median: lambda input, dim=None: -1,
  774. torch.nanmedian: lambda input, dim=None: -1,
  775. torch.meshgrid: lambda *tensors, **kwargs: -1,
  776. torch.min: lambda input, out=None: -1,
  777. torch.minimum: lambda input, other, out=None: -1,
  778. torch.fmin: lambda input, other, out=None: -1,
  779. torch.miopen_batch_norm: (
  780. lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
  781. ),
  782. torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950
  783. torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
  784. torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
  785. torch.miopen_convolution_transpose: (
  786. lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
  787. ),
  788. torch.miopen_depthwise_convolution: (
  789. lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
  790. ),
  791. torch.miopen_rnn: (
  792. lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950
  793. ),
  794. torch.mm: lambda input, mat2, out_dtype=None, out=None: -1,
  795. torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
  796. torch.movedim: lambda input, source, destination: -1,
  797. torch.moveaxis: lambda input, source, destination: -1,
  798. torch.msort: lambda input, descending=False, out=None: -1,
  799. torch.mul: lambda input, other, out=None: -1,
  800. torch.multiply: lambda input, other, out=None: -1,
  801. torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
  802. torch.mv: lambda input, vec, out=None: -1,
  803. torch.mvlgamma: lambda input, p: -1,
  804. torch.narrow: lambda input, dim, start, length: -1,
  805. torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
  806. torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
  807. torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
  808. torch.native_dropout: lambda input, p, train: -1,
  809. torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
  810. torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1,
  811. torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
  812. torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
  813. torch.native_channel_shuffle: lambda input, groups: -1,
  814. torch.ne: lambda input, other, out=None: -1,
  815. torch.not_equal: lambda input, other, out=None: -1,
  816. torch.neg: lambda input, out=None: -1,
  817. torch.negative: lambda input, out=None: -1,
  818. torch.nextafter: lambda input, other, out=None: -1,
  819. torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
  820. torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
  821. torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
  822. torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
  823. torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
  824. torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
  825. torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
  826. torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
  827. torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
  828. torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
  829. torch.nn.functional.avg_pool2d: (
  830. lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
  831. ),
  832. torch.nn.functional.avg_pool3d: (
  833. lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
  834. ),
  835. torch.nn.functional.batch_norm: (
  836. lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
  837. ),
  838. torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
  839. torch.nn.functional.binary_cross_entropy: (
  840. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
  841. ),
  842. torch.nn.functional.binary_cross_entropy_with_logits: (
  843. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
  844. ),
  845. torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
  846. torch.nn.functional.cosine_embedding_loss: (
  847. lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
  848. ),
  849. torch.nn.functional.cross_entropy: (
  850. lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
  851. ),
  852. torch.nn.functional.ctc_loss: (
  853. lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
  854. ),
  855. torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
  856. torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
  857. torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
  858. torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
  859. torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
  860. torch.nn.functional.embedding: (
  861. lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
  862. ),
  863. torch.nn.functional.embedding_bag: (
  864. lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950
  865. ),
  866. torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
  867. torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
  868. torch.nn.functional.fractional_max_pool2d: (
  869. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  870. ),
  871. torch.nn.functional.fractional_max_pool2d_with_indices: (
  872. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  873. ),
  874. torch.nn.functional.fractional_max_pool3d: (
  875. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  876. ),
  877. torch.nn.functional.fractional_max_pool3d_with_indices: (
  878. lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
  879. ),
  880. torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
  881. torch.nn.functional.gelu: lambda input, approximate="none": -1,
  882. torch.nn.functional.glu: lambda input, dim=-1: -1,
  883. torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950
  884. torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
  885. torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
  886. torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
  887. torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
  888. torch.nn.functional.hinge_embedding_loss: (
  889. lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
  890. ),
  891. torch.nn.functional.instance_norm: (
  892. lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950
  893. ),
  894. torch.nn.functional.interpolate: (
  895. lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
  896. ),
  897. torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
  898. torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
  899. torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
  900. torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
  901. torch.nn.functional.linear: lambda input, weight, bias=None: -1,
  902. torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
  903. torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  904. torch.nn.functional.logsigmoid: lambda input: -1,
  905. torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  906. torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  907. torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
  908. torch.nn.functional.margin_ranking_loss: (
  909. lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
  910. ),
  911. torch.nn.functional.max_pool1d: (
  912. lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
  913. ),
  914. torch.nn.functional.max_pool1d_with_indices: (
  915. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  916. ),
  917. torch.nn.functional.max_pool2d: (
  918. lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
  919. ),
  920. torch.nn.functional.max_pool2d_with_indices: (
  921. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  922. ),
  923. torch.nn.functional.max_pool3d: (
  924. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  925. ),
  926. torch.nn.functional.max_pool3d_with_indices: (
  927. lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
  928. ),
  929. torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
  930. torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
  931. torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
  932. torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
  933. torch.nn.functional.multi_head_attention_forward: (
  934. lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
  935. ),
  936. torch.nn.functional.multi_margin_loss: (
  937. lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
  938. ),
  939. torch.nn.functional.multilabel_margin_loss: (
  940. lambda input, target, size_average=None, reduce=None, reduction="mean": -1
  941. ),
  942. torch.nn.functional.multilabel_soft_margin_loss: (
  943. lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
  944. ),
  945. torch.nn.functional.nll_loss: (
  946. lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
  947. ),
  948. torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
  949. torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
  950. torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
  951. torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
  952. torch.nn.functional.poisson_nll_loss: (
  953. lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950
  954. ),
  955. torch.nn.functional.prelu: lambda input, weight: -1,
  956. torch.nn.functional.relu: lambda input, inplace=False: -1,
  957. torch.nn.functional.relu6: lambda input, inplace=False: -1,
  958. torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
  959. torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950
  960. torch.nn.functional.selu: lambda input, inplace=False: -1,
  961. torch.nn.functional.silu: lambda input, inplace=False: -1,
  962. torch.nn.functional.mish: lambda input, inplace=False: -1,
  963. torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
  964. torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
  965. torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0, weight=None: -1,
  966. torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
  967. torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  968. torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
  969. torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
  970. torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
  971. torch.nn.functional.softsign: lambda input: -1,
  972. torch.nn.functional.tanhshrink: lambda input: -1,
  973. torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
  974. torch.nn.functional.triplet_margin_loss: (
  975. lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
  976. ),
  977. torch.nn.functional.triplet_margin_with_distance_loss: (
  978. lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
  979. ),
  980. torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
  981. torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
  982. torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
  983. torch.nn.init.constant_: lambda tensor, val: -1,
  984. torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950
  985. torch.nonzero: lambda input, as_tuple=False: -1,
  986. torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
  987. torch.argwhere: lambda input: -1,
  988. torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
  989. torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
  990. torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
  991. torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
  992. -2,
  993. -1,
  994. ), keepdim=False, out=None, dtype=None: -1,
  995. torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
  996. torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
  997. torch.numel: lambda input: -1,
  998. torch.orgqr: lambda input, tau: -1,
  999. torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
  1000. torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
  1001. torch.permute: lambda self, dim: -1,
  1002. torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
  1003. torch.pdist: lambda input, p=2: -1,
  1004. torch.pinverse: lambda input, rcond=1e-15: -1,
  1005. torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
  1006. torch.pixel_shuffle: lambda input, upscale_factor: -1,
  1007. torch.pixel_unshuffle: lambda input, downscale_factor: -1,
  1008. torch.poisson: lambda input, generator=None: -1,
  1009. torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
  1010. torch.polygamma: lambda input, n, out=None: -1,
  1011. torch.positive: lambda input, out=None: -1,
  1012. torch.prelu: lambda input, weight: -1,
  1013. torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  1014. torch.pow: lambda input, exponent, out=None: -1,
  1015. torch.prod: lambda input, dtype=None: -1,
  1016. torch.put: lambda input, index, source, accumulate=False: -1,
  1017. torch.q_per_channel_axis: lambda input: -1,
  1018. torch.q_per_channel_scales: lambda input: -1,
  1019. torch.q_per_channel_zero_points: lambda input: -1,
  1020. torch.q_scale: lambda input: -1,
  1021. torch.q_zero_point: lambda input: -1,
  1022. torch.qr: lambda input, some=True, out=None: -1,
  1023. torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
  1024. torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
  1025. torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
  1026. torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
  1027. torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
  1028. torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
  1029. torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
  1030. torch.quantized_gru_cell: (
  1031. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1032. ),
  1033. torch.quantized_lstm_cell: (
  1034. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1035. ),
  1036. torch.quantized_max_pool1d: (
  1037. lambda input, kernel_size, stride=(), padding=(0,), dilation=(
  1038. 1,
  1039. ), ceil_mode=False: -1
  1040. ),
  1041. torch.quantized_max_pool2d: (
  1042. lambda input, kernel_size, stride=(), padding=(0, 0), dilation=(
  1043. 1,
  1044. 1,
  1045. ), ceil_mode=False: -1
  1046. ),
  1047. torch.quantized_max_pool3d: (
  1048. lambda input, kernel_size, stride=(), padding=(0, 0, 0), dilation=(
  1049. 1,
  1050. 1,
  1051. 1,
  1052. ), ceil_mode=False: -1
  1053. ),
  1054. torch.quantized_rnn_relu_cell: (
  1055. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1056. ),
  1057. torch.quantized_rnn_tanh_cell: (
  1058. lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
  1059. ),
  1060. torch.rad2deg: lambda input, out=None: -1,
  1061. torch.ravel: lambda input: -1,
  1062. torch.real: lambda input, out=None: -1,
  1063. torch.vdot: lambda input, other, out=None: -1,
  1064. torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
  1065. torch.view_as_real: lambda input: -1,
  1066. torch.view_as_complex: lambda input: -1,
  1067. torch.reciprocal: lambda input, out=None: -1,
  1068. torch.relu: lambda input, inplace=False: -1,
  1069. torch.remainder: lambda input, other, out=None: -1,
  1070. torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
  1071. torch.repeat_interleave: lambda input, dim=None: -1,
  1072. torch.reshape: lambda input, shape: -1,
  1073. torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
  1074. torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
  1075. torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  1076. torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
  1077. torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
  1078. torch.roll: lambda input, shifts, dims=None: -1,
  1079. torch.rot90: lambda input, k=1, dims=(0, 1): -1,
  1080. torch.round: lambda input, out=None: -1,
  1081. torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
  1082. torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
  1083. torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
  1084. torch.rsqrt: lambda input, out=None: -1,
  1085. torch.rsub: lambda input, other, alpha=1: -1,
  1086. torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  1087. torch.scatter: lambda input, dim, index, src: -1,
  1088. torch.scatter_add: lambda input, dim, index, src: -1,
  1089. torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
  1090. torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
  1091. torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950
  1092. torch.select: lambda input, dim, index: -1,
  1093. torch.select_scatter: lambda input, src, dim, index: -1,
  1094. torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
  1095. torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
  1096. torch.selu: lambda input, inplace=False: -1,
  1097. torch.sigmoid: lambda input, out=None: -1,
  1098. torch.sign: lambda input, out=None: -1,
  1099. torch.signbit: lambda input, out=None: -1,
  1100. torch.sgn: lambda input, out=None: -1,
  1101. torch.sin: lambda input, out=None: -1,
  1102. torch.sinc: lambda input, out=None: -1,
  1103. torch.sinh: lambda input, out=None: -1,
  1104. torch.slogdet: lambda input: -1,
  1105. torch.linalg.slogdet: lambda input: -1,
  1106. torch.smm: lambda input, mat2, out_dtype=None: -1,
  1107. torch.spmm: lambda input, mat2, out_dtype=None: -1,
  1108. torch.softmax: lambda input, dim, dtype=None: -1,
  1109. torch.linalg.solve: lambda A, B, left=True, out=None: -1,
  1110. torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1,
  1111. torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
  1112. torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
  1113. torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
  1114. torch.sqrt: lambda input, out=None: -1,
  1115. torch.square: lambda input, out=None: -1,
  1116. torch.squeeze: lambda input, dim=None, out=None: -1,
  1117. torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
  1118. torch.stack: lambda tensors, dim=0, out=None: -1,
  1119. torch.std: lambda input, dim=None: -1,
  1120. torch.std_mean: lambda input, dim=None: -1,
  1121. torch.stft: (
  1122. lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 # noqa: B950
  1123. ),
  1124. torch.sub: lambda input, other, out=None: -1,
  1125. torch.subtract: lambda input, other, out=None: -1,
  1126. torch.sum: lambda input, dim=None: -1,
  1127. torch.sym_float: lambda input: -1,
  1128. torch.sym_int: lambda input: -1,
  1129. torch.sym_max: lambda a, b: -1,
  1130. torch.sym_min: lambda a, b: -1,
  1131. torch.sym_not: lambda input: -1,
  1132. torch.sym_ite: lambda a, b, c: -1,
  1133. torch.sym_sum: lambda args: -1,
  1134. torch._sym_sqrt: lambda input: -1,
  1135. torch._sym_cos: lambda input: -1,
  1136. torch._sym_cosh: lambda input: -1,
  1137. torch._sym_sin: lambda input: -1,
  1138. torch._sym_sinh: lambda input: -1,
  1139. torch._sym_tan: lambda input: -1,
  1140. torch._sym_tanh: lambda input: -1,
  1141. torch._sym_asin: lambda input: -1,
  1142. torch._sym_acos: lambda input: -1,
  1143. torch._sym_atan: lambda input: -1,
  1144. torch.nansum: lambda input, dim=None: -1,
  1145. torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
  1146. torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
  1147. torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
  1148. torch.linalg.svdvals: lambda input, out=None: -1,
  1149. torch.swapaxes: lambda input, dim0, dim1: -1,
  1150. torch.swapdims: lambda input, axis0, axis1: -1,
  1151. torch.special.airy_ai: lambda input: -1,
  1152. torch.special.bessel_j0: lambda input: -1,
  1153. torch.special.bessel_j1: lambda input: -1,
  1154. torch.special.bessel_y0: lambda input: -1,
  1155. torch.special.bessel_y1: lambda input: -1,
  1156. torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1,
  1157. torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1,
  1158. torch.special.chebyshev_polynomial_v: lambda input, n, out=None: -1,
  1159. torch.special.chebyshev_polynomial_w: lambda input, n, out=None: -1,
  1160. torch.special.digamma: lambda input: -1,
  1161. torch.special.entr: lambda input: -1,
  1162. torch.special.erf: lambda input: -1,
  1163. torch.special.erfc: lambda input: -1,
  1164. torch.special.erfcx: lambda input: -1,
  1165. torch.special.erfinv: lambda input: -1,
  1166. torch.special.exp2: lambda input: -1,
  1167. torch.special.expit: lambda input: -1,
  1168. torch.special.expm1: lambda input: -1,
  1169. torch.special.gammainc: lambda input, other, out=None: -1,
  1170. torch.special.gammaincc: lambda input, other, out=None: -1,
  1171. torch.special.gammaln: lambda input: -1,
  1172. torch.special.hermite_polynomial_h: lambda input, n, out=None: -1,
  1173. torch.special.hermite_polynomial_he: lambda input, n, out=None: -1,
  1174. torch.special.i0: lambda input: -1,
  1175. torch.special.i0e: lambda input: -1,
  1176. torch.special.i1: lambda input: -1,
  1177. torch.special.i1e: lambda input: -1,
  1178. torch.special.laguerre_polynomial_l: lambda input, n, out=None: -1,
  1179. torch.special.legendre_polynomial_p: lambda input, n, out=None: -1,
  1180. torch.special.log1p: lambda input: -1,
  1181. torch.special.log_ndtr: lambda input: -1,
  1182. torch.special.log_softmax: lambda input, dim, dtype=None: -1,
  1183. torch.special.logit: lambda input: -1,
  1184. torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
  1185. torch.special.modified_bessel_i0: lambda input: -1,
  1186. torch.special.modified_bessel_i1: lambda input: -1,
  1187. torch.special.modified_bessel_k0: lambda input: -1,
  1188. torch.special.modified_bessel_k1: lambda input: -1,
  1189. torch.special.multigammaln: lambda input, p: -1,
  1190. torch.special.ndtr: lambda input: -1,
  1191. torch.special.ndtri: lambda input: -1,
  1192. torch.special.polygamma: lambda input, n, out=None: -1,
  1193. torch.special.psi: lambda input: -1,
  1194. torch.special.round: lambda input: -1,
  1195. torch.special.scaled_modified_bessel_k0: lambda input: -1,
  1196. torch.special.scaled_modified_bessel_k1: lambda input: -1,
  1197. torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1,
  1198. torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1,
  1199. torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1,
  1200. torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1,
  1201. torch.special.sinc: lambda input: -1,
  1202. torch.special.softmax: lambda input, dim, dtype=None: -1,
  1203. torch.special.spherical_bessel_j0: lambda input: -1,
  1204. torch.special.xlog1py: lambda input, other, out=None: -1,
  1205. torch.special.xlogy: lambda input, other, out=None: -1,
  1206. torch.special.zeta: lambda self, other, out=None: -1,
  1207. torch.t: lambda input: -1,
  1208. torch.take: lambda input, index: -1,
  1209. torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
  1210. torch.tan: lambda input, out=None: -1,
  1211. torch.tanh: lambda input, out=None: -1,
  1212. torch.linalg.tensorinv: lambda a, ind=2: -1,
  1213. torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
  1214. torch.tensordot: lambda a, b, dims=2, out=None: -1,
  1215. torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
  1216. torch.threshold: lambda input, threshold, value, inplace=False: -1,
  1217. torch.tile: lambda input, dims: -1,
  1218. torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
  1219. torch.trace: lambda input: -1,
  1220. torch.transpose: lambda input, dim0, dim1: -1,
  1221. torch.trapz: lambda y, x=None, dim=-1: -1,
  1222. torch.trapezoid: lambda y, x=None, dim=-1: -1,
  1223. torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
  1224. torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
  1225. torch.tril: lambda input, diagonal=0, out=None: -1,
  1226. torch.triplet_margin_loss: (
  1227. lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
  1228. ),
  1229. torch.triu: lambda input, diagonal=0, out=None: -1,
  1230. torch.true_divide: lambda input, other: -1,
  1231. torch.trunc: lambda input, out=None: -1,
  1232. torch.unbind: lambda input, dim=0: -1,
  1233. torch.unflatten: lambda input, dim, sizes, names: -1,
  1234. torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
  1235. torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
  1236. torch.unravel_index: lambda indices, shape: -1,
  1237. torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
  1238. torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
  1239. torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
  1240. torch.unsqueeze: lambda input, dim, out=None: -1,
  1241. torch.linalg.vander: lambda x, N=None: -1,
  1242. torch.var: lambda input, dim=None: -1,
  1243. torch.var_mean: lambda input, dim=None: -1,
  1244. torch.vsplit: lambda input, indices_or_sections: -1,
  1245. torch.vstack: lambda tensors, out=None: -1,
  1246. torch.where: lambda condition, x=None, y=None: -1,
  1247. torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1,
  1248. torch._wrapped_quantized_linear_prepacked: (
  1249. lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950
  1250. ),
  1251. torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
  1252. torch._fw_primal_copy: lambda self, level: -1,
  1253. torch._make_dual_copy: lambda primal, tangent, level: -1,
  1254. torch.view_as_real_copy: lambda self: -1,
  1255. torch.view_as_complex_copy: lambda self: -1,
  1256. torch._conj_copy: lambda self: -1,
  1257. torch._neg_view_copy: lambda self: -1,
  1258. torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
  1259. torch._sparse_broadcast_to_copy: lambda self, size: -1,
  1260. torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
  1261. torch.expand_copy: lambda self, size, *, implicit=False: -1,
  1262. torch.narrow_copy: lambda self, dim, start, length: -1,
  1263. torch.permute_copy: lambda self, dims: -1,
  1264. torch._reshape_alias_copy: lambda self, size, stride: -1,
  1265. torch.select_copy: lambda self, dim, index: -1,
  1266. torch.detach_copy: lambda self: -1,
  1267. torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
  1268. torch.split_copy: lambda self, split_size, dim=0: -1,
  1269. torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
  1270. torch.squeeze_copy: lambda self, dim: -1,
  1271. torch.t_copy: lambda self: -1,
  1272. torch.transpose_copy: lambda self, dim0, dim1: -1,
  1273. torch.unsqueeze_copy: lambda self, dim: -1,
  1274. torch._indices_copy: lambda self: -1,
  1275. torch._values_copy: lambda self: -1,
  1276. torch.indices_copy: lambda self: -1,
  1277. torch.values_copy: lambda self: -1,
  1278. torch.crow_indices_copy: lambda self: -1,
  1279. torch.col_indices_copy: lambda self: -1,
  1280. torch.ccol_indices_copy: lambda self: -1,
  1281. torch.row_indices_copy: lambda self: -1,
  1282. torch.unbind_copy: lambda self, dim=0: -1,
  1283. torch.view_copy: lambda self, dtype: -1,
  1284. torch.unfold_copy: lambda self, dimension, size, step: -1,
  1285. torch.alias_copy: lambda self: -1,
  1286. Tensor.__floordiv__: lambda self, other: -1,
  1287. Tensor.__rfloordiv__: lambda self, other: -1,
  1288. Tensor.__ifloordiv__: lambda self, other: -1,
  1289. Tensor.__truediv__: lambda self, other: -1,
  1290. Tensor.__rtruediv__: lambda self, other: -1,
  1291. Tensor.__itruediv__: lambda self, other: -1,
  1292. Tensor.__lshift__: lambda self, other: -1,
  1293. Tensor.__rlshift__: lambda self, other: -1,
  1294. Tensor.__ilshift__: lambda self, other: -1,
  1295. Tensor.__rshift__: lambda self, other: -1,
  1296. Tensor.__rrshift__: lambda self, other: -1,
  1297. Tensor.__irshift__: lambda self, other: -1,
  1298. Tensor.__and__: lambda self, other: -1,
  1299. Tensor.__or__: lambda self, other: -1,
  1300. Tensor.__xor__: lambda self, other: -1,
  1301. Tensor.__float__: lambda self: -1,
  1302. Tensor.__complex__: lambda self: -1,
  1303. Tensor.__array__: lambda self, dtype: -1,
  1304. Tensor.__bool__: lambda self: -1,
  1305. Tensor.__contains__: lambda self, other: -1,
  1306. Tensor.__neg__: lambda self: -1,
  1307. Tensor.__invert__: lambda self: -1,
  1308. Tensor.__mod__: lambda self, other: -1,
  1309. Tensor.__rmod__: lambda self, other: -1,
  1310. Tensor.__imod__: lambda self, other: -1,
  1311. Tensor.__array_wrap__: lambda self, array: -1,
  1312. Tensor.__getitem__: lambda self, idx: -1,
  1313. Tensor.__deepcopy__: lambda self, memo: -1,
  1314. Tensor.__int__: lambda self: -1,
  1315. Tensor.__long__: lambda self: -1,
  1316. Tensor.__index__: lambda self: -1,
  1317. Tensor.__len__: lambda self: -1,
  1318. Tensor.__format__: lambda self, format_spec: -1,
  1319. Tensor.__reduce_ex__: lambda self, proto: -1,
  1320. Tensor.__reversed__: lambda self: -1,
  1321. Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
  1322. Tensor.__setitem__: lambda self, k, v: -1,
  1323. Tensor.__setstate__: lambda self, d: -1,
  1324. Tensor.T.__get__: lambda self: -1,
  1325. Tensor.H.__get__: lambda self: -1,
  1326. Tensor.mT.__get__: lambda self: -1,
  1327. Tensor.mH.__get__: lambda self: -1,
  1328. Tensor._backward_hooks.__get__: lambda self: -1,
  1329. Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1,
  1330. Tensor._base.__get__: lambda self: -1,
  1331. Tensor._cdata.__get__: lambda self: -1,
  1332. Tensor.grad.__get__: lambda self: -1,
  1333. Tensor._grad.__get__: lambda self: -1,
  1334. Tensor._grad_fn.__get__: lambda self: -1,
  1335. Tensor.grad_fn.__get__: lambda self: -1,
  1336. Tensor.grad_dtype.__get__: lambda self: -1,
  1337. Tensor._version.__get__: lambda self: -1,
  1338. Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
  1339. Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
  1340. Tensor._clear_non_serializable_cached_data: lambda self: -1,
  1341. Tensor.data.__get__: lambda self: -1,
  1342. Tensor.device.__get__: lambda self: -1,
  1343. Tensor.dtype.__get__: lambda self: -1,
  1344. Tensor.is_cuda.__get__: lambda self: -1,
  1345. Tensor.is_cpu.__get__: lambda self: -1,
  1346. Tensor.is_xla.__get__: lambda self: -1,
  1347. Tensor.is_xpu.__get__: lambda self: -1,
  1348. Tensor.is_ipu.__get__: lambda self: -1,
  1349. Tensor.is_leaf.__get__: lambda self: -1,
  1350. Tensor.retains_grad.__get__: lambda self: -1,
  1351. Tensor.is_meta.__get__: lambda self: -1,
  1352. Tensor.is_mps.__get__: lambda self: -1,
  1353. Tensor.is_mtia.__get__: lambda self: -1,
  1354. Tensor.is_nested.__get__: lambda self: -1,
  1355. Tensor.is_maia.__get__: lambda self: -1,
  1356. Tensor.is_mkldnn.__get__: lambda self: -1,
  1357. Tensor.is_quantized.__get__: lambda self: -1,
  1358. Tensor.is_sparse.__get__: lambda self: -1,
  1359. Tensor.is_sparse_csr.__get__: lambda self: -1,
  1360. Tensor.is_vulkan.__get__: lambda self: -1,
  1361. Tensor.itemsize.__get__: lambda self: -1,
  1362. Tensor.layout.__get__: lambda self: -1,
  1363. Tensor.name.__get__: lambda self: -1,
  1364. Tensor.names.__get__: lambda self: -1,
  1365. Tensor.nbytes.__get__: lambda self: -1,
  1366. Tensor.ndim.__get__: lambda self: -1,
  1367. Tensor.output_nr.__get__: lambda self: -1,
  1368. Tensor.requires_grad.__get__: lambda self: -1,
  1369. Tensor.shape.__get__: lambda self: -1,
  1370. Tensor.volatile.__get__: lambda self: -1,
  1371. Tensor.real.__get__: lambda self: -1,
  1372. Tensor.imag.__get__: lambda self: -1,
  1373. Tensor.__cuda_array_interface__.__get__: lambda self: -1,
  1374. Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
  1375. Tensor._dimI: lambda self: -1,
  1376. Tensor._dimV: lambda self: -1,
  1377. Tensor._indices: lambda self: -1,
  1378. Tensor._is_view: lambda self: -1,
  1379. Tensor._nnz: lambda self: -1,
  1380. Tensor.crow_indices: lambda self: -1,
  1381. Tensor.col_indices: lambda self: -1,
  1382. Tensor.ccol_indices: lambda self: -1,
  1383. Tensor.row_indices: lambda self: -1,
  1384. Tensor._update_names: lambda self, names, inplace: -1,
  1385. Tensor._values: lambda self: -1,
  1386. Tensor.adjoint: lambda self: -1,
  1387. Tensor.align_as: lambda self, other: -1,
  1388. Tensor.align_to: lambda self, order, ellipsis_idx: -1,
  1389. Tensor.apply_: lambda self, callable: -1,
  1390. Tensor.as_strided: lambda self, size, stride: -1,
  1391. Tensor.as_strided_: lambda self, size, stride: -1,
  1392. Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
  1393. Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
  1394. Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
  1395. Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
  1396. Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
  1397. Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
  1398. Tensor.coalesce: lambda self: -1,
  1399. Tensor._coalesced_: lambda self, coalesced: -1,
  1400. Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
  1401. Tensor.copy_: lambda self, src, non_blocking=False: -1,
  1402. Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
  1403. Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
  1404. Tensor.mtia: lambda self, memory_format=torch.preserve_format: -1,
  1405. Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
  1406. Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
  1407. Tensor.data_ptr: lambda self: -1,
  1408. Tensor.dense_dim: lambda self: -1,
  1409. Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
  1410. Tensor.dim: lambda self: -1,
  1411. Tensor.dim_order: lambda self, ambiguity_check=False: -1,
  1412. Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
  1413. Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
  1414. Tensor.element_size: lambda self: -1,
  1415. Tensor.expand: lambda self, size: -1,
  1416. Tensor.expand_as: lambda self, other: -1,
  1417. Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
  1418. Tensor.fill_: lambda self, value: -1,
  1419. Tensor.fill_diagonal_: lambda self, value: -1,
  1420. Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
  1421. Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
  1422. Tensor.geometric_: lambda self, p, *, generator=None: -1,
  1423. Tensor.get_device: lambda self: -1,
  1424. Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
  1425. Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1,
  1426. Tensor.has_names: lambda self: -1,
  1427. Tensor.indices: lambda self: -1,
  1428. Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
  1429. Tensor.is_coalesced: lambda self: -1,
  1430. Tensor.is_contiguous: lambda self: -1,
  1431. Tensor.is_inference: lambda self: -1,
  1432. Tensor.is_pinned: lambda self: -1,
  1433. Tensor.is_set_to: lambda self, tensor: -1,
  1434. Tensor.is_shared: lambda self: -1,
  1435. Tensor.item: lambda self: -1,
  1436. Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
  1437. Tensor.log_softmax: lambda self, dim: -1,
  1438. Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
  1439. Tensor.map_: lambda self, tensor, callable: -1,
  1440. Tensor.map2_: lambda self, x, y, callable: -1,
  1441. Tensor.mm: lambda self, mat2, out_dtype=None: -1,
  1442. Tensor.module_load: lambda self, other, assign=False: -1,
  1443. Tensor.narrow_copy: lambda self, dimension, start, length: -1,
  1444. Tensor.ndimension: lambda self: -1,
  1445. Tensor.nelement: lambda self: -1,
  1446. Tensor._nested_tensor_size: lambda self: -1,
  1447. Tensor._nested_tensor_storage_offsets: lambda self: -1,
  1448. Tensor._nested_tensor_strides: lambda self: -1,
  1449. Tensor.normal_: lambda self: -1,
  1450. Tensor.numpy: lambda self: -1,
  1451. Tensor.permute: lambda self, dim: -1,
  1452. Tensor.pin_memory: lambda self: -1,
  1453. Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
  1454. Tensor.qscheme: lambda self: -1,
  1455. Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
  1456. Tensor.record_stream: lambda self, stream: -1,
  1457. Tensor.refine_names: lambda self, names: -1,
  1458. Tensor.register_hook: lambda self, hook: -1,
  1459. Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1,
  1460. Tensor.rename: lambda self, name: -1,
  1461. Tensor.repeat: lambda self, *size: -1,
  1462. Tensor.requires_grad_: lambda self, requires_grad=True: -1,
  1463. Tensor.reshape_as: lambda self, other: -1,
  1464. Tensor.resize: lambda self, *size: -1,
  1465. Tensor.resize_: lambda self, size: -1,
  1466. Tensor.resize_as: lambda self, other: -1,
  1467. Tensor.resize_as_sparse_: lambda self, other: -1,
  1468. Tensor.retain_grad: lambda self: -1,
  1469. Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
  1470. Tensor.select_scatter: lambda self, src, dim, index: -1,
  1471. Tensor.share_memory_: lambda self: -1,
  1472. Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
  1473. Tensor.size: lambda self: -1,
  1474. Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
  1475. Tensor.sparse_dim: lambda self: -1,
  1476. Tensor.sparse_mask: lambda self, mask: -1,
  1477. Tensor._sparse_mask_projection: lambda self, mask, accumulate_matches=False: -1,
  1478. Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
  1479. Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
  1480. Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
  1481. Tensor.storage: lambda self: -1,
  1482. Tensor.untyped_storage: lambda self: -1,
  1483. Tensor.storage_offset: lambda self: -1,
  1484. Tensor.storage_type: lambda self: -1,
  1485. Tensor.sum_to_size: lambda self, size: -1,
  1486. Tensor.tile: lambda self, *reps: -1,
  1487. Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
  1488. Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
  1489. Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
  1490. Tensor.to_sparse: lambda self: -1,
  1491. Tensor.tolist: lambda self: -1,
  1492. Tensor.to_mkldnn: lambda self: -1,
  1493. Tensor.type_as: lambda self, other: -1,
  1494. Tensor.unfold: lambda self, dimension, size, step: -1,
  1495. Tensor.uniform_: lambda self, from_=0, to=1: -1,
  1496. Tensor.values: lambda self: -1,
  1497. Tensor.view: lambda self, shape: -1,
  1498. Tensor.view_as: lambda self, other: -1,
  1499. Tensor.zero_: lambda self: -1,
  1500. Tensor.__dlpack__: lambda self, stream=None, max_version=None, dl_device=None, copy=None: -1,
  1501. Tensor.__dlpack_device__: lambda self: -1,
  1502. Tensor.index: lambda self, a, b: -1,
  1503. torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
  1504. } # fmt: skip
  1505. privateuse1_backend_name = (
  1506. torch.utils.backend_registration._privateuse1_backend_name
  1507. )
  1508. if hasattr(Tensor, privateuse1_backend_name):
  1509. ret[getattr(Tensor, privateuse1_backend_name)] = (
  1510. lambda self, device=None, non_blocking=False, **kwargs: -1
  1511. )
  1512. ret[getattr(Tensor, f"is_{privateuse1_backend_name}").__get__] = lambda self: -1
  1513. ret2 = {}
  1514. ignored = get_ignored_functions()
  1515. for k, v in ret.items():
  1516. # Generate methods like __add__ and add_ by default from add
  1517. names = [
  1518. k.__name__, # Default method
  1519. k.__name__ + "_", # Inplace variant
  1520. "__" + k.__name__ + "__", # Dunder method
  1521. "__i" + k.__name__ + "__", # Inplace dunder method
  1522. "__r" + k.__name__ + "__", # Reverse dunder method
  1523. ]
  1524. if k.__name__.startswith("bitwise_"):
  1525. # bitwise_<op> have dunder methods of the form __<op>__
  1526. # And so on.
  1527. subname = k.__name__[len("bitwise_") :]
  1528. names.extend(
  1529. ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
  1530. )
  1531. for name in names:
  1532. func = getattr(Tensor, name, None)
  1533. if callable(func) and func not in ret and func not in ignored:
  1534. ret2[func] = v
  1535. ret.update(ret2)
  1536. return ret
  1537. def wrap_torch_function(dispatcher: Callable):
  1538. """Wraps a given function with ``__torch_function__`` -related functionality.
  1539. Parameters
  1540. ----------
  1541. dispatcher: Callable
  1542. A callable that returns an iterable of Tensor-likes passed into the function.
  1543. Note
  1544. ----
  1545. This decorator may reduce the performance of your code. Generally, it's enough to express
  1546. your code as a series of functions that, themselves, support __torch_function__. If you
  1547. find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
  1548. low-level library and you also need it to work for Tensor-likes, then this function is available.
  1549. Examples
  1550. --------
  1551. >>> def dispatcher(a): # Must have the same signature as func
  1552. ... return (a,)
  1553. >>> @torch.overrides.wrap_torch_function(dispatcher)
  1554. >>> def func(a): # This will make func dispatchable by __torch_function__
  1555. ... return a + 0
  1556. """
  1557. def inner(func):
  1558. @functools.wraps(func)
  1559. def wrapped(*args, **kwargs):
  1560. relevant_args = dispatcher(*args, **kwargs)
  1561. if has_torch_function(relevant_args):
  1562. return handle_torch_function(wrapped, relevant_args, *args, **kwargs)
  1563. return func(*args, **kwargs)
  1564. return wrapped
  1565. return inner
  1566. def _get_overloaded_args(
  1567. relevant_args: Iterable[Any],
  1568. get_type_fn: Callable[[Any], type] | None = None,
  1569. ) -> list[Any]:
  1570. """Returns a list of arguments on which to call __torch_function__.
  1571. Checks arguments in relevant_args for __torch_function__ implementations,
  1572. storing references to the arguments and their types in overloaded_args and
  1573. overloaded_types in order of calling precedence. Only distinct types are
  1574. considered. If a type is a subclass of another type it will have higher
  1575. precedence, otherwise the precedence order is the same as the order of
  1576. arguments in relevant_args, that is, from left-to-right in the argument list.
  1577. The precedence-determining algorithm implemented in this function is
  1578. described in `NEP-0018`_.
  1579. See torch::append_overloaded_arg for the equivalent function in the C++
  1580. implementation.
  1581. Parameters
  1582. ----------
  1583. relevant_args : iterable of array-like
  1584. Iterable of array-like arguments to check for __torch_function__
  1585. methods.
  1586. get_type_fn : callable, optional
  1587. Function to call on each argument in relevant_args to get its type.
  1588. Returns
  1589. -------
  1590. overloaded_args : list
  1591. Arguments from relevant_args on which to call __torch_function__
  1592. methods, in the order in which they should be called.
  1593. .. _NEP-0018:
  1594. https://numpy.org/neps/nep-0018-array-function-protocol.html
  1595. """
  1596. if get_type_fn is None:
  1597. get_type_fn = type
  1598. # If torch function is not enabled, there are no overloaded types
  1599. if not torch._C._is_torch_function_enabled():
  1600. return []
  1601. # Runtime is O(num_arguments * num_unique_types)
  1602. overloaded_types: set[type] = set()
  1603. overloaded_args: list[Any] = []
  1604. for arg in relevant_args:
  1605. arg_type = get_type_fn(arg)
  1606. # We only collect arguments if they have a unique type, which ensures
  1607. # reasonable performance even with a long list of possibly overloaded
  1608. # arguments.
  1609. #
  1610. # NB: Important to exclude _disabled_torch_function_impl, otherwise
  1611. # https://github.com/pytorch/pytorch/issues/64687
  1612. if (
  1613. arg_type not in overloaded_types
  1614. and hasattr(arg_type, "__torch_function__")
  1615. and arg_type.__torch_function__
  1616. is not torch._C._disabled_torch_function_impl
  1617. ):
  1618. # Create lists explicitly for the first type (usually the only one
  1619. # done) to avoid setting up the iterator for overloaded_args.
  1620. if overloaded_types:
  1621. overloaded_types.add(arg_type)
  1622. # By default, insert argument at the end, but if it is
  1623. # subclass of another argument, insert it before that argument.
  1624. # This ensures "subclasses before superclasses".
  1625. index = len(overloaded_args)
  1626. for i, old_arg in enumerate(overloaded_args):
  1627. if issubclass(arg_type, get_type_fn(old_arg)):
  1628. index = i
  1629. break
  1630. overloaded_args.insert(index, arg)
  1631. else:
  1632. overloaded_types = {arg_type}
  1633. overloaded_args = [arg]
  1634. return overloaded_args
  1635. def handle_torch_function(
  1636. public_api: Callable,
  1637. relevant_args: Iterable[Any],
  1638. *args,
  1639. **kwargs,
  1640. ) -> Any:
  1641. """Implement a function with checks for ``__torch_function__`` overrides.
  1642. See torch::autograd::handle_torch_function for the equivalent of this
  1643. function in the C++ implementation.
  1644. Arguments
  1645. ---------
  1646. public_api : function
  1647. Function exposed by the public torch API originally called like
  1648. ``public_api(*args, **kwargs)`` on which arguments are now being
  1649. checked.
  1650. relevant_args : iterable
  1651. Iterable of arguments to check for __torch_function__ methods.
  1652. args : tuple
  1653. Arbitrary positional arguments originally passed into ``public_api``.
  1654. kwargs : tuple
  1655. Arbitrary keyword arguments originally passed into ``public_api``.
  1656. Returns
  1657. -------
  1658. object
  1659. Result from calling ``implementation`` or an ``__torch_function__``
  1660. method, as appropriate.
  1661. Raises
  1662. ------
  1663. TypeError : if no implementation is found.
  1664. Example
  1665. -------
  1666. >>> def func(a):
  1667. ... if has_torch_function_unary(a):
  1668. ... return handle_torch_function(func, (a,), a)
  1669. ... return a + 0
  1670. """
  1671. # Check for __torch_function__ methods.
  1672. overloaded_args = _get_overloaded_args(relevant_args)
  1673. # overloaded_args already have unique types.
  1674. types = tuple(map(type, overloaded_args))
  1675. # Check for __torch_function__ mode.
  1676. if _is_torch_function_mode_enabled():
  1677. # if we're here, the mode must be set to a TorchFunctionStackMode
  1678. # this unsets it and calls directly into TorchFunctionStackMode's torch function
  1679. with _pop_mode_temporarily() as mode:
  1680. result = mode.__torch_function__(public_api, types, args, kwargs)
  1681. if result is not NotImplemented:
  1682. return result
  1683. # Call overrides
  1684. for overloaded_arg in overloaded_args:
  1685. # This call needs to become a classmethod call in the future.
  1686. # See https://github.com/pytorch/pytorch/issues/63767
  1687. torch_func_method = overloaded_arg.__torch_function__
  1688. if (
  1689. hasattr(torch_func_method, "__self__")
  1690. and torch_func_method.__self__ is overloaded_arg
  1691. and torch_func_method is not torch._C._disabled_torch_function_impl
  1692. ):
  1693. warnings.warn(
  1694. "Defining your `__torch_function__ as a plain method is deprecated and "
  1695. "will be an error in future, please define it as a classmethod.",
  1696. DeprecationWarning,
  1697. stacklevel=2,
  1698. )
  1699. # Use `public_api` instead of `implementation` so __torch_function__
  1700. # implementations can do equality/identity comparisons.
  1701. result = torch_func_method(public_api, types, args, kwargs)
  1702. if result is not NotImplemented:
  1703. return result
  1704. func_name = f"{public_api.__module__}.{public_api.__name__}"
  1705. msg = (
  1706. f"no implementation found for '{func_name}' on types that implement "
  1707. f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
  1708. )
  1709. if _is_torch_function_mode_enabled():
  1710. msg += f" nor in mode {_get_current_function_mode()}"
  1711. raise TypeError(msg)
  1712. has_torch_function = _add_docstr(
  1713. _has_torch_function,
  1714. r"""Check for __torch_function__ implementations in the elements of an iterable
  1715. or if a __torch_function__ mode is enabled. Considers exact ``Tensor`` s
  1716. and ``Parameter`` s non-dispatchable. Use this to guard a call to
  1717. :func:`handle_torch_function`; don't use it to test if something
  1718. is Tensor-like, use :func:`is_tensor_like` instead.
  1719. Arguments
  1720. ---------
  1721. relevant_args : iterable
  1722. Iterable or arguments to check for __torch_function__ methods.
  1723. Returns
  1724. -------
  1725. bool
  1726. True if any of the elements of relevant_args have __torch_function__
  1727. implementations, False otherwise.
  1728. See Also
  1729. ________
  1730. torch.is_tensor_like
  1731. Checks if something is a Tensor-like, including an exact ``Tensor``.
  1732. """,
  1733. )
  1734. has_torch_function_unary = _add_docstr(
  1735. _has_torch_function_unary,
  1736. r"""Special case of `has_torch_function` for single inputs.
  1737. Instead of:
  1738. `has_torch_function((t,))`
  1739. call:
  1740. `has_torch_function_unary(t)`
  1741. which skips unnecessary packing and unpacking work.
  1742. """,
  1743. )
  1744. has_torch_function_variadic = _add_docstr(
  1745. _has_torch_function_variadic,
  1746. r"""Special case of `has_torch_function` that skips tuple creation.
  1747. This uses the METH_FASTCALL protocol introduced in Python 3.7
  1748. Instead of:
  1749. `has_torch_function((a, b))`
  1750. call:
  1751. `has_torch_function_variadic(a, b)`
  1752. which skips unnecessary packing and unpacking work.
  1753. """,
  1754. )
  1755. @functools.cache
  1756. def _get_overridable_functions() -> tuple[
  1757. dict[Any, list[Callable]], dict[Callable, str]
  1758. ]:
  1759. overridable_funcs = collections.defaultdict(list)
  1760. index = {}
  1761. tested_namespaces = [
  1762. ("torch", torch, torch.__all__),
  1763. ("torch.functional", torch.functional, torch.functional.__all__),
  1764. ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
  1765. ("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
  1766. ("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
  1767. ("torch.linalg", torch.linalg, dir(torch.linalg)),
  1768. ("torch.fft", torch.fft, dir(torch.fft)),
  1769. ("torch.special", torch.special, dir(torch.special)),
  1770. ]
  1771. for namespace_str, namespace, ns_funcs in tested_namespaces:
  1772. for func_name in ns_funcs:
  1773. ignore = False
  1774. # ignore private functions or functions that are deleted in torch.__init__
  1775. if namespace is not torch.Tensor:
  1776. if func_name.startswith("__"):
  1777. continue
  1778. elif func_name.startswith("_"):
  1779. ignore = True
  1780. elif func_name.endswith("_"):
  1781. ignore = True
  1782. elif not func_name[0].islower():
  1783. ignore = True
  1784. elif func_name == "unique_dim":
  1785. continue
  1786. else:
  1787. func = getattr(namespace, func_name)
  1788. if getattr(object, func_name, None) == func:
  1789. continue
  1790. if func_name == "__weakref__":
  1791. continue
  1792. func = getattr(namespace, func_name)
  1793. if namespace is torch.Tensor and getattr(object, func_name, None) == func:
  1794. continue
  1795. # ignore re-exported modules
  1796. if isinstance(func, types.ModuleType):
  1797. continue
  1798. # ignore __future__ imports
  1799. if isinstance(func, __future__._Feature):
  1800. continue
  1801. if not callable(func) and hasattr(func, "__get__"):
  1802. index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
  1803. index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
  1804. if ignore:
  1805. continue
  1806. if func.__get__ in get_ignored_functions():
  1807. msg = (
  1808. "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
  1809. "but still has an explicit override"
  1810. )
  1811. if func.__get__ in get_testing_overrides():
  1812. raise AssertionError(msg.format(namespace, func.__name__))
  1813. continue
  1814. else:
  1815. overridable_funcs[func].append(func.__get__)
  1816. continue
  1817. if not callable(func):
  1818. continue
  1819. index[func] = f"{namespace_str}.{func_name}"
  1820. if ignore:
  1821. continue
  1822. # cannot be overridden by __torch_function__
  1823. if func in get_ignored_functions():
  1824. msg = (
  1825. "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
  1826. "but still has an explicit override"
  1827. )
  1828. if func in get_testing_overrides():
  1829. raise AssertionError(msg.format(namespace, func.__name__))
  1830. continue
  1831. overridable_funcs[namespace].append(func)
  1832. return overridable_funcs, index
  1833. @_disable_user_warnings
  1834. def get_overridable_functions() -> dict[Any, list[Callable]]:
  1835. """List functions that are overridable via __torch_function__
  1836. Returns
  1837. -------
  1838. Dict[Any, List[Callable]]
  1839. A dictionary that maps namespaces that contain overridable functions
  1840. to functions in that namespace that can be overridden.
  1841. """
  1842. return _get_overridable_functions()[0]
  1843. @_disable_user_warnings
  1844. def resolve_name(f):
  1845. """Get a human readable string name for a function passed to
  1846. __torch_function__
  1847. Arguments
  1848. ---------
  1849. f : Callable
  1850. Function to resolve the name of.
  1851. Returns
  1852. -------
  1853. str
  1854. Name of the function; if eval'ed it should give back the input
  1855. function.
  1856. """
  1857. if isinstance(f, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
  1858. return str(f)
  1859. return _get_overridable_functions()[1].get(f)
  1860. @functools.cache
  1861. def _get_tensor_methods() -> set[Callable]:
  1862. """Returns a set of the overridable methods on ``torch.Tensor``"""
  1863. overridable_funcs = get_overridable_functions()
  1864. methods = set(overridable_funcs[torch.Tensor])
  1865. return methods
  1866. @_disable_user_warnings
  1867. def is_tensor_method_or_property(func: Callable) -> bool:
  1868. """
  1869. Returns True if the function passed in is a handler for a
  1870. method or property belonging to ``torch.Tensor``, as passed
  1871. into ``__torch_function__``.
  1872. .. note::
  1873. For properties, their ``__get__`` method must be passed in.
  1874. This may be needed, in particular, for the following reasons:
  1875. 1. Methods/properties sometimes don't contain a `__module__` slot.
  1876. 2. They require that the first passed-in argument is an instance
  1877. of ``torch.Tensor``.
  1878. Examples
  1879. --------
  1880. >>> is_tensor_method_or_property(torch.Tensor.add)
  1881. True
  1882. >>> is_tensor_method_or_property(torch.add)
  1883. False
  1884. """
  1885. return func in _get_tensor_methods() or func.__name__ == "__get__"
  1886. def is_tensor_like(inp):
  1887. """
  1888. Returns ``True`` if the passed-in input is a Tensor-like.
  1889. Currently, this occurs whenever there's a ``__torch_function__``
  1890. attribute on the type of the input.
  1891. Examples
  1892. --------
  1893. A subclass of tensor is generally a Tensor-like.
  1894. >>> class SubTensor(torch.Tensor): ...
  1895. >>> is_tensor_like(SubTensor([0]))
  1896. True
  1897. Built-in or user types aren't usually Tensor-like.
  1898. >>> is_tensor_like(6)
  1899. False
  1900. >>> is_tensor_like(None)
  1901. False
  1902. >>> class NotATensor: ...
  1903. >>> is_tensor_like(NotATensor())
  1904. False
  1905. But, they can be made Tensor-like by implementing __torch_function__.
  1906. >>> class TensorLike:
  1907. ... @classmethod
  1908. ... def __torch_function__(cls, func, types, args, kwargs):
  1909. ... return -1
  1910. >>> is_tensor_like(TensorLike())
  1911. True
  1912. """
  1913. return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
  1914. class TorchFunctionMode:
  1915. """
  1916. A ``TorchFunctionMode`` allows you to override the meaning of all
  1917. ``__torch_function__`` overridable functions within a dynamic scope,
  1918. without having to actually create a tensor subclass or manually
  1919. monkey-patch functions in the PyTorch API. Some common situations
  1920. where you should use a mode:
  1921. * You want to override the meaning of factory functions, or other
  1922. functions that do not otherwise take a tensor as an argument
  1923. (these cannot be overridden with tensor subclasses).
  1924. * You want to override the behavior of all functions without needing
  1925. to wrap your inputs in tensor subclasses; e.g., if you are just
  1926. interested in logging intermediate computations.
  1927. * You want to control the order of execution of various tensor
  1928. subclasses explicitly, rather than implicitly via the return of
  1929. ``NotImplemented``.
  1930. Independent subclasses of :class:`TorchFunctionMode` are compositional:
  1931. modes can be pushed onto a stack using ``with MyMode():``.
  1932. When you call functions in the PyTorch API inside your
  1933. ``__torch_function__`` implementation, by default, they will forward on to
  1934. the next mode on the mode stack. If you want recursively call back into
  1935. your current ``__torch_function__`` implementation, either explicitly
  1936. invoke ``self.__torch_function__(...)``, or use the context manager
  1937. ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
  1938. API self-referential (beware of infinite loops, in this case!)
  1939. """
  1940. inner: "TorchFunctionMode"
  1941. # Force metaclass to generate constructor at the base of the hierarchy
  1942. def __init__(self) -> None:
  1943. pass
  1944. def __torch_function__(self, func, types, args=(), kwargs=None):
  1945. raise NotImplementedError
  1946. def __enter__(self):
  1947. _push_mode(self)
  1948. return self
  1949. def __exit__(self, exc_type, exc_val, exc_tb):
  1950. _pop_mode()
  1951. @classmethod
  1952. def push(cls, *args, **kwargs):
  1953. warnings.warn(
  1954. "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`",
  1955. stacklevel=2,
  1956. )
  1957. instance = cls(*args, **kwargs)
  1958. return instance
  1959. def _get_current_function_mode():
  1960. stack_len = _len_torch_function_stack()
  1961. return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None
  1962. def _get_current_function_mode_stack():
  1963. stack_len = _len_torch_function_stack()
  1964. return [_get_function_stack_at(i) for i in range(stack_len)]
  1965. def _push_mode(mode):
  1966. _push_on_torch_function_stack(mode)
  1967. def _pop_mode():
  1968. old = _pop_torch_function_stack()
  1969. return old
  1970. @contextlib.contextmanager
  1971. def _pop_mode_temporarily():
  1972. old = _pop_mode()
  1973. try:
  1974. yield old
  1975. finally:
  1976. _push_mode(old)
  1977. class BaseTorchFunctionMode(TorchFunctionMode):
  1978. def __torch_function__(self, func, types, args=(), kwargs=None):
  1979. if kwargs is None:
  1980. kwargs = {}
  1981. return func(*args, **kwargs)
  1982. @contextlib.contextmanager
  1983. def _enable_torch_function():
  1984. old_state = torch._C._get_torch_function_state()
  1985. try:
  1986. torch._C._set_torch_function_state(torch._C._TorchFunctionState.ENABLED)
  1987. yield
  1988. finally:
  1989. torch._C._set_torch_function_state(old_state)
  1990. @contextlib.contextmanager
  1991. def enable_reentrant_dispatch():
  1992. # NB: this can't simply be
  1993. # `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot`
  1994. # because:
  1995. # 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file
  1996. # initially gets imported. Probably an import order thing.
  1997. # 2. enable_reentrant_dispatch is technically public API; assigning
  1998. # it the object would change the __module__ to look private.
  1999. with torch._C._RestorePythonTLSSnapshot():
  2000. try:
  2001. yield
  2002. finally:
  2003. pass