test_matrix_io.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import os
  2. import numpy as np
  3. import tempfile
  4. import pytest
  5. from pytest import raises as assert_raises
  6. from numpy.testing import assert_equal, assert_
  7. from scipy.sparse import (sparray, csr_array, coo_array, save_npz, load_npz,
  8. csc_matrix, csr_matrix, bsr_matrix, dia_matrix,
  9. coo_matrix, dok_matrix)
  10. DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
  11. def _save_and_load(matrix):
  12. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  13. os.close(fd)
  14. try:
  15. save_npz(tmpfile, matrix)
  16. loaded_matrix = load_npz(tmpfile)
  17. finally:
  18. os.remove(tmpfile)
  19. return loaded_matrix
  20. def _check_save_and_load(dense_matrix):
  21. for matrix_class in [csc_matrix, csr_matrix, bsr_matrix, dia_matrix, coo_matrix]:
  22. matrix = matrix_class(dense_matrix)
  23. loaded_matrix = _save_and_load(matrix)
  24. assert_(type(loaded_matrix) is matrix_class)
  25. assert_(loaded_matrix.shape == dense_matrix.shape)
  26. assert_(loaded_matrix.dtype == dense_matrix.dtype)
  27. assert_equal(loaded_matrix.toarray(), dense_matrix)
  28. def test_save_and_load_random():
  29. N = 10
  30. np.random.seed(0)
  31. dense_matrix = np.random.random((N, N))
  32. dense_matrix[dense_matrix > 0.7] = 0
  33. _check_save_and_load(dense_matrix)
  34. def test_save_and_load_empty():
  35. dense_matrix = np.zeros((4,6))
  36. _check_save_and_load(dense_matrix)
  37. def test_save_and_load_one_entry():
  38. dense_matrix = np.zeros((4,6))
  39. dense_matrix[1,2] = 1
  40. _check_save_and_load(dense_matrix)
  41. def test_sparray_vs_spmatrix():
  42. #save/load matrix
  43. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  44. os.close(fd)
  45. try:
  46. save_npz(tmpfile, csr_matrix([[1.2, 0, 0.9], [0, 0.3, 0]]))
  47. loaded_matrix = load_npz(tmpfile)
  48. finally:
  49. os.remove(tmpfile)
  50. #save/load array
  51. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  52. os.close(fd)
  53. try:
  54. save_npz(tmpfile, csr_array([[1.2, 0, 0.9], [0, 0.3, 0]]))
  55. loaded_array = load_npz(tmpfile)
  56. finally:
  57. os.remove(tmpfile)
  58. assert not isinstance(loaded_matrix, sparray)
  59. assert isinstance(loaded_array, sparray)
  60. assert_(loaded_matrix.dtype == loaded_array.dtype)
  61. assert_equal(loaded_matrix.toarray(), loaded_array.toarray())
  62. @pytest.mark.parametrize("value", [0, 1.2])
  63. @pytest.mark.parametrize("ndim", [1, 2, 3])
  64. def test_nd_coo_format(ndim, value):
  65. A = coo_array([value]).reshape((1,) * ndim)
  66. #save/load array
  67. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  68. os.close(fd)
  69. try:
  70. save_npz(tmpfile, A)
  71. loaded_A = load_npz(tmpfile)
  72. finally:
  73. os.remove(tmpfile)
  74. assert isinstance(loaded_A, coo_array)
  75. assert_(loaded_A.shape == A.shape)
  76. assert_equal(A.toarray(), loaded_A.toarray())
  77. def test_malicious_load():
  78. class Executor:
  79. def __reduce__(self):
  80. return (assert_, (False, 'unexpected code execution'))
  81. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  82. os.close(fd)
  83. try:
  84. np.savez(tmpfile, format=Executor())
  85. # Should raise a ValueError, not execute code
  86. assert_raises(ValueError, load_npz, tmpfile)
  87. finally:
  88. os.remove(tmpfile)
  89. def test_py23_compatibility():
  90. # Try loading files saved on Python 2 and Python 3. They are not
  91. # the same, since files saved with SciPy versions < 1.0.0 may
  92. # contain unicode.
  93. a = load_npz(os.path.join(DATA_DIR, 'csc_py2.npz'))
  94. b = load_npz(os.path.join(DATA_DIR, 'csc_py3.npz'))
  95. c = csc_matrix([[0]])
  96. assert_equal(a.toarray(), c.toarray())
  97. assert_equal(b.toarray(), c.toarray())
  98. def test_implemented_error():
  99. # Attempts to save an unsupported type and checks that an
  100. # NotImplementedError is raised.
  101. x = dok_matrix((2,3))
  102. x[0,1] = 1
  103. assert_raises(NotImplementedError, save_npz, 'x.npz', x)