reference.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # mypy: allow-untyped-defs
  2. import math
  3. import operator
  4. from typing import NoReturn
  5. import sympy
  6. import torch
  7. from torch.utils._sympy.functions import (
  8. _keep_float,
  9. BitwiseFn_bitwise_and,
  10. BitwiseFn_bitwise_or,
  11. BitwiseFn_bitwise_xor,
  12. FloatPow,
  13. FloatTrueDiv,
  14. FloorDiv,
  15. IntTrueDiv,
  16. Max,
  17. Min,
  18. Mod,
  19. OpaqueUnaryFn_exp,
  20. OpaqueUnaryFn_log,
  21. OpaqueUnaryFn_log2,
  22. OpaqueUnaryFn_sqrt,
  23. PowByNatural,
  24. RoundDecimal,
  25. RoundToInt,
  26. ToFloat,
  27. TruncToInt,
  28. )
  29. # The sympy interpretation of operators. It will also sometimes work with
  30. # plain int/float, but if you do certain operations you will get out a
  31. # sympy.Basic in the end. If you want the Python/FX traceable interpretation,
  32. # check PythonReferenceAnalysis.
  33. # NB: For magic methods this needs to use normal magic methods
  34. # so that test_magic_methods works
  35. class ReferenceAnalysis:
  36. @staticmethod
  37. def constant(c, dtype):
  38. return sympy.sympify(c)
  39. @staticmethod
  40. def or_(a, b):
  41. return a | b
  42. @staticmethod
  43. def and_(a, b):
  44. return a & b
  45. @staticmethod
  46. def eq(a, b):
  47. if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
  48. return sympy.Eq(a, b)
  49. return a == b
  50. @classmethod
  51. def ne(cls, a, b):
  52. return cls.not_(cls.eq(a, b))
  53. @staticmethod
  54. def lt(a, b):
  55. return a < b
  56. @staticmethod
  57. def gt(a, b):
  58. return a > b
  59. @staticmethod
  60. def le(a, b):
  61. return a <= b
  62. @staticmethod
  63. def ge(a, b):
  64. return a >= b
  65. @staticmethod
  66. def not_(a):
  67. if isinstance(a, bool):
  68. raise AssertionError("not_ needs sympy expr")
  69. return ~a
  70. @staticmethod
  71. def reciprocal(x):
  72. return FloatTrueDiv(1.0, x)
  73. @staticmethod
  74. def square(x):
  75. return PowByNatural(x, 2)
  76. @staticmethod
  77. def trunc_to_int(x, dtype):
  78. return TruncToInt(x)
  79. @staticmethod
  80. def ceil_to_int(x, dtype):
  81. return sympy.ceiling(x)
  82. @staticmethod
  83. def floor_to_int(x, dtype):
  84. return sympy.floor(x)
  85. @staticmethod
  86. def floor(x):
  87. return _keep_float(sympy.floor)(x)
  88. @staticmethod
  89. def ceil(x):
  90. return _keep_float(sympy.ceiling)(x)
  91. @staticmethod
  92. def to_dtype(x, dtype):
  93. if dtype == torch.float64:
  94. return ToFloat(x)
  95. raise NotImplementedError(f"to_dtype {dtype} NYI")
  96. @staticmethod
  97. def mod(x, y):
  98. return Mod(x, y)
  99. @staticmethod
  100. def abs(x):
  101. return abs(x)
  102. @staticmethod
  103. def neg(x):
  104. return -x
  105. @staticmethod
  106. def truediv(a, b):
  107. return FloatTrueDiv(a, b)
  108. @staticmethod
  109. def int_truediv(a, b):
  110. return IntTrueDiv(a, b)
  111. @staticmethod
  112. def floordiv(a, b):
  113. return FloorDiv(a, b)
  114. @staticmethod
  115. def truncdiv(a, b) -> NoReturn:
  116. raise NotImplementedError("TODO: truncdiv")
  117. @staticmethod
  118. def add(a, b):
  119. return _keep_float(operator.add)(a, b)
  120. @classmethod
  121. def sym_sum(cls, args):
  122. return sympy.Add(*args)
  123. @staticmethod
  124. def mul(a, b):
  125. return _keep_float(operator.mul)(a, b)
  126. @staticmethod
  127. def sub(a, b):
  128. return _keep_float(operator.sub)(a, b)
  129. @staticmethod
  130. def exp(x):
  131. return OpaqueUnaryFn_exp(x)
  132. @staticmethod
  133. def log(x):
  134. return OpaqueUnaryFn_log(x)
  135. @staticmethod
  136. def log2(x):
  137. return OpaqueUnaryFn_log2(x)
  138. @staticmethod
  139. def sqrt(x):
  140. return OpaqueUnaryFn_sqrt(x)
  141. @staticmethod
  142. def pow(a, b):
  143. # pyrefly: ignore [bad-argument-type]
  144. return _keep_float(FloatPow)(a, b)
  145. @staticmethod
  146. def pow_by_natural(a, b):
  147. return PowByNatural(a, b)
  148. @staticmethod
  149. def minimum(a, b):
  150. return Min(a, b)
  151. @staticmethod
  152. def maximum(a, b):
  153. return Max(a, b)
  154. @staticmethod
  155. def round_to_int(a, dtype):
  156. return RoundToInt(a)
  157. @staticmethod
  158. def round_decimal(a, b):
  159. return RoundDecimal(a, b)
  160. @staticmethod
  161. def bitwise_and(a, b):
  162. return BitwiseFn_bitwise_and(a, b)
  163. @staticmethod
  164. def bitwise_or(a, b):
  165. return BitwiseFn_bitwise_or(a, b)
  166. @staticmethod
  167. def bitwise_xor(a, b):
  168. return BitwiseFn_bitwise_xor(a, b)
  169. # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
  170. # Python types and is FX traceable. Inheritance here is purely for code
  171. # sharing (TODO: considering splitting out a BaseReferenceAnalysis).
  172. class PythonReferenceAnalysis(ReferenceAnalysis):
  173. @staticmethod
  174. def constant(c, dtype):
  175. if dtype is torch.int64:
  176. return int(c)
  177. elif dtype is torch.double:
  178. return float(c)
  179. elif dtype is torch.bool:
  180. return bool(c)
  181. else:
  182. raise AssertionError(f"unrecognized dtype {dtype}")
  183. @staticmethod
  184. def not_(a):
  185. return torch.sym_not(a)
  186. @classmethod
  187. def sym_sum(cls, args):
  188. if len(args) == 0:
  189. return 0
  190. if len(args) == 1:
  191. return args[0]
  192. acc = cls.add(args[0], args[1])
  193. for i in range(2, len(args)):
  194. acc = cls.add(acc, args[i])
  195. return acc
  196. @staticmethod
  197. def floordiv(a, b):
  198. return a // b
  199. @staticmethod
  200. def mod(x, y):
  201. return x % y
  202. @staticmethod
  203. def python_mod(x, y):
  204. return x % y
  205. @staticmethod
  206. def truncdiv(a, b):
  207. return a / b
  208. @staticmethod
  209. def to_dtype(x, dtype):
  210. if dtype == torch.float64:
  211. return torch.sym_float(x)
  212. raise NotImplementedError(f"to_dtype {dtype} NYI")
  213. @staticmethod
  214. def exp(x) -> NoReturn:
  215. raise AssertionError("exp is not valid shape sympy expr")
  216. @staticmethod
  217. def log(x) -> NoReturn:
  218. raise AssertionError("log is not valid shape sympy expr")
  219. @staticmethod
  220. def log2(x):
  221. return torch._sym_log2(x) # type: ignore[attr-defined]
  222. @staticmethod
  223. def sqrt(x):
  224. return torch._sym_sqrt(x) # type: ignore[attr-defined]
  225. @staticmethod
  226. def minimum(a, b):
  227. return torch.sym_min(a, b)
  228. @staticmethod
  229. def maximum(a, b):
  230. return torch.sym_max(a, b)
  231. @staticmethod
  232. def floor_to_int(x, dtype):
  233. return math.floor(x)
  234. @staticmethod
  235. def ceil_to_int(x, dtype):
  236. return math.ceil(x)
  237. @staticmethod
  238. def floor(x):
  239. return float(math.floor(x))
  240. @staticmethod
  241. def ceil(x):
  242. return float(math.ceil(x))
  243. @staticmethod
  244. def truediv(a, b):
  245. return a / b
  246. @staticmethod
  247. def pow(a, b):
  248. return a**b
  249. @staticmethod
  250. def pow_by_natural(a, b):
  251. # Pray that safe_pow is not needed here lol. In particular, this
  252. # never participates in VR low/high ranges, so overflow should be
  253. # unlikely
  254. return a**b
  255. @staticmethod
  256. def round_to_int(a, dtype):
  257. return round(a)
  258. @staticmethod
  259. def round_decimal(a, b):
  260. return round(a, ndigits=b)
  261. @staticmethod
  262. def bitwise_and(a, b):
  263. return a & b
  264. @staticmethod
  265. def bitwise_or(a, b):
  266. return a | b
  267. @staticmethod
  268. def bitwise_xor(a, b):
  269. return a ^ b
  270. # Like PythonReferenceAnalysis, but some export-unfriendly choices of
  271. # operators to make things faster
  272. class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis):
  273. @staticmethod
  274. def sym_sum(args):
  275. return torch.sym_sum(args)
  276. def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  277. return torch.ops.prims.convert_element_type.default(x, dtype)
  278. # Suppose we have some int/float arguments. This diagram commutes:
  279. #
  280. # int/float -- PythonReferenceAnalysis.op --> int/float
  281. # | |
  282. # | |
  283. # torch.tensor(..., dtype=torch.int64/torch.float64)
  284. # | |
  285. # V V
  286. # Tensor -- TensorReferenceAnalysis.op --> Tensor
  287. #
  288. # NB: int before and after must be representable in int64 (we will
  289. # insert guards accordingly.)
  290. #
  291. # This is guaranteed to be FX traceable with OpOverloads only.
  292. class TensorReferenceAnalysis:
  293. # NB: This is actually dead, because with Proxy tracing the factory
  294. # function isn't traced correctly. Here for completeness.
  295. @staticmethod
  296. def constant(c, dtype):
  297. d: int | float | bool
  298. if dtype is torch.int64:
  299. d = int(c)
  300. elif dtype is torch.double:
  301. d = float(c)
  302. elif dtype is torch.bool:
  303. d = bool(c)
  304. else:
  305. raise AssertionError(f"unrecognized dtype {dtype}")
  306. return torch.ops.aten.scalar_tensor.default(d, dtype=dtype)
  307. @staticmethod
  308. def or_(a, b):
  309. return torch.ops.aten.logical_or.default(a, b)
  310. @staticmethod
  311. def and_(a, b):
  312. return torch.ops.aten.logical_and.default(a, b)
  313. @staticmethod
  314. def bitwise_and(a, b):
  315. return torch.ops.aten.bitwise_and(a, b)
  316. @staticmethod
  317. def bitwise_or(a, b):
  318. return torch.ops.aten.bitwise_or(a, b)
  319. @staticmethod
  320. def bitwise_xor(a, b):
  321. return torch.ops.aten.bitwise_xor(a, b)
  322. @staticmethod
  323. def eq(a, b):
  324. return torch.ops.aten.eq.Tensor(a, b)
  325. @classmethod
  326. def ne(cls, a, b):
  327. return torch.ops.aten.ne.Tensor(a, b)
  328. @staticmethod
  329. def lt(a, b):
  330. return torch.ops.aten.lt.Tensor(a, b)
  331. @staticmethod
  332. def gt(a, b):
  333. return torch.ops.aten.gt.Tensor(a, b)
  334. @staticmethod
  335. def le(a, b):
  336. return torch.ops.aten.le.Tensor(a, b)
  337. @staticmethod
  338. def ge(a, b):
  339. return torch.ops.aten.ge.Tensor(a, b)
  340. @staticmethod
  341. def not_(a):
  342. return torch.ops.aten.logical_not.default(a)
  343. @staticmethod
  344. def reciprocal(x):
  345. return torch.ops.aten.reciprocal.default(x)
  346. @staticmethod
  347. def square(x):
  348. # TODO: maybe composite implicit autograd doesn't work here?
  349. return torch.ops.aten.square.default(x)
  350. @staticmethod
  351. def trunc_to_int(x, dtype):
  352. return _to_dtype(torch.ops.aten.trunc.default(x), dtype)
  353. @staticmethod
  354. def ceil_to_int(x, dtype):
  355. return _to_dtype(torch.ops.aten.ceil.default(x), dtype)
  356. @staticmethod
  357. def floor_to_int(x, dtype):
  358. return _to_dtype(torch.ops.aten.floor.default(x), dtype)
  359. @staticmethod
  360. def floor(x):
  361. return torch.ops.aten.floor.default(x)
  362. @staticmethod
  363. def ceil(x):
  364. return torch.ops.aten.ceil.default(x)
  365. @staticmethod
  366. def to_dtype(x, dtype):
  367. return _to_dtype(x, dtype)
  368. @staticmethod
  369. def mod(x, y) -> NoReturn:
  370. # TODO: https://github.com/pytorch/pytorch/pull/133654
  371. raise NotImplementedError(
  372. "no C-style modulus operation available from frontend atm"
  373. )
  374. @staticmethod
  375. def abs(x):
  376. return torch.ops.aten.abs.default(x)
  377. @staticmethod
  378. def neg(x):
  379. return torch.ops.aten.neg.default(x)
  380. @staticmethod
  381. def truediv(a, b):
  382. return torch.ops.aten.true_divide.Tensor(a, b)
  383. @staticmethod
  384. def int_truediv(a, b):
  385. raise NotImplementedError(
  386. "Python int truediv difficult to implement in PyTorch atm"
  387. )
  388. # TODO: This is wrong, CPython has a custom implementation of true
  389. # division that results in higher precision when the floats are
  390. # sufficiently large. Short term fix: add a guard here
  391. # pyrefly: ignore [unreachable]
  392. return torch.ops.aten.true_divide.default(
  393. _to_dtype(a, torch.float64), _to_dtype(b, torch.float64)
  394. )
  395. @staticmethod
  396. def floordiv(a, b):
  397. return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor")
  398. @staticmethod
  399. def truncdiv(a, b) -> NoReturn:
  400. raise NotImplementedError(
  401. "no C-style truncdiv operation available from frontend atm"
  402. )
  403. @staticmethod
  404. def add(a, b):
  405. return torch.ops.aten.add.Tensor(a, b)
  406. @staticmethod
  407. def mul(a, b):
  408. return torch.ops.aten.mul.Tensor(a, b)
  409. @staticmethod
  410. def sub(a, b):
  411. return torch.ops.aten.sub.Tensor(a, b)
  412. @staticmethod
  413. def exp(x):
  414. return torch.ops.aten.exp.default(x)
  415. @staticmethod
  416. def log(x):
  417. return torch.ops.aten.log.default(x)
  418. @staticmethod
  419. def log2(x):
  420. return torch.ops.aten.log2.default(x)
  421. @staticmethod
  422. def sqrt(x):
  423. return torch.ops.aten.sqrt.default(x)
  424. @staticmethod
  425. def sin(x):
  426. return torch.ops.aten.sin.default(x)
  427. @staticmethod
  428. def cos(x):
  429. return torch.ops.aten.cos.default(x)
  430. @staticmethod
  431. def tanh(x):
  432. return torch.ops.aten.tanh.default(x)
  433. @staticmethod
  434. def sinh(x):
  435. return torch.ops.aten.sinh.default(x)
  436. @staticmethod
  437. def cosh(x):
  438. return torch.ops.aten.cosh.default(x)
  439. @staticmethod
  440. def tan(x):
  441. return torch.ops.aten.tan.default(x)
  442. @staticmethod
  443. def acos(x):
  444. return torch.ops.aten.acos.default(x)
  445. @staticmethod
  446. def atan(x):
  447. return torch.ops.aten.atan.default(x)
  448. @staticmethod
  449. def asin(x):
  450. return torch.ops.aten.asin.default(x)
  451. @staticmethod
  452. def pow(a, b):
  453. return torch.ops.aten.pow.Tensor_Tensor(a, b)
  454. @staticmethod
  455. def pow_by_natural(a, b):
  456. # NB: pow handles int x int fine
  457. return torch.ops.aten.pow.Tensor_Tensor(a, b)
  458. @staticmethod
  459. def minimum(a, b):
  460. return torch.ops.aten.minimum.default(a, b)
  461. @staticmethod
  462. def maximum(a, b):
  463. return torch.ops.aten.maximum.default(a, b)
  464. @staticmethod
  465. def round_to_int(a, dtype):
  466. return torch.ops.aten.round.default(a)
  467. @staticmethod
  468. def round_decimal(a, b) -> NoReturn:
  469. raise NotImplementedError(
  470. "round decimal doesn't support Tensor second argument atm"
  471. )
  472. # return torch.ops.aten.round.decimals(a, b)