test_extract.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """test sparse matrix construction functions"""
  2. from numpy.testing import assert_equal
  3. from scipy.sparse import csr_matrix, csr_array, sparray
  4. import numpy as np
  5. from scipy.sparse import _extract
  6. class TestExtract:
  7. def setup_method(self):
  8. self.cases = [
  9. csr_array([[1,2]]),
  10. csr_array([[1,0]]),
  11. csr_array([[0,0]]),
  12. csr_array([[1],[2]]),
  13. csr_array([[1],[0]]),
  14. csr_array([[0],[0]]),
  15. csr_array([[1,2],[3,4]]),
  16. csr_array([[0,1],[0,0]]),
  17. csr_array([[0,0],[1,0]]),
  18. csr_array([[0,0],[0,0]]),
  19. csr_array([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]),
  20. csr_array([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]).T,
  21. ]
  22. def test_find(self):
  23. for A in self.cases:
  24. I,J,V = _extract.find(A)
  25. B = csr_array((V,(I,J)), shape=A.shape)
  26. assert_equal(A.toarray(), B.toarray())
  27. def test_tril(self):
  28. for A in self.cases:
  29. B = A.toarray()
  30. for k in [-3,-2,-1,0,1,2,3]:
  31. assert_equal(_extract.tril(A,k=k).toarray(), np.tril(B,k=k))
  32. def test_triu(self):
  33. for A in self.cases:
  34. B = A.toarray()
  35. for k in [-3,-2,-1,0,1,2,3]:
  36. assert_equal(_extract.triu(A,k=k).toarray(), np.triu(B,k=k))
  37. def test_array_vs_matrix(self):
  38. for A in self.cases:
  39. assert isinstance(_extract.tril(A), sparray)
  40. assert isinstance(_extract.triu(A), sparray)
  41. M = csr_matrix(A)
  42. assert not isinstance(_extract.tril(M), sparray)
  43. assert not isinstance(_extract.triu(M), sparray)