test_torch.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. import random
  2. import math
  3. from sympy import symbols, Derivative
  4. from sympy.printing.pytorch import torch_code
  5. from sympy import (eye, MatrixSymbol, Matrix)
  6. from sympy.tensor.array import NDimArray
  7. from sympy.tensor.array.expressions.array_expressions import (
  8. ArrayTensorProduct, ArrayAdd,
  9. PermuteDims, ArrayDiagonal, _CodegenArrayAbstract)
  10. from sympy.utilities.lambdify import lambdify
  11. from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt
  12. from sympy.functions import \
  13. Abs, ceiling, exp, floor, sign, sin, asin, cos, \
  14. acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \
  15. re, im, arg, erf, loggamma, sqrt
  16. from sympy.testing.pytest import skip
  17. from sympy.external import import_module
  18. from sympy.matrices.expressions import \
  19. Determinant, HadamardProduct, Inverse, Trace
  20. from sympy.matrices import randMatrix
  21. from sympy.matrices import Identity, ZeroMatrix, OneMatrix
  22. from sympy import conjugate, I
  23. from sympy import Heaviside, gamma, polygamma
  24. torch = import_module("torch")
  25. M = MatrixSymbol("M", 3, 3)
  26. N = MatrixSymbol("N", 3, 3)
  27. P = MatrixSymbol("P", 3, 3)
  28. Q = MatrixSymbol("Q", 3, 3)
  29. x, y, z, t = symbols("x y z t")
  30. if torch is not None:
  31. llo = [list(range(i, i + 3)) for i in range(0, 9, 3)]
  32. m3x3 = torch.tensor(llo, dtype=torch.float64)
  33. m3x3sympy = Matrix(llo)
  34. def _compare_torch_matrix(variables, expr):
  35. f = lambdify(variables, expr, 'torch')
  36. random_matrices = [randMatrix(i.shape[0], i.shape[1]) for i in variables]
  37. random_variables = [torch.tensor(i.tolist(), dtype=torch.float64) for i in random_matrices]
  38. r = f(*random_variables)
  39. e = expr.subs(dict(zip(variables, random_matrices))).doit()
  40. if isinstance(e, _CodegenArrayAbstract):
  41. e = e.doit()
  42. if hasattr(e, 'is_number') and e.is_number:
  43. if isinstance(r, torch.Tensor) and r.dim() == 0:
  44. r = r.item()
  45. e = float(e)
  46. assert abs(r - e) < 1e-6
  47. return
  48. if e.is_Matrix or isinstance(e, NDimArray):
  49. e = torch.tensor(e.tolist(), dtype=torch.float64)
  50. assert torch.allclose(r, e, atol=1e-6)
  51. else:
  52. raise TypeError(f"Cannot compare {type(r)} with {type(e)}")
  53. def _compare_torch_scalar(variables, expr, rng=lambda: random.uniform(-5, 5)):
  54. f = lambdify(variables, expr, 'torch')
  55. rvs = [rng() for v in variables]
  56. t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs]
  57. r = f(*t_rvs)
  58. if isinstance(r, torch.Tensor):
  59. r = r.item()
  60. e = expr.subs(dict(zip(variables, rvs))).doit()
  61. assert abs(r - e) < 1e-6
  62. def _compare_torch_relational(variables, expr, rng=lambda: random.randint(0, 10)):
  63. f = lambdify(variables, expr, 'torch')
  64. rvs = [rng() for v in variables]
  65. t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs]
  66. r = f(*t_rvs)
  67. e = bool(expr.subs(dict(zip(variables, rvs))).doit())
  68. assert r.item() == e
  69. def test_torch_math():
  70. if not torch:
  71. skip("PyTorch not installed")
  72. expr = Abs(x)
  73. assert torch_code(expr) == "torch.abs(x)"
  74. f = lambdify(x, expr, 'torch')
  75. ma = torch.tensor([[-1, 2, -3, -4]], dtype=torch.float64)
  76. y_abs = f(ma)
  77. c = torch.abs(ma)
  78. assert torch.all(y_abs == c)
  79. expr = sign(x)
  80. assert torch_code(expr) == "torch.sign(x)"
  81. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-10, 10))
  82. expr = ceiling(x)
  83. assert torch_code(expr) == "torch.ceil(x)"
  84. _compare_torch_scalar((x,), expr, rng=lambda: random.random())
  85. expr = floor(x)
  86. assert torch_code(expr) == "torch.floor(x)"
  87. _compare_torch_scalar((x,), expr, rng=lambda: random.random())
  88. expr = exp(x)
  89. assert torch_code(expr) == "torch.exp(x)"
  90. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
  91. expr = sqrt(x)
  92. assert torch_code(expr) == "torch.sqrt(x)"
  93. _compare_torch_scalar((x,), expr, rng=lambda: random.random())
  94. expr = x ** 4
  95. assert torch_code(expr) == "torch.pow(x, 4)"
  96. _compare_torch_scalar((x,), expr, rng=lambda: random.random())
  97. expr = cos(x)
  98. assert torch_code(expr) == "torch.cos(x)"
  99. _compare_torch_scalar((x,), expr, rng=lambda: random.random())
  100. expr = acos(x)
  101. assert torch_code(expr) == "torch.acos(x)"
  102. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99))
  103. expr = sin(x)
  104. assert torch_code(expr) == "torch.sin(x)"
  105. _compare_torch_scalar((x,), expr, rng=lambda: random.random())
  106. expr = asin(x)
  107. assert torch_code(expr) == "torch.asin(x)"
  108. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99))
  109. expr = tan(x)
  110. assert torch_code(expr) == "torch.tan(x)"
  111. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-1.5, 1.5))
  112. expr = atan(x)
  113. assert torch_code(expr) == "torch.atan(x)"
  114. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5))
  115. expr = atan2(y, x)
  116. assert torch_code(expr) == "torch.atan2(y, x)"
  117. _compare_torch_scalar((y, x), expr, rng=lambda: random.uniform(-5, 5))
  118. expr = cosh(x)
  119. assert torch_code(expr) == "torch.cosh(x)"
  120. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
  121. expr = acosh(x)
  122. assert torch_code(expr) == "torch.acosh(x)"
  123. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(1.1, 5))
  124. expr = sinh(x)
  125. assert torch_code(expr) == "torch.sinh(x)"
  126. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
  127. expr = asinh(x)
  128. assert torch_code(expr) == "torch.asinh(x)"
  129. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5))
  130. expr = tanh(x)
  131. assert torch_code(expr) == "torch.tanh(x)"
  132. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
  133. expr = atanh(x)
  134. assert torch_code(expr) == "torch.atanh(x)"
  135. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.9, 0.9))
  136. expr = erf(x)
  137. assert torch_code(expr) == "torch.erf(x)"
  138. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
  139. expr = loggamma(x)
  140. assert torch_code(expr) == "torch.lgamma(x)"
  141. _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(0.5, 5))
  142. def test_torch_complexes():
  143. assert torch_code(re(x)) == "torch.real(x)"
  144. assert torch_code(im(x)) == "torch.imag(x)"
  145. assert torch_code(arg(x)) == "torch.angle(x)"
  146. def test_torch_relational():
  147. if not torch:
  148. skip("PyTorch not installed")
  149. expr = Eq(x, y)
  150. assert torch_code(expr) == "torch.eq(x, y)"
  151. _compare_torch_relational((x, y), expr)
  152. expr = Ne(x, y)
  153. assert torch_code(expr) == "torch.ne(x, y)"
  154. _compare_torch_relational((x, y), expr)
  155. expr = Ge(x, y)
  156. assert torch_code(expr) == "torch.ge(x, y)"
  157. _compare_torch_relational((x, y), expr)
  158. expr = Gt(x, y)
  159. assert torch_code(expr) == "torch.gt(x, y)"
  160. _compare_torch_relational((x, y), expr)
  161. expr = Le(x, y)
  162. assert torch_code(expr) == "torch.le(x, y)"
  163. _compare_torch_relational((x, y), expr)
  164. expr = Lt(x, y)
  165. assert torch_code(expr) == "torch.lt(x, y)"
  166. _compare_torch_relational((x, y), expr)
  167. def test_torch_matrix():
  168. if torch is None:
  169. skip("PyTorch not installed")
  170. expr = M
  171. assert torch_code(expr) == "M"
  172. f = lambdify((M,), expr, "torch")
  173. eye_mat = eye(3)
  174. eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64)
  175. assert torch.allclose(f(eye_tensor), eye_tensor)
  176. expr = M * N
  177. assert torch_code(expr) == "torch.matmul(M, N)"
  178. _compare_torch_matrix((M, N), expr)
  179. expr = M ** 3
  180. assert torch_code(expr) == "torch.mm(torch.mm(M, M), M)"
  181. _compare_torch_matrix((M,), expr)
  182. expr = M * N * P * Q
  183. assert torch_code(expr) == "torch.matmul(torch.matmul(torch.matmul(M, N), P), Q)"
  184. _compare_torch_matrix((M, N, P, Q), expr)
  185. expr = Trace(M)
  186. assert torch_code(expr) == "torch.trace(M)"
  187. _compare_torch_matrix((M,), expr)
  188. expr = Determinant(M)
  189. assert torch_code(expr) == "torch.det(M)"
  190. _compare_torch_matrix((M,), expr)
  191. expr = HadamardProduct(M, N)
  192. assert torch_code(expr) == "torch.mul(M, N)"
  193. _compare_torch_matrix((M, N), expr)
  194. expr = Inverse(M)
  195. assert torch_code(expr) == "torch.linalg.inv(M)"
  196. # For inverse, use a matrix that's guaranteed to be invertible
  197. eye_mat = eye(3)
  198. eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64)
  199. f = lambdify((M,), expr, "torch")
  200. result = f(eye_tensor)
  201. expected = torch.linalg.inv(eye_tensor)
  202. assert torch.allclose(result, expected)
  203. def test_torch_array_operations():
  204. if not torch:
  205. skip("PyTorch not installed")
  206. M = MatrixSymbol("M", 2, 2)
  207. N = MatrixSymbol("N", 2, 2)
  208. P = MatrixSymbol("P", 2, 2)
  209. Q = MatrixSymbol("Q", 2, 2)
  210. ma = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float64)
  211. mb = torch.tensor([[1., -2.], [-1., 3.]], dtype=torch.float64)
  212. mc = torch.tensor([[2., 0.], [1., 2.]], dtype=torch.float64)
  213. md = torch.tensor([[1., -1.], [4., 7.]], dtype=torch.float64)
  214. cg = ArrayTensorProduct(M, N)
  215. assert torch_code(cg) == 'torch.einsum("ab,cd", M, N)'
  216. f = lambdify((M, N), cg, 'torch')
  217. y = f(ma, mb)
  218. c = torch.einsum("ij,kl", ma, mb)
  219. assert torch.allclose(y, c)
  220. cg = ArrayAdd(M, N)
  221. assert torch_code(cg) == 'torch.add(M, N)'
  222. f = lambdify((M, N), cg, 'torch')
  223. y = f(ma, mb)
  224. c = ma + mb
  225. assert torch.allclose(y, c)
  226. cg = ArrayAdd(M, N, P)
  227. assert torch_code(cg) == 'torch.add(torch.add(M, N), P)'
  228. f = lambdify((M, N, P), cg, 'torch')
  229. y = f(ma, mb, mc)
  230. c = ma + mb + mc
  231. assert torch.allclose(y, c)
  232. cg = ArrayAdd(M, N, P, Q)
  233. assert torch_code(cg) == 'torch.add(torch.add(torch.add(M, N), P), Q)'
  234. f = lambdify((M, N, P, Q), cg, 'torch')
  235. y = f(ma, mb, mc, md)
  236. c = ma + mb + mc + md
  237. assert torch.allclose(y, c)
  238. cg = PermuteDims(M, [1, 0])
  239. assert torch_code(cg) == 'M.permute(1, 0)'
  240. f = lambdify((M,), cg, 'torch')
  241. y = f(ma)
  242. c = ma.T
  243. assert torch.allclose(y, c)
  244. cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
  245. assert torch_code(cg) == 'torch.einsum("ab,cd", M, N).permute(1, 2, 3, 0)'
  246. f = lambdify((M, N), cg, 'torch')
  247. y = f(ma, mb)
  248. c = torch.einsum("ab,cd", ma, mb).permute(1, 2, 3, 0)
  249. assert torch.allclose(y, c)
  250. cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
  251. assert torch_code(cg) == 'torch.einsum("ab,bc->acb", M, N)'
  252. f = lambdify((M, N), cg, 'torch')
  253. y = f(ma, mb)
  254. c = torch.einsum("ab,bc->acb", ma, mb)
  255. assert torch.allclose(y, c)
  256. def test_torch_derivative():
  257. """Test derivative handling."""
  258. expr = Derivative(sin(x), x)
  259. assert torch_code(expr) == 'torch.autograd.grad(torch.sin(x), x)[0]'
  260. def test_torch_printing_dtype():
  261. if not torch:
  262. skip("PyTorch not installed")
  263. # matrix printing with default dtype
  264. expr = Matrix([[x, sin(y)], [exp(z), -t]])
  265. assert "dtype=torch.float64" in torch_code(expr)
  266. # explicit dtype
  267. assert "dtype=torch.float32" in torch_code(expr, dtype="torch.float32")
  268. # with requires_grad
  269. result = torch_code(expr, requires_grad=True)
  270. assert "requires_grad=True" in result
  271. assert "dtype=torch.float64" in result
  272. # both
  273. result = torch_code(expr, requires_grad=True, dtype="torch.float32")
  274. assert "requires_grad=True" in result
  275. assert "dtype=torch.float32" in result
  276. def test_requires_grad():
  277. if not torch:
  278. skip("PyTorch not installed")
  279. expr = sin(x) + cos(y)
  280. f = lambdify([x, y], expr, 'torch')
  281. # make sure the gradients flow
  282. x_val = torch.tensor(1.0, requires_grad=True)
  283. y_val = torch.tensor(2.0, requires_grad=True)
  284. result = f(x_val, y_val)
  285. assert result.requires_grad
  286. result.backward()
  287. # x_val.grad should be cos(x_val) which is close to cos(1.0)
  288. assert abs(x_val.grad.item() - float(cos(1.0).evalf())) < 1e-6
  289. # y_val.grad should be -sin(y_val) which is close to -sin(2.0)
  290. assert abs(y_val.grad.item() - float(-sin(2.0).evalf())) < 1e-6
  291. def test_torch_multi_variable_derivatives():
  292. if not torch:
  293. skip("PyTorch not installed")
  294. x, y, z = symbols("x y z")
  295. expr = Derivative(sin(x), x)
  296. assert torch_code(expr) == "torch.autograd.grad(torch.sin(x), x)[0]"
  297. expr = Derivative(sin(x), (x, 2))
  298. assert torch_code(
  299. expr) == "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]"
  300. expr = Derivative(sin(x * y), x, y)
  301. result = torch_code(expr)
  302. expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x*y), x, create_graph=True)[0], y, create_graph=True)[0]"
  303. normalized_result = result.replace(" ", "")
  304. normalized_expected = expected.replace(" ", "")
  305. assert normalized_result == normalized_expected
  306. expr = Derivative(sin(x), x, x)
  307. result = torch_code(expr)
  308. expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]"
  309. assert result == expected
  310. expr = Derivative(sin(x * y * z), x, (y, 2), z)
  311. result = torch_code(expr)
  312. expected = "torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.sin(x*y*z), x, create_graph=True)[0], y, create_graph=True)[0], y, create_graph=True)[0], z, create_graph=True)[0]"
  313. normalized_result = result.replace(" ", "")
  314. normalized_expected = expected.replace(" ", "")
  315. assert normalized_result == normalized_expected
  316. def test_torch_derivative_lambdify():
  317. if not torch:
  318. skip("PyTorch not installed")
  319. x = symbols("x")
  320. y = symbols("y")
  321. expr = Derivative(x ** 2, x)
  322. f = lambdify(x, expr, 'torch')
  323. x_val = torch.tensor(2.0, requires_grad=True)
  324. result = f(x_val)
  325. assert torch.isclose(result, torch.tensor(4.0))
  326. expr = Derivative(sin(x), (x, 2))
  327. f = lambdify(x, expr, 'torch')
  328. # Second derivative of sin(x) at x=0 is 0, not -1
  329. x_val = torch.tensor(0.0, requires_grad=True)
  330. result = f(x_val)
  331. assert torch.isclose(result, torch.tensor(0.0), atol=1e-5)
  332. x_val = torch.tensor(math.pi / 2, requires_grad=True)
  333. result = f(x_val)
  334. assert torch.isclose(result, torch.tensor(-1.0), atol=1e-5)
  335. expr = Derivative(x * y ** 2, x, y)
  336. f = lambdify((x, y), expr, 'torch')
  337. x_val = torch.tensor(2.0, requires_grad=True)
  338. y_val = torch.tensor(3.0, requires_grad=True)
  339. result = f(x_val, y_val)
  340. assert torch.isclose(result, torch.tensor(6.0))
  341. def test_torch_special_matrices():
  342. if not torch:
  343. skip("PyTorch not installed")
  344. expr = Identity(3)
  345. assert torch_code(expr) == "torch.eye(3)"
  346. n = symbols("n")
  347. expr = Identity(n)
  348. assert torch_code(expr) == "torch.eye(n, n)"
  349. expr = ZeroMatrix(2, 3)
  350. assert torch_code(expr) == "torch.zeros((2, 3))"
  351. m, n = symbols("m n")
  352. expr = ZeroMatrix(m, n)
  353. assert torch_code(expr) == "torch.zeros((m, n))"
  354. expr = OneMatrix(2, 3)
  355. assert torch_code(expr) == "torch.ones((2, 3))"
  356. expr = OneMatrix(m, n)
  357. assert torch_code(expr) == "torch.ones((m, n))"
  358. def test_torch_special_matrices_lambdify():
  359. if not torch:
  360. skip("PyTorch not installed")
  361. expr = Identity(3)
  362. f = lambdify([], expr, 'torch')
  363. result = f()
  364. expected = torch.eye(3)
  365. assert torch.allclose(result, expected)
  366. expr = ZeroMatrix(2, 3)
  367. f = lambdify([], expr, 'torch')
  368. result = f()
  369. expected = torch.zeros((2, 3))
  370. assert torch.allclose(result, expected)
  371. expr = OneMatrix(2, 3)
  372. f = lambdify([], expr, 'torch')
  373. result = f()
  374. expected = torch.ones((2, 3))
  375. assert torch.allclose(result, expected)
  376. def test_torch_complex_operations():
  377. if not torch:
  378. skip("PyTorch not installed")
  379. expr = conjugate(x)
  380. assert torch_code(expr) == "torch.conj(x)"
  381. # SymPy distributes conjugate over addition and applies specific rules for each term
  382. expr = conjugate(sin(x) + I * cos(y))
  383. assert torch_code(expr) == "torch.sin(torch.conj(x)) - 1j*torch.cos(torch.conj(y))"
  384. expr = I
  385. assert torch_code(expr) == "1j"
  386. expr = 2 * I + x
  387. assert torch_code(expr) == "x + 2*1j"
  388. expr = exp(I * x)
  389. assert torch_code(expr) == "torch.exp(1j*x)"
  390. def test_torch_special_functions():
  391. if not torch:
  392. skip("PyTorch not installed")
  393. expr = Heaviside(x)
  394. assert torch_code(expr) == "torch.heaviside(x, 1/2)"
  395. expr = Heaviside(x, 0)
  396. assert torch_code(expr) == "torch.heaviside(x, 0)"
  397. expr = gamma(x)
  398. assert torch_code(expr) == "torch.special.gamma(x)"
  399. expr = polygamma(0, x) # Use polygamma instead of digamma because sympy will default to that anyway
  400. assert torch_code(expr) == "torch.special.digamma(x)"
  401. expr = gamma(sin(x))
  402. assert torch_code(expr) == "torch.special.gamma(torch.sin(x))"