test_matrix_linalg.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """ Test functions for linalg module using the matrix class."""
  2. import pytest
  3. import numpy as np
  4. from numpy.linalg.tests.test_linalg import (
  5. CondCases,
  6. DetCases,
  7. EigCases,
  8. EigvalsCases,
  9. InvCases,
  10. LinalgCase,
  11. LinalgTestCase,
  12. LstsqCases,
  13. PinvCases,
  14. SolveCases,
  15. SVDCases,
  16. TestQR as _TestQR,
  17. _TestNorm2D,
  18. _TestNormDoubleBase,
  19. _TestNormInt64Base,
  20. _TestNormSingleBase,
  21. apply_tag,
  22. )
  23. CASES = []
  24. # square test cases
  25. CASES += apply_tag('square', [
  26. LinalgCase("0x0_matrix",
  27. np.empty((0, 0), dtype=np.double).view(np.matrix),
  28. np.empty((0, 1), dtype=np.double).view(np.matrix),
  29. tags={'size-0'}),
  30. LinalgCase("matrix_b_only",
  31. np.array([[1., 2.], [3., 4.]]),
  32. np.matrix([2., 1.]).T),
  33. LinalgCase("matrix_a_and_b",
  34. np.matrix([[1., 2.], [3., 4.]]),
  35. np.matrix([2., 1.]).T),
  36. ])
  37. # hermitian test-cases
  38. CASES += apply_tag('hermitian', [
  39. LinalgCase("hmatrix_a_and_b",
  40. np.matrix([[1., 2.], [2., 1.]]),
  41. None),
  42. ])
  43. # No need to make generalized or strided cases for matrices.
  44. class MatrixTestCase(LinalgTestCase):
  45. TEST_CASES = CASES
  46. class TestSolveMatrix(SolveCases, MatrixTestCase):
  47. pass
  48. class TestInvMatrix(InvCases, MatrixTestCase):
  49. pass
  50. class TestEigvalsMatrix(EigvalsCases, MatrixTestCase):
  51. pass
  52. class TestEigMatrix(EigCases, MatrixTestCase):
  53. pass
  54. class TestSVDMatrix(SVDCases, MatrixTestCase):
  55. pass
  56. class TestCondMatrix(CondCases, MatrixTestCase):
  57. pass
  58. class TestPinvMatrix(PinvCases, MatrixTestCase):
  59. pass
  60. class TestDetMatrix(DetCases, MatrixTestCase):
  61. pass
  62. @pytest.mark.thread_unsafe(
  63. reason="residuals not calculated properly for square tests (gh-29851)"
  64. )
  65. class TestLstsqMatrix(LstsqCases, MatrixTestCase):
  66. pass
  67. class _TestNorm2DMatrix(_TestNorm2D):
  68. array = np.matrix
  69. class TestNormDoubleMatrix(_TestNorm2DMatrix, _TestNormDoubleBase):
  70. pass
  71. class TestNormSingleMatrix(_TestNorm2DMatrix, _TestNormSingleBase):
  72. pass
  73. class TestNormInt64Matrix(_TestNorm2DMatrix, _TestNormInt64Base):
  74. pass
  75. class TestQRMatrix(_TestQR):
  76. array = np.matrix