test_defmatrix.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import collections.abc
  2. import numpy as np
  3. from numpy import asmatrix, bmat, matrix
  4. from numpy.linalg import matrix_power
  5. from numpy.testing import (
  6. assert_,
  7. assert_almost_equal,
  8. assert_array_almost_equal,
  9. assert_array_equal,
  10. assert_equal,
  11. assert_raises,
  12. )
  13. class TestCtor:
  14. def test_basic(self):
  15. A = np.array([[1, 2], [3, 4]])
  16. mA = matrix(A)
  17. assert_(np.all(mA.A == A))
  18. B = bmat("A,A;A,A")
  19. C = bmat([[A, A], [A, A]])
  20. D = np.array([[1, 2, 1, 2],
  21. [3, 4, 3, 4],
  22. [1, 2, 1, 2],
  23. [3, 4, 3, 4]])
  24. assert_(np.all(B.A == D))
  25. assert_(np.all(C.A == D))
  26. E = np.array([[5, 6], [7, 8]])
  27. AEresult = matrix([[1, 2, 5, 6], [3, 4, 7, 8]])
  28. assert_(np.all(bmat([A, E]) == AEresult))
  29. vec = np.arange(5)
  30. mvec = matrix(vec)
  31. assert_(mvec.shape == (1, 5))
  32. def test_exceptions(self):
  33. # Check for ValueError when called with invalid string data.
  34. assert_raises(ValueError, matrix, "invalid")
  35. def test_bmat_nondefault_str(self):
  36. A = np.array([[1, 2], [3, 4]])
  37. B = np.array([[5, 6], [7, 8]])
  38. Aresult = np.array([[1, 2, 1, 2],
  39. [3, 4, 3, 4],
  40. [1, 2, 1, 2],
  41. [3, 4, 3, 4]])
  42. mixresult = np.array([[1, 2, 5, 6],
  43. [3, 4, 7, 8],
  44. [5, 6, 1, 2],
  45. [7, 8, 3, 4]])
  46. assert_(np.all(bmat("A,A;A,A") == Aresult))
  47. assert_(np.all(bmat("A,A;A,A", ldict={'A': B}) == Aresult))
  48. assert_raises(TypeError, bmat, "A,A;A,A", gdict={'A': B})
  49. assert_(
  50. np.all(bmat("A,A;A,A", ldict={'A': A}, gdict={'A': B}) == Aresult))
  51. b2 = bmat("A,B;C,D", ldict={'A': A, 'B': B}, gdict={'C': B, 'D': A})
  52. assert_(np.all(b2 == mixresult))
  53. class TestProperties:
  54. def test_sum(self):
  55. """Test whether matrix.sum(axis=1) preserves orientation.
  56. Fails in NumPy <= 0.9.6.2127.
  57. """
  58. M = matrix([[1, 2, 0, 0],
  59. [3, 4, 0, 0],
  60. [1, 2, 1, 2],
  61. [3, 4, 3, 4]])
  62. sum0 = matrix([8, 12, 4, 6])
  63. sum1 = matrix([3, 7, 6, 14]).T
  64. sumall = 30
  65. assert_array_equal(sum0, M.sum(axis=0))
  66. assert_array_equal(sum1, M.sum(axis=1))
  67. assert_equal(sumall, M.sum())
  68. assert_array_equal(sum0, np.sum(M, axis=0))
  69. assert_array_equal(sum1, np.sum(M, axis=1))
  70. assert_equal(sumall, np.sum(M))
  71. def test_prod(self):
  72. x = matrix([[1, 2, 3], [4, 5, 6]])
  73. assert_equal(x.prod(), 720)
  74. assert_equal(x.prod(0), matrix([[4, 10, 18]]))
  75. assert_equal(x.prod(1), matrix([[6], [120]]))
  76. assert_equal(np.prod(x), 720)
  77. assert_equal(np.prod(x, axis=0), matrix([[4, 10, 18]]))
  78. assert_equal(np.prod(x, axis=1), matrix([[6], [120]]))
  79. y = matrix([0, 1, 3])
  80. assert_(y.prod() == 0)
  81. def test_max(self):
  82. x = matrix([[1, 2, 3], [4, 5, 6]])
  83. assert_equal(x.max(), 6)
  84. assert_equal(x.max(0), matrix([[4, 5, 6]]))
  85. assert_equal(x.max(1), matrix([[3], [6]]))
  86. assert_equal(np.max(x), 6)
  87. assert_equal(np.max(x, axis=0), matrix([[4, 5, 6]]))
  88. assert_equal(np.max(x, axis=1), matrix([[3], [6]]))
  89. def test_min(self):
  90. x = matrix([[1, 2, 3], [4, 5, 6]])
  91. assert_equal(x.min(), 1)
  92. assert_equal(x.min(0), matrix([[1, 2, 3]]))
  93. assert_equal(x.min(1), matrix([[1], [4]]))
  94. assert_equal(np.min(x), 1)
  95. assert_equal(np.min(x, axis=0), matrix([[1, 2, 3]]))
  96. assert_equal(np.min(x, axis=1), matrix([[1], [4]]))
  97. def test_ptp(self):
  98. x = np.arange(4).reshape((2, 2))
  99. mx = x.view(np.matrix)
  100. assert_(mx.ptp() == 3)
  101. assert_(np.all(mx.ptp(0) == np.array([2, 2])))
  102. assert_(np.all(mx.ptp(1) == np.array([1, 1])))
  103. def test_var(self):
  104. x = np.arange(9).reshape((3, 3))
  105. mx = x.view(np.matrix)
  106. assert_equal(x.var(ddof=0), mx.var(ddof=0))
  107. assert_equal(x.var(ddof=1), mx.var(ddof=1))
  108. def test_basic(self):
  109. import numpy.linalg as linalg
  110. A = np.array([[1., 2.],
  111. [3., 4.]])
  112. mA = matrix(A)
  113. assert_(np.allclose(linalg.inv(A), mA.I))
  114. assert_(np.all(np.array(np.transpose(A) == mA.T)))
  115. assert_(np.all(np.array(np.transpose(A) == mA.H)))
  116. assert_(np.all(A == mA.A))
  117. B = A + 2j * A
  118. mB = matrix(B)
  119. assert_(np.allclose(linalg.inv(B), mB.I))
  120. assert_(np.all(np.array(np.transpose(B) == mB.T)))
  121. assert_(np.all(np.array(np.transpose(B).conj() == mB.H)))
  122. def test_pinv(self):
  123. x = matrix(np.arange(6).reshape(2, 3))
  124. xpinv = matrix([[-0.77777778, 0.27777778],
  125. [-0.11111111, 0.11111111],
  126. [ 0.55555556, -0.05555556]])
  127. assert_almost_equal(x.I, xpinv)
  128. def test_comparisons(self):
  129. A = np.arange(100).reshape(10, 10)
  130. mA = matrix(A)
  131. mB = matrix(A) + 0.1
  132. assert_(np.all(mB == A + 0.1))
  133. assert_(np.all(mB == matrix(A + 0.1)))
  134. assert_(not np.any(mB == matrix(A - 0.1)))
  135. assert_(np.all(mA < mB))
  136. assert_(np.all(mA <= mB))
  137. assert_(np.all(mA <= mA))
  138. assert_(not np.any(mA < mA))
  139. assert_(not np.any(mB < mA))
  140. assert_(np.all(mB >= mA))
  141. assert_(np.all(mB >= mB))
  142. assert_(not np.any(mB > mB))
  143. assert_(np.all(mA == mA))
  144. assert_(not np.any(mA == mB))
  145. assert_(np.all(mB != mA))
  146. assert_(not np.all(abs(mA) > 0))
  147. assert_(np.all(abs(mB > 0)))
  148. def test_asmatrix(self):
  149. A = np.arange(100).reshape(10, 10)
  150. mA = asmatrix(A)
  151. A[0, 0] = -10
  152. assert_(A[0, 0] == mA[0, 0])
  153. def test_noaxis(self):
  154. A = matrix([[1, 0], [0, 1]])
  155. assert_(A.sum() == matrix(2))
  156. assert_(A.mean() == matrix(0.5))
  157. def test_repr(self):
  158. A = matrix([[1, 0], [0, 1]])
  159. assert_(repr(A) == "matrix([[1, 0],\n [0, 1]])")
  160. def test_make_bool_matrix_from_str(self):
  161. A = matrix('True; True; False')
  162. B = matrix([[True], [True], [False]])
  163. assert_array_equal(A, B)
  164. class TestCasting:
  165. def test_basic(self):
  166. A = np.arange(100).reshape(10, 10)
  167. mA = matrix(A)
  168. mB = mA.copy()
  169. O = np.ones((10, 10), np.float64) * 0.1
  170. mB = mB + O
  171. assert_(mB.dtype.type == np.float64)
  172. assert_(np.all(mA != mB))
  173. assert_(np.all(mB == mA + 0.1))
  174. mC = mA.copy()
  175. O = np.ones((10, 10), np.complex128)
  176. mC = mC * O
  177. assert_(mC.dtype.type == np.complex128)
  178. assert_(np.all(mA != mB))
  179. class TestAlgebra:
  180. def test_basic(self):
  181. import numpy.linalg as linalg
  182. A = np.array([[1., 2.], [3., 4.]])
  183. mA = matrix(A)
  184. B = np.identity(2)
  185. for i in range(6):
  186. assert_(np.allclose((mA ** i).A, B))
  187. B = np.dot(B, A)
  188. Ainv = linalg.inv(A)
  189. B = np.identity(2)
  190. for i in range(6):
  191. assert_(np.allclose((mA ** -i).A, B))
  192. B = np.dot(B, Ainv)
  193. assert_(np.allclose((mA * mA).A, np.dot(A, A)))
  194. assert_(np.allclose((mA + mA).A, (A + A)))
  195. assert_(np.allclose((3 * mA).A, (3 * A)))
  196. mA2 = matrix(A)
  197. mA2 *= 3
  198. assert_(np.allclose(mA2.A, 3 * A))
  199. def test_pow(self):
  200. """Test raising a matrix to an integer power works as expected."""
  201. m = matrix("1. 2.; 3. 4.")
  202. m2 = m.copy()
  203. m2 **= 2
  204. mi = m.copy()
  205. mi **= -1
  206. m4 = m2.copy()
  207. m4 **= 2
  208. assert_array_almost_equal(m2, m**2)
  209. assert_array_almost_equal(m4, np.dot(m2, m2))
  210. assert_array_almost_equal(np.dot(mi, m), np.eye(2))
  211. def test_scalar_type_pow(self):
  212. m = matrix([[1, 2], [3, 4]])
  213. for scalar_t in [np.int8, np.uint8]:
  214. two = scalar_t(2)
  215. assert_array_almost_equal(m ** 2, m ** two)
  216. def test_notimplemented(self):
  217. '''Check that 'not implemented' operations produce a failure.'''
  218. A = matrix([[1., 2.],
  219. [3., 4.]])
  220. # __rpow__
  221. with assert_raises(TypeError):
  222. 1.0**A
  223. # __mul__ with something not a list, ndarray, tuple, or scalar
  224. with assert_raises(TypeError):
  225. A * object()
  226. class TestMatrixReturn:
  227. def test_instance_methods(self):
  228. a = matrix([1.0], dtype='f8')
  229. methodargs = {
  230. 'astype': ('intc',),
  231. 'clip': (0.0, 1.0),
  232. 'compress': ([1],),
  233. 'repeat': (1,),
  234. 'reshape': (1,),
  235. 'swapaxes': (0, 0),
  236. 'dot': np.array([1.0]),
  237. }
  238. excluded_methods = [
  239. 'argmin', 'choose', 'dump', 'dumps', 'fill', 'getfield',
  240. 'getA', 'getA1', 'item', 'nonzero', 'put', 'putmask', 'resize',
  241. 'searchsorted', 'setflags', 'setfield', 'sort',
  242. 'partition', 'argpartition', 'to_device',
  243. 'take', 'tofile', 'tolist', 'tobytes', 'all', 'any',
  244. 'sum', 'argmax', 'argmin', 'min', 'max', 'mean', 'var', 'ptp',
  245. 'prod', 'std', 'ctypes', 'bitwise_count',
  246. ]
  247. for attrib in dir(a):
  248. if attrib.startswith('_') or attrib in excluded_methods:
  249. continue
  250. f = getattr(a, attrib)
  251. if isinstance(f, collections.abc.Callable):
  252. # reset contents of a
  253. a.astype('f8')
  254. a.fill(1.0)
  255. args = methodargs.get(attrib, ())
  256. b = f(*args)
  257. assert_(type(b) is matrix, f"{attrib}")
  258. assert_(type(a.real) is matrix)
  259. assert_(type(a.imag) is matrix)
  260. c, d = matrix([0.0]).nonzero()
  261. assert_(type(c) is np.ndarray)
  262. assert_(type(d) is np.ndarray)
  263. class TestIndexing:
  264. def test_basic(self):
  265. x = asmatrix(np.zeros((3, 2), float))
  266. y = np.zeros((3, 1), float)
  267. y[:, 0] = [0.8, 0.2, 0.3]
  268. x[:, 1] = y > 0.5
  269. assert_equal(x, [[0, 1], [0, 0], [0, 0]])
  270. class TestNewScalarIndexing:
  271. a = matrix([[1, 2], [3, 4]])
  272. def test_dimesions(self):
  273. a = self.a
  274. x = a[0]
  275. assert_equal(x.ndim, 2)
  276. def test_array_from_matrix_list(self):
  277. a = self.a
  278. x = np.array([a, a])
  279. assert_equal(x.shape, [2, 2, 2])
  280. def test_array_to_list(self):
  281. a = self.a
  282. assert_equal(a.tolist(), [[1, 2], [3, 4]])
  283. def test_fancy_indexing(self):
  284. a = self.a
  285. x = a[1, [0, 1, 0]]
  286. assert_(isinstance(x, matrix))
  287. assert_equal(x, matrix([[3, 4, 3]]))
  288. x = a[[1, 0]]
  289. assert_(isinstance(x, matrix))
  290. assert_equal(x, matrix([[3, 4], [1, 2]]))
  291. x = a[[[1], [0]], [[1, 0], [0, 1]]]
  292. assert_(isinstance(x, matrix))
  293. assert_equal(x, matrix([[4, 3], [1, 2]]))
  294. def test_matrix_element(self):
  295. x = matrix([[1, 2, 3], [4, 5, 6]])
  296. assert_equal(x[0][0], matrix([[1, 2, 3]]))
  297. assert_equal(x[0][0].shape, (1, 3))
  298. assert_equal(x[0].shape, (1, 3))
  299. assert_equal(x[:, 0].shape, (2, 1))
  300. x = matrix(0)
  301. assert_equal(x[0, 0], 0)
  302. assert_equal(x[0], 0)
  303. assert_equal(x[:, 0].shape, x.shape)
  304. def test_scalar_indexing(self):
  305. x = asmatrix(np.zeros((3, 2), float))
  306. assert_equal(x[0, 0], x[0][0])
  307. def test_row_column_indexing(self):
  308. x = asmatrix(np.eye(2))
  309. assert_array_equal(x[0, :], [[1, 0]])
  310. assert_array_equal(x[1, :], [[0, 1]])
  311. assert_array_equal(x[:, 0], [[1], [0]])
  312. assert_array_equal(x[:, 1], [[0], [1]])
  313. def test_boolean_indexing(self):
  314. A = np.arange(6)
  315. A.shape = (3, 2)
  316. x = asmatrix(A)
  317. assert_array_equal(x[:, np.array([True, False])], x[:, 0])
  318. assert_array_equal(x[np.array([True, False, False]), :], x[0, :])
  319. def test_list_indexing(self):
  320. A = np.arange(6)
  321. A.shape = (3, 2)
  322. x = asmatrix(A)
  323. assert_array_equal(x[:, [1, 0]], x[:, ::-1])
  324. assert_array_equal(x[[2, 1, 0], :], x[::-1, :])
  325. class TestPower:
  326. def test_returntype(self):
  327. a = np.array([[0, 1], [0, 0]])
  328. assert_(type(matrix_power(a, 2)) is np.ndarray)
  329. a = asmatrix(a)
  330. assert_(type(matrix_power(a, 2)) is matrix)
  331. def test_list(self):
  332. assert_array_equal(matrix_power([[0, 1], [0, 0]], 2), [[0, 0], [0, 0]])
  333. class TestShape:
  334. a = np.array([[1], [2]])
  335. m = matrix([[1], [2]])
  336. def test_shape(self):
  337. assert_equal(self.a.shape, (2, 1))
  338. assert_equal(self.m.shape, (2, 1))
  339. def test_numpy_ravel(self):
  340. assert_equal(np.ravel(self.a).shape, (2,))
  341. assert_equal(np.ravel(self.m).shape, (2,))
  342. def test_member_ravel(self):
  343. assert_equal(self.a.ravel().shape, (2,))
  344. assert_equal(self.m.ravel().shape, (1, 2))
  345. def test_member_flatten(self):
  346. assert_equal(self.a.flatten().shape, (2,))
  347. assert_equal(self.m.flatten().shape, (1, 2))
  348. def test_numpy_ravel_order(self):
  349. x = np.array([[1, 2, 3], [4, 5, 6]])
  350. assert_equal(np.ravel(x), [1, 2, 3, 4, 5, 6])
  351. assert_equal(np.ravel(x, order='F'), [1, 4, 2, 5, 3, 6])
  352. assert_equal(np.ravel(x.T), [1, 4, 2, 5, 3, 6])
  353. assert_equal(np.ravel(x.T, order='A'), [1, 2, 3, 4, 5, 6])
  354. x = matrix([[1, 2, 3], [4, 5, 6]])
  355. assert_equal(np.ravel(x), [1, 2, 3, 4, 5, 6])
  356. assert_equal(np.ravel(x, order='F'), [1, 4, 2, 5, 3, 6])
  357. assert_equal(np.ravel(x.T), [1, 4, 2, 5, 3, 6])
  358. assert_equal(np.ravel(x.T, order='A'), [1, 2, 3, 4, 5, 6])
  359. def test_matrix_ravel_order(self):
  360. x = matrix([[1, 2, 3], [4, 5, 6]])
  361. assert_equal(x.ravel(), [[1, 2, 3, 4, 5, 6]])
  362. assert_equal(x.ravel(order='F'), [[1, 4, 2, 5, 3, 6]])
  363. assert_equal(x.T.ravel(), [[1, 4, 2, 5, 3, 6]])
  364. assert_equal(x.T.ravel(order='A'), [[1, 2, 3, 4, 5, 6]])
  365. def test_array_memory_sharing(self):
  366. assert_(np.may_share_memory(self.a, self.a.ravel()))
  367. assert_(not np.may_share_memory(self.a, self.a.flatten()))
  368. def test_matrix_memory_sharing(self):
  369. assert_(np.may_share_memory(self.m, self.m.ravel()))
  370. assert_(not np.may_share_memory(self.m, self.m.flatten()))
  371. def test_expand_dims_matrix(self):
  372. # matrices are always 2d - so expand_dims only makes sense when the
  373. # type is changed away from matrix.
  374. a = np.arange(10).reshape((2, 5)).view(np.matrix)
  375. expanded = np.expand_dims(a, axis=1)
  376. assert_equal(expanded.ndim, 3)
  377. assert_(not isinstance(expanded, np.matrix))