test_quaternion.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. from sympy.testing.pytest import slow
  2. from sympy.core.function import diff
  3. from sympy.core.function import expand
  4. from sympy.core.numbers import (E, I, Rational, pi)
  5. from sympy.core.singleton import S
  6. from sympy.core.symbol import (Symbol, symbols)
  7. from sympy.functions.elementary.complexes import (Abs, conjugate, im, re, sign)
  8. from sympy.functions.elementary.exponential import log
  9. from sympy.functions.elementary.miscellaneous import sqrt
  10. from sympy.functions.elementary.trigonometric import (acos, asin, cos, sin, atan2, atan)
  11. from sympy.integrals.integrals import integrate
  12. from sympy.matrices.dense import Matrix
  13. from sympy.simplify import simplify
  14. from sympy.simplify.trigsimp import trigsimp
  15. from sympy.algebras.quaternion import Quaternion
  16. from sympy.testing.pytest import raises
  17. import math
  18. from itertools import permutations, product
  19. w, x, y, z = symbols('w:z')
  20. phi = symbols('phi')
  21. def test_quaternion_construction():
  22. q = Quaternion(w, x, y, z)
  23. assert q + q == Quaternion(2*w, 2*x, 2*y, 2*z)
  24. q2 = Quaternion.from_axis_angle((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3),
  25. pi*Rational(2, 3))
  26. assert q2 == Quaternion(S.Half, S.Half,
  27. S.Half, S.Half)
  28. M = Matrix([[cos(phi), -sin(phi), 0], [sin(phi), cos(phi), 0], [0, 0, 1]])
  29. q3 = trigsimp(Quaternion.from_rotation_matrix(M))
  30. assert q3 == Quaternion(
  31. sqrt(2)*sqrt(cos(phi) + 1)/2, 0, 0, sqrt(2 - 2*cos(phi))*sign(sin(phi))/2)
  32. nc = Symbol('nc', commutative=False)
  33. raises(ValueError, lambda: Quaternion(w, x, nc, z))
  34. def test_quaternion_construction_norm():
  35. q1 = Quaternion(*symbols('a:d'))
  36. q2 = Quaternion(w, x, y, z)
  37. assert expand((q1*q2).norm()**2 - (q1.norm()**2 * q2.norm()**2)) == 0
  38. q3 = Quaternion(w, x, y, z, norm=1)
  39. assert (q1 * q3).norm() == q1.norm()
  40. def test_issue_25254():
  41. # calculating the inverse cached the norm which caused problems
  42. # when multiplying
  43. p = Quaternion(1, 0, 0, 0)
  44. q = Quaternion.from_axis_angle((1, 1, 1), 3 * math.pi/4)
  45. qi = q.inverse() # this operation cached the norm
  46. test = q * p * qi
  47. assert ((test - p).norm() < 1E-10)
  48. def test_to_and_from_Matrix():
  49. q = Quaternion(w, x, y, z)
  50. q_full = Quaternion.from_Matrix(q.to_Matrix())
  51. q_vect = Quaternion.from_Matrix(q.to_Matrix(True))
  52. assert (q - q_full).is_zero_quaternion()
  53. assert (q.vector_part() - q_vect).is_zero_quaternion()
  54. def test_product_matrices():
  55. q1 = Quaternion(w, x, y, z)
  56. q2 = Quaternion(*(symbols("a:d")))
  57. assert (q1 * q2).to_Matrix() == q1.product_matrix_left * q2.to_Matrix()
  58. assert (q1 * q2).to_Matrix() == q2.product_matrix_right * q1.to_Matrix()
  59. R1 = (q1.product_matrix_left * q1.product_matrix_right.T)[1:, 1:]
  60. R2 = simplify(q1.to_rotation_matrix()*q1.norm()**2)
  61. assert R1 == R2
  62. def test_quaternion_axis_angle():
  63. test_data = [ # axis, angle, expected_quaternion
  64. ((1, 0, 0), 0, (1, 0, 0, 0)),
  65. ((1, 0, 0), pi/2, (sqrt(2)/2, sqrt(2)/2, 0, 0)),
  66. ((0, 1, 0), pi/2, (sqrt(2)/2, 0, sqrt(2)/2, 0)),
  67. ((0, 0, 1), pi/2, (sqrt(2)/2, 0, 0, sqrt(2)/2)),
  68. ((1, 0, 0), pi, (0, 1, 0, 0)),
  69. ((0, 1, 0), pi, (0, 0, 1, 0)),
  70. ((0, 0, 1), pi, (0, 0, 0, 1)),
  71. ((1, 1, 1), pi, (0, 1/sqrt(3),1/sqrt(3),1/sqrt(3))),
  72. ((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3), pi*2/3, (S.Half, S.Half, S.Half, S.Half))
  73. ]
  74. for axis, angle, expected in test_data:
  75. assert Quaternion.from_axis_angle(axis, angle) == Quaternion(*expected)
  76. def test_quaternion_axis_angle_simplification():
  77. result = Quaternion.from_axis_angle((1, 2, 3), asin(4))
  78. assert result.a == cos(asin(4)/2)
  79. assert result.b == sqrt(14)*sin(asin(4)/2)/14
  80. assert result.c == sqrt(14)*sin(asin(4)/2)/7
  81. assert result.d == 3*sqrt(14)*sin(asin(4)/2)/14
  82. def test_quaternion_complex_real_addition():
  83. a = symbols("a", complex=True)
  84. b = symbols("b", real=True)
  85. # This symbol is not complex:
  86. c = symbols("c", commutative=False)
  87. q = Quaternion(w, x, y, z)
  88. assert a + q == Quaternion(w + re(a), x + im(a), y, z)
  89. assert 1 + q == Quaternion(1 + w, x, y, z)
  90. assert I + q == Quaternion(w, 1 + x, y, z)
  91. assert b + q == Quaternion(w + b, x, y, z)
  92. raises(ValueError, lambda: c + q)
  93. raises(ValueError, lambda: q * c)
  94. raises(ValueError, lambda: c * q)
  95. assert -q == Quaternion(-w, -x, -y, -z)
  96. q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False)
  97. q2 = Quaternion(1, 4, 7, 8)
  98. assert q1 + (2 + 3*I) == Quaternion(5 + 7*I, 2 + 5*I, 0, 7 + 8*I)
  99. assert q2 + (2 + 3*I) == Quaternion(3, 7, 7, 8)
  100. assert q1 * (2 + 3*I) == \
  101. Quaternion((2 + 3*I)*(3 + 4*I), (2 + 3*I)*(2 + 5*I), 0, (2 + 3*I)*(7 + 8*I))
  102. assert q2 * (2 + 3*I) == Quaternion(-10, 11, 38, -5)
  103. q1 = Quaternion(1, 2, 3, 4)
  104. q0 = Quaternion(0, 0, 0, 0)
  105. assert q1 + q0 == q1
  106. assert q1 - q0 == q1
  107. assert q1 - q1 == q0
  108. def test_quaternion_subs():
  109. q = Quaternion.from_axis_angle((0, 0, 1), phi)
  110. assert q.subs(phi, 0) == Quaternion(1, 0, 0, 0)
  111. def test_quaternion_evalf():
  112. assert (Quaternion(sqrt(2), 0, 0, sqrt(3)).evalf() ==
  113. Quaternion(sqrt(2).evalf(), 0, 0, sqrt(3).evalf()))
  114. assert (Quaternion(1/sqrt(2), 0, 0, 1/sqrt(2)).evalf() ==
  115. Quaternion((1/sqrt(2)).evalf(), 0, 0, (1/sqrt(2)).evalf()))
  116. def test_quaternion_functions():
  117. q = Quaternion(w, x, y, z)
  118. q1 = Quaternion(1, 2, 3, 4)
  119. q0 = Quaternion(0, 0, 0, 0)
  120. assert conjugate(q) == Quaternion(w, -x, -y, -z)
  121. assert q.norm() == sqrt(w**2 + x**2 + y**2 + z**2)
  122. assert q.normalize() == Quaternion(w, x, y, z) / sqrt(w**2 + x**2 + y**2 + z**2)
  123. assert q.inverse() == Quaternion(w, -x, -y, -z) / (w**2 + x**2 + y**2 + z**2)
  124. assert q.inverse() == q.pow(-1)
  125. raises(ValueError, lambda: q0.inverse())
  126. assert q.pow(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z)
  127. assert q**(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z)
  128. assert q1.pow(-2) == Quaternion(
  129. Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225))
  130. assert q1**(-2) == Quaternion(
  131. Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225))
  132. assert q1.pow(-0.5) == NotImplemented
  133. raises(TypeError, lambda: q1**(-0.5))
  134. assert q1.exp() == \
  135. Quaternion(E * cos(sqrt(29)),
  136. 2 * sqrt(29) * E * sin(sqrt(29)) / 29,
  137. 3 * sqrt(29) * E * sin(sqrt(29)) / 29,
  138. 4 * sqrt(29) * E * sin(sqrt(29)) / 29)
  139. assert q1.log() == \
  140. Quaternion(log(sqrt(30)),
  141. 2 * sqrt(29) * acos(sqrt(30)/30) / 29,
  142. 3 * sqrt(29) * acos(sqrt(30)/30) / 29,
  143. 4 * sqrt(29) * acos(sqrt(30)/30) / 29)
  144. assert q1.pow_cos_sin(2) == \
  145. Quaternion(30 * cos(2 * acos(sqrt(30)/30)),
  146. 60 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29,
  147. 90 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29,
  148. 120 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29)
  149. assert diff(Quaternion(x, x, x, x), x) == Quaternion(1, 1, 1, 1)
  150. assert integrate(Quaternion(x, x, x, x), x) == \
  151. Quaternion(x**2 / 2, x**2 / 2, x**2 / 2, x**2 / 2)
  152. assert Quaternion(1, x, x**2, x**3).integrate(x) == \
  153. Quaternion(x, x**2/2, x**3/3, x**4/4)
  154. assert Quaternion(sin(x), cos(x), sin(2*x), cos(2*x)).integrate(x) == \
  155. Quaternion(-cos(x), sin(x), -cos(2*x)/2, sin(2*x)/2)
  156. assert Quaternion(x**2, y**2, z**2, x*y*z).integrate(x, y) == \
  157. Quaternion(x**3*y/3, x*y**3/3, x*y*z**2, x**2*y**2*z/4)
  158. assert Quaternion.rotate_point((1, 1, 1), q1) == (S.One / 5, 1, S(7) / 5)
  159. n = Symbol('n')
  160. raises(TypeError, lambda: q1**n)
  161. n = Symbol('n', integer=True)
  162. raises(TypeError, lambda: q1**n)
  163. assert Quaternion(22, 23, 55, 8).scalar_part() == 22
  164. assert Quaternion(w, x, y, z).scalar_part() == w
  165. assert Quaternion(22, 23, 55, 8).vector_part() == Quaternion(0, 23, 55, 8)
  166. assert Quaternion(w, x, y, z).vector_part() == Quaternion(0, x, y, z)
  167. assert q1.axis() == Quaternion(0, 2*sqrt(29)/29, 3*sqrt(29)/29, 4*sqrt(29)/29)
  168. assert q1.axis().pow(2) == Quaternion(-1, 0, 0, 0)
  169. assert q0.axis().scalar_part() == 0
  170. assert (q.axis() == Quaternion(0,
  171. x/sqrt(x**2 + y**2 + z**2),
  172. y/sqrt(x**2 + y**2 + z**2),
  173. z/sqrt(x**2 + y**2 + z**2)))
  174. assert q0.is_pure() is True
  175. assert q1.is_pure() is False
  176. assert Quaternion(0, 0, 0, 3).is_pure() is True
  177. assert Quaternion(0, 2, 10, 3).is_pure() is True
  178. assert Quaternion(w, 2, 10, 3).is_pure() is None
  179. assert q1.angle() == 2*atan(sqrt(29))
  180. assert q.angle() == 2*atan2(sqrt(x**2 + y**2 + z**2), w)
  181. assert Quaternion.arc_coplanar(q1, Quaternion(2, 4, 6, 8)) is True
  182. assert Quaternion.arc_coplanar(q1, Quaternion(1, -2, -3, -4)) is True
  183. assert Quaternion.arc_coplanar(q1, Quaternion(1, 8, 12, 16)) is True
  184. assert Quaternion.arc_coplanar(q1, Quaternion(1, 2, 3, 4)) is True
  185. assert Quaternion.arc_coplanar(q1, Quaternion(w, 4, 6, 8)) is True
  186. assert Quaternion.arc_coplanar(q1, Quaternion(2, 7, 4, 1)) is False
  187. assert Quaternion.arc_coplanar(q1, Quaternion(w, x, y, z)) is None
  188. raises(ValueError, lambda: Quaternion.arc_coplanar(q1, q0))
  189. assert Quaternion.vector_coplanar(
  190. Quaternion(0, 8, 12, 16),
  191. Quaternion(0, 4, 6, 8),
  192. Quaternion(0, 2, 3, 4)) is True
  193. assert Quaternion.vector_coplanar(
  194. Quaternion(0, 0, 0, 0), Quaternion(0, 4, 6, 8), Quaternion(0, 2, 3, 4)) is True
  195. assert Quaternion.vector_coplanar(
  196. Quaternion(0, 8, 2, 6), Quaternion(0, 1, 6, 6), Quaternion(0, 0, 3, 4)) is False
  197. assert Quaternion.vector_coplanar(
  198. Quaternion(0, 1, 3, 4),
  199. Quaternion(0, 4, w, 6),
  200. Quaternion(0, 6, 8, 1)) is None
  201. raises(ValueError, lambda:
  202. Quaternion.vector_coplanar(q0, Quaternion(0, 4, 6, 8), q1))
  203. assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 4, 6)) is True
  204. assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 2, 6)) is False
  205. assert Quaternion(0, 1, 2, 3).parallel(Quaternion(w, x, y, 6)) is None
  206. raises(ValueError, lambda: q0.parallel(q1))
  207. assert Quaternion(0, 1, 2, 3).orthogonal(Quaternion(0, -2, 1, 0)) is True
  208. assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(0, 2, 2, 6)) is False
  209. assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(w, x, y, 6)) is None
  210. raises(ValueError, lambda: q0.orthogonal(q1))
  211. assert q1.index_vector() == Quaternion(
  212. 0, 2*sqrt(870)/29,
  213. 3*sqrt(870)/29,
  214. 4*sqrt(870)/29)
  215. assert Quaternion(0, 3, 9, 4).index_vector() == Quaternion(0, 3, 9, 4)
  216. assert Quaternion(4, 3, 9, 4).mensor() == log(sqrt(122))
  217. assert Quaternion(3, 3, 0, 2).mensor() == log(sqrt(22))
  218. assert q0.is_zero_quaternion() is True
  219. assert q1.is_zero_quaternion() is False
  220. assert Quaternion(w, 0, 0, 0).is_zero_quaternion() is None
  221. def test_quaternion_conversions():
  222. q1 = Quaternion(1, 2, 3, 4)
  223. assert q1.to_axis_angle() == ((2 * sqrt(29)/29,
  224. 3 * sqrt(29)/29,
  225. 4 * sqrt(29)/29),
  226. 2 * acos(sqrt(30)/30))
  227. assert (q1.to_rotation_matrix() ==
  228. Matrix([[Rational(-2, 3), Rational(2, 15), Rational(11, 15)],
  229. [Rational(2, 3), Rational(-1, 3), Rational(2, 3)],
  230. [Rational(1, 3), Rational(14, 15), Rational(2, 15)]]))
  231. assert (q1.to_rotation_matrix((1, 1, 1)) ==
  232. Matrix([
  233. [Rational(-2, 3), Rational(2, 15), Rational(11, 15), Rational(4, 5)],
  234. [Rational(2, 3), Rational(-1, 3), Rational(2, 3), S.Zero],
  235. [Rational(1, 3), Rational(14, 15), Rational(2, 15), Rational(-2, 5)],
  236. [S.Zero, S.Zero, S.Zero, S.One]]))
  237. theta = symbols("theta", real=True)
  238. q2 = Quaternion(cos(theta/2), 0, 0, sin(theta/2))
  239. assert trigsimp(q2.to_rotation_matrix()) == Matrix([
  240. [cos(theta), -sin(theta), 0],
  241. [sin(theta), cos(theta), 0],
  242. [0, 0, 1]])
  243. assert q2.to_axis_angle() == ((0, 0, sin(theta/2)/Abs(sin(theta/2))),
  244. 2*acos(cos(theta/2)))
  245. assert trigsimp(q2.to_rotation_matrix((1, 1, 1))) == Matrix([
  246. [cos(theta), -sin(theta), 0, sin(theta) - cos(theta) + 1],
  247. [sin(theta), cos(theta), 0, -sin(theta) - cos(theta) + 1],
  248. [0, 0, 1, 0],
  249. [0, 0, 0, 1]])
  250. def test_rotation_matrix_homogeneous():
  251. q = Quaternion(w, x, y, z)
  252. R1 = q.to_rotation_matrix(homogeneous=True) * q.norm()**2
  253. R2 = simplify(q.to_rotation_matrix(homogeneous=False) * q.norm()**2)
  254. assert R1 == R2
  255. def test_quaternion_rotation_iss1593():
  256. """
  257. There was a sign mistake in the definition,
  258. of the rotation matrix. This tests that particular sign mistake.
  259. See issue 1593 for reference.
  260. See wikipedia
  261. https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix
  262. for the correct definition
  263. """
  264. q = Quaternion(cos(phi/2), sin(phi/2), 0, 0)
  265. assert(trigsimp(q.to_rotation_matrix()) == Matrix([
  266. [1, 0, 0],
  267. [0, cos(phi), -sin(phi)],
  268. [0, sin(phi), cos(phi)]]))
  269. def test_quaternion_multiplication():
  270. q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False)
  271. q2 = Quaternion(1, 2, 3, 5)
  272. q3 = Quaternion(1, 1, 1, y)
  273. assert Quaternion._generic_mul(S(4), S.One) == 4
  274. assert (Quaternion._generic_mul(S(4), q1) ==
  275. Quaternion(12 + 16*I, 8 + 20*I, 0, 28 + 32*I))
  276. assert q2.mul(2) == Quaternion(2, 4, 6, 10)
  277. assert q2.mul(q3) == Quaternion(-5*y - 4, 3*y - 2, 9 - 2*y, y + 4)
  278. assert q2.mul(q3) == q2*q3
  279. z = symbols('z', complex=True)
  280. z_quat = Quaternion(re(z), im(z), 0, 0)
  281. q = Quaternion(*symbols('q:4', real=True))
  282. assert z * q == z_quat * q
  283. assert q * z == q * z_quat
  284. def test_issue_16318():
  285. #for rtruediv
  286. q0 = Quaternion(0, 0, 0, 0)
  287. raises(ValueError, lambda: 1/q0)
  288. #for rotate_point
  289. q = Quaternion(1, 2, 3, 4)
  290. (axis, angle) = q.to_axis_angle()
  291. assert Quaternion.rotate_point((1, 1, 1), (axis, angle)) == (S.One / 5, 1, S(7) / 5)
  292. #test for to_axis_angle
  293. q = Quaternion(-1, 1, 1, 1)
  294. axis = (-sqrt(3)/3, -sqrt(3)/3, -sqrt(3)/3)
  295. angle = 2*pi/3
  296. assert (axis, angle) == q.to_axis_angle()
  297. @slow
  298. def test_to_euler():
  299. q = Quaternion(w, x, y, z)
  300. q_normalized = q.normalize()
  301. seqs = ['zxy', 'zyx', 'zyz', 'zxz']
  302. seqs += [seq.upper() for seq in seqs]
  303. for seq in seqs:
  304. euler_from_q = q.to_euler(seq)
  305. q_back = simplify(Quaternion.from_euler(euler_from_q, seq))
  306. assert q_back == q_normalized
  307. def test_to_euler_iss24504():
  308. """
  309. There was a mistake in the degenerate case testing
  310. See issue 24504 for reference.
  311. """
  312. q = Quaternion.from_euler((phi, 0, 0), 'zyz')
  313. assert trigsimp(q.to_euler('zyz'), inverse=True) == (phi, 0, 0)
  314. def test_to_euler_numerical_singilarities():
  315. def test_one_case(angles, seq):
  316. q = Quaternion.from_euler(angles, seq)
  317. assert q.to_euler(seq) == angles
  318. # symmetric
  319. test_one_case((pi/2, 0, 0), 'zyz')
  320. test_one_case((pi/2, 0, 0), 'ZYZ')
  321. test_one_case((pi/2, pi, 0), 'zyz')
  322. test_one_case((pi/2, pi, 0), 'ZYZ')
  323. # asymmetric
  324. test_one_case((pi/2, pi/2, 0), 'zyx')
  325. test_one_case((pi/2, -pi/2, 0), 'zyx')
  326. test_one_case((pi/2, pi/2, 0), 'ZYX')
  327. test_one_case((pi/2, -pi/2, 0), 'ZYX')
  328. @slow
  329. def test_to_euler_options():
  330. def test_one_case(q):
  331. angles1 = Matrix(q.to_euler(seq, True, True))
  332. angles2 = Matrix(q.to_euler(seq, False, False))
  333. angle_errors = simplify(angles1-angles2).evalf()
  334. for angle_error in angle_errors:
  335. # forcing angles to set {-pi, pi}
  336. angle_error = (angle_error + pi) % (2 * pi) - pi
  337. assert angle_error < 10e-7
  338. for xyz in ('xyz', 'XYZ'):
  339. for seq_tuple in permutations(xyz):
  340. for symmetric in (True, False):
  341. if symmetric:
  342. seq = ''.join([seq_tuple[0], seq_tuple[1], seq_tuple[0]])
  343. else:
  344. seq = ''.join(seq_tuple)
  345. for elements in product([-1, 0, 1], repeat=4):
  346. q = Quaternion(*elements)
  347. if not q.is_zero_quaternion():
  348. test_one_case(q)