_builtins.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # mypy: allow-untyped-defs
  2. import cmath
  3. import math
  4. import warnings
  5. from collections import OrderedDict
  6. from typing import Optional
  7. import torch
  8. import torch.backends.cudnn as cudnn
  9. from torch.nn.modules.utils import (
  10. _list_with_default,
  11. _pair,
  12. _quadruple,
  13. _single,
  14. _triple,
  15. )
  16. _builtin_table: Optional[dict[int, str]] = None
  17. _modules_containing_builtins = (
  18. torch,
  19. torch._C._nn,
  20. torch._C._fft, # type: ignore[attr-defined]
  21. torch._C._linalg, # type: ignore[attr-defined]
  22. torch._C._nested, # type: ignore[attr-defined]
  23. torch._C._sparse, # type: ignore[attr-defined]
  24. torch._C._special, # type: ignore[attr-defined]
  25. )
  26. _builtin_ops = [
  27. # Pairs of (function, op_name)
  28. (_pair, "aten::_pair"),
  29. (_quadruple, "aten::_quadruple"),
  30. (_single, "aten::_single"),
  31. (_triple, "aten::_triple"),
  32. (_list_with_default, "aten::list_with_default"),
  33. (OrderedDict, "aten::dict"),
  34. (dict, "aten::dict"),
  35. (cudnn.is_acceptable, "aten::cudnn_is_acceptable"),
  36. (math.ceil, "aten::ceil"),
  37. (math.copysign, "aten::copysign"),
  38. (math.erf, "aten::erf"),
  39. (math.erfc, "aten::erfc"),
  40. (math.exp, "aten::exp"),
  41. (math.expm1, "aten::expm1"),
  42. (math.fabs, "aten::fabs"),
  43. (math.floor, "aten::floor"),
  44. (math.gamma, "aten::gamma"),
  45. (math.lgamma, "aten::lgamma"),
  46. (math.log, "aten::log"),
  47. (math.log10, "aten::log10"),
  48. (math.log1p, "aten::log1p"),
  49. (math.pow, "aten::pow"),
  50. (math.sqrt, "aten::sqrt"),
  51. (math.isnan, "aten::isnan"),
  52. (math.asinh, "aten::asinh"),
  53. (math.atanh, "aten::atanh"),
  54. (math.cosh, "aten::cosh"),
  55. (math.sinh, "aten::sinh"),
  56. (math.tanh, "aten::tanh"),
  57. (math.acos, "aten::acos"),
  58. (math.asin, "aten::asin"),
  59. (math.atan, "aten::atan"),
  60. (math.atan2, "aten::atan2"),
  61. (math.cos, "aten::cos"),
  62. (math.sin, "aten::sin"),
  63. (math.tan, "aten::tan"),
  64. (math.asinh, "aten::asinh"),
  65. (math.atanh, "aten::atanh"),
  66. (math.acosh, "aten::acosh"),
  67. (math.fmod, "aten::fmod"),
  68. (math.modf, "aten::modf"),
  69. (math.factorial, "aten::factorial"),
  70. (math.frexp, "aten::frexp"),
  71. (math.isinf, "aten::isinf"),
  72. (math.degrees, "aten::degrees"),
  73. (math.radians, "aten::radians"),
  74. (cmath.isnan, "aten::isnan"),
  75. (cmath.isfinite, "aten::isfinite"),
  76. (cmath.isinf, "aten::isinf"),
  77. (cmath.phase, "aten::angle"),
  78. (cmath.rect, "aten::polar"),
  79. (cmath.log, "aten::log"),
  80. (cmath.log10, "aten::log10"),
  81. (cmath.sqrt, "aten::sqrt"),
  82. (cmath.exp, "aten::exp"),
  83. (cmath.sin, "aten::sin"),
  84. (cmath.tan, "aten::tan"),
  85. (cmath.cos, "aten::cos"),
  86. (cmath.asin, "aten::asin"),
  87. (cmath.acos, "aten::acos"),
  88. (cmath.atan, "aten::atan"),
  89. (cmath.sinh, "aten::sinh"),
  90. (cmath.cosh, "aten::cosh"),
  91. (cmath.tanh, "aten::tanh"),
  92. (cmath.asinh, "aten::asinh"),
  93. (cmath.acosh, "aten::acosh"),
  94. (cmath.atanh, "aten::atanh"),
  95. (math.ldexp, "aten::ldexp"),
  96. (torch._assert, "aten::_assert"),
  97. (torch.autograd.grad, "aten::grad"),
  98. (torch.autograd.backward, "aten::backward"),
  99. (torch._C._infer_size, "aten::_infer_size"),
  100. (
  101. torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined]
  102. "aten::_no_grad_embedding_renorm_",
  103. ),
  104. (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
  105. (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
  106. (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
  107. (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
  108. (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
  109. (torch._C._get_tracing_state, "aten::_get_tracing_state"),
  110. (torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
  111. (warnings.warn, "aten::warn"),
  112. (torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
  113. (torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]
  114. (torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined]
  115. (torch._VF.norm, "aten::norm"), # type: ignore[attr-defined]
  116. (torch._VF.unique_dim, "aten::unique_dim"),
  117. (torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined]
  118. (torch._VF.nuclear_norm, "aten::nuclear_norm"),
  119. (torch._VF.frobenius_norm, "aten::frobenius_norm"),
  120. (torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined]
  121. ]
  122. # ops in torch.functional are bound to torch
  123. # in these cases, we want to resolve the function to their python implementation
  124. # instead looking up a builtin "aten::" schema
  125. def _gen_torch_functional_registered_ops():
  126. # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
  127. # but we are currently only able to compile some of the functions. additionally,
  128. # some functions directly map to their aten:: implementations.
  129. # TODO: add support for more ops
  130. ops = [
  131. "stft",
  132. "istft",
  133. "lu",
  134. "cdist",
  135. "norm",
  136. "unique",
  137. "unique_consecutive",
  138. "tensordot",
  139. ]
  140. return {getattr(torch.functional, name) for name in ops}
  141. _functional_registered_ops = _gen_torch_functional_registered_ops()
  142. def _is_special_functional_bound_op(fn):
  143. return fn in _functional_registered_ops
  144. # lazily built to ensure the correct initialization order
  145. def _get_builtin_table():
  146. global _builtin_table
  147. if _builtin_table is not None:
  148. return _builtin_table
  149. _builtin_table = {}
  150. def register_all(mod) -> None:
  151. for name in dir(mod):
  152. v = getattr(mod, name)
  153. if (
  154. callable(v)
  155. and not _is_special_functional_bound_op(v)
  156. and v is not torch.no_grad
  157. and v is not torch.autocast
  158. ):
  159. # Fixup inconsistency in segment_reduce
  160. if name == "_segment_reduce":
  161. name = name[1:]
  162. _builtin_ops.append((v, "aten::" + name))
  163. for mod in _modules_containing_builtins:
  164. register_all(mod)
  165. _builtin_ops.append((math.gcd, "aten::gcd"))
  166. _builtin_ops.append((math.isfinite, "aten::isfinite"))
  167. _builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined]
  168. import torch.distributed.autograd as dist_autograd
  169. if dist_autograd.is_available():
  170. _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
  171. _builtin_ops.append((dist_autograd.backward, "aten::dist_backward"))
  172. # populate the _builtin_table from _builtin_ops
  173. for builtin, aten_op in _builtin_ops:
  174. _builtin_table[id(builtin)] = aten_op
  175. return _builtin_table
  176. def _register_builtin(fn, op) -> None:
  177. _get_builtin_table()[id(fn)] = op
  178. def _find_builtin(fn):
  179. return _get_builtin_table().get(id(fn))