test_sputils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. """unit tests for sparse utility functions"""
  2. import numpy as np
  3. from numpy.testing import assert_equal
  4. import pytest
  5. from pytest import raises as assert_raises
  6. from scipy.sparse import _sputils as sputils, csr_array, bsr_array, dia_array, coo_array
  7. from scipy.sparse._sputils import matrix
  8. class TestSparseUtils:
  9. def test_upcast(self):
  10. assert_equal(sputils.upcast('intc'), np.intc)
  11. assert_equal(sputils.upcast('int32', 'float32'), np.float64)
  12. assert_equal(sputils.upcast('bool', complex, float), np.complex128)
  13. assert_equal(sputils.upcast('i', 'd'), np.float64)
  14. def test_getdtype(self):
  15. A = np.array([1], dtype='int8')
  16. assert_equal(sputils.getdtype(None, default=float), float)
  17. assert_equal(sputils.getdtype(None, a=A), np.int8)
  18. with assert_raises(
  19. ValueError,
  20. match="scipy.sparse does not support dtype object. .*",
  21. ):
  22. sputils.getdtype("O")
  23. with assert_raises(
  24. ValueError,
  25. match="scipy.sparse does not support dtype float16. .*",
  26. ):
  27. sputils.getdtype(None, default=np.float16)
  28. def test_isscalarlike(self):
  29. assert_equal(sputils.isscalarlike(3.0), True)
  30. assert_equal(sputils.isscalarlike(-4), True)
  31. assert_equal(sputils.isscalarlike(2.5), True)
  32. assert_equal(sputils.isscalarlike(1 + 3j), True)
  33. assert_equal(sputils.isscalarlike(np.array(3)), True)
  34. assert_equal(sputils.isscalarlike("16"), True)
  35. assert_equal(sputils.isscalarlike(np.array([3])), False)
  36. assert_equal(sputils.isscalarlike([[3]]), False)
  37. assert_equal(sputils.isscalarlike((1,)), False)
  38. assert_equal(sputils.isscalarlike((1, 2)), False)
  39. def test_isintlike(self):
  40. assert_equal(sputils.isintlike(-4), True)
  41. assert_equal(sputils.isintlike(np.array(3)), True)
  42. assert_equal(sputils.isintlike(np.array([3])), False)
  43. with assert_raises(
  44. ValueError,
  45. match="Inexact indices into sparse matrices are not allowed"
  46. ):
  47. sputils.isintlike(3.0)
  48. assert_equal(sputils.isintlike(2.5), False)
  49. assert_equal(sputils.isintlike(1 + 3j), False)
  50. assert_equal(sputils.isintlike((1,)), False)
  51. assert_equal(sputils.isintlike((1, 2)), False)
  52. def test_isshape(self):
  53. assert_equal(sputils.isshape((1, 2)), True)
  54. assert_equal(sputils.isshape((5, 2)), True)
  55. assert_equal(sputils.isshape((1.5, 2)), False)
  56. assert_equal(sputils.isshape((2, 2, 2)), False)
  57. assert_equal(sputils.isshape(([2], 2)), False)
  58. assert_equal(sputils.isshape((-1, 2), nonneg=False),True)
  59. assert_equal(sputils.isshape((2, -1), nonneg=False),True)
  60. assert_equal(sputils.isshape((-1, 2), nonneg=True),False)
  61. assert_equal(sputils.isshape((2, -1), nonneg=True),False)
  62. assert_equal(sputils.isshape((1.5, 2), allow_nd=(1, 2)), False)
  63. assert_equal(sputils.isshape(([2], 2), allow_nd=(1, 2)), False)
  64. assert_equal(sputils.isshape((2, 2, -2), nonneg=True, allow_nd=(1, 2)),
  65. False)
  66. assert_equal(sputils.isshape((2,), allow_nd=(1, 2)), True)
  67. assert_equal(sputils.isshape((2, 2,), allow_nd=(1, 2)), True)
  68. assert_equal(sputils.isshape((2, 2, 2), allow_nd=(1, 2)), False)
  69. def test_issequence(self):
  70. assert_equal(sputils.issequence((1,)), True)
  71. assert_equal(sputils.issequence((1, 2, 3)), True)
  72. assert_equal(sputils.issequence([1]), True)
  73. assert_equal(sputils.issequence([1, 2, 3]), True)
  74. assert_equal(sputils.issequence(np.array([1, 2, 3])), True)
  75. assert_equal(sputils.issequence(np.array([[1], [2], [3]])), False)
  76. assert_equal(sputils.issequence(3), False)
  77. def test_ismatrix(self):
  78. assert_equal(sputils.ismatrix(((),)), True)
  79. assert_equal(sputils.ismatrix([[1], [2]]), True)
  80. assert_equal(sputils.ismatrix(np.arange(3)[None]), True)
  81. assert_equal(sputils.ismatrix([1, 2]), False)
  82. assert_equal(sputils.ismatrix(np.arange(3)), False)
  83. assert_equal(sputils.ismatrix([[[1]]]), False)
  84. assert_equal(sputils.ismatrix(3), False)
  85. def test_isdense(self):
  86. assert_equal(sputils.isdense(np.array([1])), True)
  87. assert_equal(sputils.isdense(matrix([1])), True)
  88. def test_validateaxis(self):
  89. with assert_raises(ValueError, match="does not accept 0D axis"):
  90. sputils.validateaxis(())
  91. for ax in [1.5, (0, 1.5), (1.5, 0)]:
  92. with assert_raises(TypeError, match="must be an integer"):
  93. sputils.validateaxis(ax)
  94. for ax in [(1, 1), (1, -1), (0, -2)]:
  95. with assert_raises(ValueError, match="duplicate value in axis"):
  96. sputils.validateaxis(ax)
  97. # ndim 1
  98. for ax in [1, -2, (0, 1), (1, -1)]:
  99. with assert_raises(ValueError, match="out of range"):
  100. sputils.validateaxis(ax, ndim=1)
  101. with assert_raises(ValueError, match="duplicate value in axis"):
  102. sputils.validateaxis((0, -1), ndim=1)
  103. # all valid axis values lead to None when canonical
  104. for axis in (0, -1, None, (0,), (-1,)):
  105. assert sputils.validateaxis(axis, ndim=1) is None
  106. # ndim 2
  107. for ax in [5, -5, (0, 5), (-5, 0)]:
  108. with assert_raises(ValueError, match="out of range"):
  109. sputils.validateaxis(ax, ndim=2)
  110. for axis in ((0,), (1,), None):
  111. assert sputils.validateaxis(axis, ndim=2) == axis
  112. axis_2d = {-2: (0,), -1: (1,), 0: (0,), 1: (1,), (0, 1): None, (0, -1): None}
  113. for axis, canonical_axis in axis_2d.items():
  114. assert sputils.validateaxis(axis, ndim=2) == canonical_axis
  115. # ndim 4
  116. for axis in ((2,), (3,), (2, 3), (2, 1), (0, 3)):
  117. assert sputils.validateaxis(axis, ndim=4) == axis
  118. axis_4d = {-4: (0,), -3: (1,), 2: (2,), 3: (3,), (3, -4): (3, 0)}
  119. for axis, canonical_axis in axis_4d.items():
  120. sputils.validateaxis(axis, ndim=4) == canonical_axis
  121. @pytest.mark.parametrize("container", [csr_array, bsr_array])
  122. def test_safely_cast_index_compressed(self, container):
  123. # This is slow to test completely as nnz > imax is big
  124. # and indptr is big for some shapes
  125. # So we don't test large nnz, nor csc_array (same code as csr_array)
  126. imax = np.int64(np.iinfo(np.int32).max)
  127. # Shape 32bit
  128. A32 = container((1, imax))
  129. # indices big type, small values
  130. B32 = A32.copy()
  131. B32.indices = B32.indices.astype(np.int64)
  132. B32.indptr = B32.indptr.astype(np.int64)
  133. # Shape 64bit
  134. # indices big type, small values
  135. A64 = csr_array((1, imax + 1))
  136. # indices small type, small values
  137. B64 = A64.copy()
  138. B64.indices = B64.indices.astype(np.int32)
  139. B64.indptr = B64.indptr.astype(np.int32)
  140. # indices big type, big values
  141. C64 = A64.copy()
  142. C64.indices = np.array([imax + 1], dtype=np.int64)
  143. C64.indptr = np.array([0, 1], dtype=np.int64)
  144. C64.data = np.array([2.2])
  145. assert (A32.indices.dtype, A32.indptr.dtype) == (np.int32, np.int32)
  146. assert (B32.indices.dtype, B32.indptr.dtype) == (np.int64, np.int64)
  147. assert (A64.indices.dtype, A64.indptr.dtype) == (np.int64, np.int64)
  148. assert (B64.indices.dtype, B64.indptr.dtype) == (np.int32, np.int32)
  149. assert (C64.indices.dtype, C64.indptr.dtype) == (np.int64, np.int64)
  150. for A in [A32, B32, A64, B64]:
  151. indices, indptr = sputils.safely_cast_index_arrays(A, np.int32)
  152. assert (indices.dtype, indptr.dtype) == (np.int32, np.int32)
  153. indices, indptr = sputils.safely_cast_index_arrays(A, np.int64)
  154. assert (indices.dtype, indptr.dtype) == (np.int64, np.int64)
  155. indices, indptr = sputils.safely_cast_index_arrays(A, A.indices.dtype)
  156. assert indices is A.indices
  157. assert indptr is A.indptr
  158. with assert_raises(ValueError):
  159. sputils.safely_cast_index_arrays(C64, np.int32)
  160. indices, indptr = sputils.safely_cast_index_arrays(C64, np.int64)
  161. assert indices is C64.indices
  162. assert indptr is C64.indptr
  163. def test_safely_cast_index_coo(self):
  164. # This is slow to test completely as nnz > imax is big
  165. # So we don't test large nnz
  166. imax = np.int64(np.iinfo(np.int32).max)
  167. # Shape 32bit
  168. A32 = coo_array((1, imax))
  169. # coords big type, small values
  170. B32 = A32.copy()
  171. B32.coords = tuple(co.astype(np.int64) for co in B32.coords)
  172. # Shape 64bit
  173. # coords big type, small values
  174. A64 = coo_array((1, imax + 1))
  175. # coords small type, small values
  176. B64 = A64.copy()
  177. B64.coords = tuple(co.astype(np.int32) for co in B64.coords)
  178. # coords big type, big values
  179. C64 = A64.copy()
  180. C64.coords = (np.array([imax + 1]), np.array([0]))
  181. C64.data = np.array([2.2])
  182. assert A32.coords[0].dtype == np.int32
  183. assert B32.coords[0].dtype == np.int64
  184. assert A64.coords[0].dtype == np.int64
  185. assert B64.coords[0].dtype == np.int32
  186. assert C64.coords[0].dtype == np.int64
  187. for A in [A32, B32, A64, B64]:
  188. coords = sputils.safely_cast_index_arrays(A, np.int32)
  189. assert coords[0].dtype == np.int32
  190. coords = sputils.safely_cast_index_arrays(A, np.int64)
  191. assert coords[0].dtype == np.int64
  192. coords = sputils.safely_cast_index_arrays(A, A.coords[0].dtype)
  193. assert coords[0] is A.coords[0]
  194. with assert_raises(ValueError):
  195. sputils.safely_cast_index_arrays(C64, np.int32)
  196. coords = sputils.safely_cast_index_arrays(C64, np.int64)
  197. assert coords[0] is C64.coords[0]
  198. def test_safely_cast_index_dia(self):
  199. # This is slow to test completely as nnz > imax is big
  200. # So we don't test large nnz
  201. imax = np.int64(np.iinfo(np.int32).max)
  202. # Shape 32bit
  203. A32 = dia_array((1, imax))
  204. # offsets big type, small values
  205. B32 = A32.copy()
  206. B32.offsets = B32.offsets.astype(np.int64)
  207. # Shape 64bit
  208. # offsets big type, small values
  209. A64 = dia_array((1, imax + 2))
  210. # offsets small type, small values
  211. B64 = A64.copy()
  212. B64.offsets = B64.offsets.astype(np.int32)
  213. # offsets big type, big values
  214. C64 = A64.copy()
  215. C64.offsets = np.array([imax + 1])
  216. C64.data = np.array([2.2])
  217. assert A32.offsets.dtype == np.int32
  218. assert B32.offsets.dtype == np.int64
  219. assert A64.offsets.dtype == np.int64
  220. assert B64.offsets.dtype == np.int32
  221. assert C64.offsets.dtype == np.int64
  222. for A in [A32, B32, A64, B64]:
  223. offsets = sputils.safely_cast_index_arrays(A, np.int32)
  224. assert offsets.dtype == np.int32
  225. offsets = sputils.safely_cast_index_arrays(A, np.int64)
  226. assert offsets.dtype == np.int64
  227. offsets = sputils.safely_cast_index_arrays(A, A.offsets.dtype)
  228. assert offsets is A.offsets
  229. with assert_raises(ValueError):
  230. sputils.safely_cast_index_arrays(C64, np.int32)
  231. offsets = sputils.safely_cast_index_arrays(C64, np.int64)
  232. assert offsets is C64.offsets
  233. def test_get_index_dtype(self):
  234. imax = np.int64(np.iinfo(np.int32).max)
  235. too_big = imax + 1
  236. # Check that uint32's with no values too large doesn't return
  237. # int64
  238. a1 = np.ones(90, dtype='uint32')
  239. a2 = np.ones(90, dtype='uint32')
  240. assert_equal(
  241. np.dtype(sputils.get_index_dtype((a1, a2), check_contents=True)),
  242. np.dtype('int32')
  243. )
  244. # Check that if we can not convert but all values are less than or
  245. # equal to max that we can just convert to int32
  246. a1[-1] = imax
  247. assert_equal(
  248. np.dtype(sputils.get_index_dtype((a1, a2), check_contents=True)),
  249. np.dtype('int32')
  250. )
  251. # Check that if it can not convert directly and the contents are
  252. # too large that we return int64
  253. a1[-1] = too_big
  254. assert_equal(
  255. np.dtype(sputils.get_index_dtype((a1, a2), check_contents=True)),
  256. np.dtype('int64')
  257. )
  258. # test that if can not convert and didn't specify to check_contents
  259. # we return int64
  260. a1 = np.ones(89, dtype='uint32')
  261. a2 = np.ones(89, dtype='uint32')
  262. assert_equal(
  263. np.dtype(sputils.get_index_dtype((a1, a2))),
  264. np.dtype('int64')
  265. )
  266. # Check that even if we have arrays that can be converted directly
  267. # that if we specify a maxval directly it takes precedence
  268. a1 = np.ones(12, dtype='uint32')
  269. a2 = np.ones(12, dtype='uint32')
  270. assert_equal(
  271. np.dtype(sputils.get_index_dtype(
  272. (a1, a2), maxval=too_big, check_contents=True
  273. )),
  274. np.dtype('int64')
  275. )
  276. # Check that an array with a too max size and maxval set
  277. # still returns int64
  278. a1[-1] = too_big
  279. assert_equal(
  280. np.dtype(sputils.get_index_dtype((a1, a2), maxval=too_big)),
  281. np.dtype('int64')
  282. )
  283. # tests public broadcast_shapes largely from
  284. # numpy/numpy/lib/tests/test_stride_tricks.py
  285. # first 3 cause np.broadcast to raise index too large, but not sputils
  286. @pytest.mark.parametrize("input_shapes,target_shape", [
  287. [((6, 5, 1, 4, 1, 1), (1, 2**32), (2**32, 1)), (6, 5, 1, 4, 2**32, 2**32)],
  288. [((6, 5, 1, 4, 1, 1), (1, 2**32)), (6, 5, 1, 4, 1, 2**32)],
  289. [((1, 2**32), (2**32, 1)), (2**32, 2**32)],
  290. [[2, 2, 2], (2,)],
  291. [[], ()],
  292. [[()], ()],
  293. [[(7,)], (7,)],
  294. [[(1, 2), (2,)], (1, 2)],
  295. [[(2,), (1, 2)], (1, 2)],
  296. [[(1, 1)], (1, 1)],
  297. [[(1, 1), (3, 4)], (3, 4)],
  298. [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
  299. [[(5, 6, 1)], (5, 6, 1)],
  300. [[(1, 3), (3, 1)], (3, 3)],
  301. [[(1, 0), (0, 0)], (0, 0)],
  302. [[(0, 1), (0, 0)], (0, 0)],
  303. [[(1, 0), (0, 1)], (0, 0)],
  304. [[(1, 1), (0, 0)], (0, 0)],
  305. [[(1, 1), (1, 0)], (1, 0)],
  306. [[(1, 1), (0, 1)], (0, 1)],
  307. [[(), (0,)], (0,)],
  308. [[(0,), (0, 0)], (0, 0)],
  309. [[(0,), (0, 1)], (0, 0)],
  310. [[(1,), (0, 0)], (0, 0)],
  311. [[(), (0, 0)], (0, 0)],
  312. [[(1, 1), (0,)], (1, 0)],
  313. [[(1,), (0, 1)], (0, 1)],
  314. [[(1,), (1, 0)], (1, 0)],
  315. [[(), (1, 0)], (1, 0)],
  316. [[(), (0, 1)], (0, 1)],
  317. [[(1,), (3,)], (3,)],
  318. [[2, (3, 2)], (3, 2)],
  319. [[(1, 2)] * 32, (1, 2)],
  320. [[(1, 2)] * 100, (1, 2)],
  321. [[(2,)] * 32, (2,)],
  322. ])
  323. def test_broadcast_shapes_successes(self, input_shapes, target_shape):
  324. assert_equal(sputils.broadcast_shapes(*input_shapes), target_shape)
  325. # tests public broadcast_shapes failures
  326. @pytest.mark.parametrize("input_shapes", [
  327. [(3,), (4,)],
  328. [(2, 3), (2,)],
  329. [2, (2, 3)],
  330. [(3,), (3,), (4,)],
  331. [(2, 5), (3, 5)],
  332. [(2, 4), (2, 5)],
  333. [(1, 3, 4), (2, 3, 3)],
  334. [(1, 2), (3, 1), (3, 2), (10, 5)],
  335. [(2,)] * 32 + [(3,)] * 32,
  336. ])
  337. def test_broadcast_shapes_failures(self, input_shapes):
  338. with assert_raises(ValueError, match="cannot be broadcast"):
  339. sputils.broadcast_shapes(*input_shapes)
  340. def test_check_shape_overflow(self):
  341. new_shape = sputils.check_shape([(10, -1)], (65535, 131070))
  342. assert_equal(new_shape, (10, 858967245))
  343. def test_matrix(self):
  344. a = [[1, 2, 3]]
  345. b = np.array(a)
  346. assert isinstance(sputils.matrix(a), np.matrix)
  347. assert isinstance(sputils.matrix(b), np.matrix)
  348. c = sputils.matrix(b)
  349. c[:, :] = 123
  350. assert_equal(b, a)
  351. c = sputils.matrix(b, copy=False)
  352. c[:, :] = 123
  353. assert_equal(b, [[123, 123, 123]])
  354. def test_asmatrix(self):
  355. a = [[1, 2, 3]]
  356. b = np.array(a)
  357. assert isinstance(sputils.asmatrix(a), np.matrix)
  358. assert isinstance(sputils.asmatrix(b), np.matrix)
  359. c = sputils.asmatrix(b)
  360. c[:, :] = 123
  361. assert_equal(b, [[123, 123, 123]])