op_properties.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import torch
  7. # pointwise operators can go through a faster pathway
  8. tensor_magic_methods = ["add", ""]
  9. pointwise_magic_methods_with_reverse = (
  10. "add",
  11. "sub",
  12. "mul",
  13. "floordiv",
  14. "div",
  15. "truediv",
  16. "mod",
  17. "pow",
  18. "lshift",
  19. "rshift",
  20. "and",
  21. "or",
  22. "xor",
  23. )
  24. pointwise_magic_methods = (
  25. *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
  26. "eq",
  27. "gt",
  28. "le",
  29. "lt",
  30. "ge",
  31. "gt",
  32. "ne",
  33. "neg",
  34. "pos",
  35. "abs",
  36. "invert",
  37. "iadd",
  38. "isub",
  39. "imul",
  40. "ifloordiv",
  41. "idiv",
  42. "itruediv",
  43. "imod",
  44. "ipow",
  45. "ilshift",
  46. "irshift",
  47. "iand",
  48. "ior",
  49. "ixor",
  50. "int",
  51. "long",
  52. "float",
  53. "complex",
  54. )
  55. pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
  56. pointwise = (
  57. *(getattr(torch.Tensor, m) for m in pointwise_methods),
  58. torch.nn.functional.dropout,
  59. torch.where,
  60. torch.Tensor.abs,
  61. torch.abs,
  62. torch.Tensor.acos,
  63. torch.acos,
  64. torch.Tensor.acosh,
  65. torch.acosh,
  66. torch.Tensor.add,
  67. torch.add,
  68. torch.Tensor.addcdiv,
  69. torch.addcdiv,
  70. torch.Tensor.addcmul,
  71. torch.addcmul,
  72. torch.Tensor.addr,
  73. torch.addr,
  74. torch.Tensor.angle,
  75. torch.angle,
  76. torch.Tensor.asin,
  77. torch.asin,
  78. torch.Tensor.asinh,
  79. torch.asinh,
  80. torch.Tensor.atan,
  81. torch.atan,
  82. torch.Tensor.atan2,
  83. torch.atan2,
  84. torch.Tensor.atanh,
  85. torch.atanh,
  86. torch.Tensor.bitwise_and,
  87. torch.bitwise_and,
  88. torch.Tensor.bitwise_left_shift,
  89. torch.bitwise_left_shift,
  90. torch.Tensor.bitwise_not,
  91. torch.bitwise_not,
  92. torch.Tensor.bitwise_or,
  93. torch.bitwise_or,
  94. torch.Tensor.bitwise_right_shift,
  95. torch.bitwise_right_shift,
  96. torch.Tensor.bitwise_xor,
  97. torch.bitwise_xor,
  98. torch.Tensor.ceil,
  99. torch.ceil,
  100. torch.celu,
  101. torch.nn.functional.celu,
  102. torch.Tensor.clamp,
  103. torch.clamp,
  104. torch.Tensor.clamp_max,
  105. torch.clamp_max,
  106. torch.Tensor.clamp_min,
  107. torch.clamp_min,
  108. torch.Tensor.copysign,
  109. torch.copysign,
  110. torch.Tensor.cos,
  111. torch.cos,
  112. torch.Tensor.cosh,
  113. torch.cosh,
  114. torch.Tensor.deg2rad,
  115. torch.deg2rad,
  116. torch.Tensor.digamma,
  117. torch.digamma,
  118. torch.Tensor.div,
  119. torch.div,
  120. torch.dropout,
  121. torch.nn.functional.dropout,
  122. torch.nn.functional.elu,
  123. torch.Tensor.eq,
  124. torch.eq,
  125. torch.Tensor.erf,
  126. torch.erf,
  127. torch.Tensor.erfc,
  128. torch.erfc,
  129. torch.Tensor.erfinv,
  130. torch.erfinv,
  131. torch.Tensor.exp,
  132. torch.exp,
  133. torch.Tensor.exp2,
  134. torch.exp2,
  135. torch.Tensor.expm1,
  136. torch.expm1,
  137. torch.feature_dropout,
  138. torch.Tensor.float_power,
  139. torch.float_power,
  140. torch.Tensor.floor,
  141. torch.floor,
  142. torch.Tensor.floor_divide,
  143. torch.floor_divide,
  144. torch.Tensor.fmod,
  145. torch.fmod,
  146. torch.Tensor.frac,
  147. torch.frac,
  148. torch.Tensor.frexp,
  149. torch.frexp,
  150. torch.Tensor.gcd,
  151. torch.gcd,
  152. torch.Tensor.ge,
  153. torch.ge,
  154. torch.nn.functional.gelu,
  155. torch.nn.functional.glu,
  156. torch.Tensor.gt,
  157. torch.gt,
  158. torch.Tensor.hardshrink,
  159. torch.hardshrink,
  160. torch.nn.functional.hardshrink,
  161. torch.nn.functional.hardsigmoid,
  162. torch.nn.functional.hardswish,
  163. torch.nn.functional.hardtanh,
  164. torch.Tensor.heaviside,
  165. torch.heaviside,
  166. torch.Tensor.hypot,
  167. torch.hypot,
  168. torch.Tensor.i0,
  169. torch.i0,
  170. torch.Tensor.igamma,
  171. torch.igamma,
  172. torch.Tensor.igammac,
  173. torch.igammac,
  174. torch.Tensor.isclose,
  175. torch.isclose,
  176. torch.Tensor.isfinite,
  177. torch.isfinite,
  178. torch.Tensor.isinf,
  179. torch.isinf,
  180. torch.Tensor.isnan,
  181. torch.isnan,
  182. torch.Tensor.isneginf,
  183. torch.isneginf,
  184. torch.Tensor.isposinf,
  185. torch.isposinf,
  186. torch.Tensor.isreal,
  187. torch.isreal,
  188. torch.Tensor.kron,
  189. torch.kron,
  190. torch.Tensor.lcm,
  191. torch.lcm,
  192. torch.Tensor.ldexp,
  193. torch.ldexp,
  194. torch.Tensor.le,
  195. torch.le,
  196. torch.nn.functional.leaky_relu,
  197. torch.Tensor.lerp,
  198. torch.lerp,
  199. torch.Tensor.lgamma,
  200. torch.lgamma,
  201. torch.Tensor.log,
  202. torch.log,
  203. torch.Tensor.log10,
  204. torch.log10,
  205. torch.Tensor.log1p,
  206. torch.log1p,
  207. torch.Tensor.log2,
  208. torch.log2,
  209. torch.nn.functional.logsigmoid,
  210. torch.Tensor.logical_and,
  211. torch.logical_and,
  212. torch.Tensor.logical_not,
  213. torch.logical_not,
  214. torch.Tensor.logical_or,
  215. torch.logical_or,
  216. torch.Tensor.logical_xor,
  217. torch.logical_xor,
  218. torch.Tensor.logit,
  219. torch.logit,
  220. torch.Tensor.lt,
  221. torch.lt,
  222. torch.Tensor.maximum,
  223. torch.maximum,
  224. torch.Tensor.minimum,
  225. torch.minimum,
  226. torch.nn.functional.mish,
  227. torch.Tensor.mvlgamma,
  228. torch.mvlgamma,
  229. torch.Tensor.nan_to_num,
  230. torch.nan_to_num,
  231. torch.Tensor.ne,
  232. torch.ne,
  233. torch.Tensor.neg,
  234. torch.neg,
  235. torch.Tensor.nextafter,
  236. torch.nextafter,
  237. torch.Tensor.outer,
  238. torch.outer,
  239. torch.polar,
  240. torch.Tensor.polygamma,
  241. torch.polygamma,
  242. torch.Tensor.positive,
  243. torch.positive,
  244. torch.Tensor.pow,
  245. torch.pow,
  246. torch.Tensor.prelu,
  247. torch.prelu,
  248. torch.nn.functional.prelu,
  249. torch.Tensor.rad2deg,
  250. torch.rad2deg,
  251. torch.Tensor.reciprocal,
  252. torch.reciprocal,
  253. torch.Tensor.relu,
  254. torch.relu,
  255. torch.nn.functional.relu,
  256. torch.nn.functional.relu6,
  257. torch.Tensor.remainder,
  258. torch.remainder,
  259. torch.Tensor.round,
  260. torch.round,
  261. torch.rrelu,
  262. torch.nn.functional.rrelu,
  263. torch.Tensor.rsqrt,
  264. torch.rsqrt,
  265. torch.rsub,
  266. torch.selu,
  267. torch.nn.functional.selu,
  268. torch.Tensor.sgn,
  269. torch.sgn,
  270. torch.Tensor.sigmoid,
  271. torch.sigmoid,
  272. torch.nn.functional.sigmoid,
  273. torch.Tensor.sign,
  274. torch.sign,
  275. torch.Tensor.signbit,
  276. torch.signbit,
  277. torch.nn.functional.silu,
  278. torch.Tensor.sin,
  279. torch.sin,
  280. torch.Tensor.sinc,
  281. torch.sinc,
  282. torch.Tensor.sinh,
  283. torch.sinh,
  284. torch.nn.functional.softplus,
  285. torch.nn.functional.softshrink,
  286. torch.Tensor.sqrt,
  287. torch.sqrt,
  288. torch.Tensor.square,
  289. torch.square,
  290. torch.Tensor.sub,
  291. torch.sub,
  292. torch.Tensor.tan,
  293. torch.tan,
  294. torch.Tensor.tanh,
  295. torch.tanh,
  296. torch.nn.functional.tanh,
  297. torch.threshold,
  298. torch.nn.functional.threshold,
  299. torch.trapz,
  300. torch.Tensor.true_divide,
  301. torch.true_divide,
  302. torch.Tensor.trunc,
  303. torch.trunc,
  304. torch.Tensor.xlogy,
  305. torch.xlogy,
  306. torch.rand_like,
  307. )