| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- """ Test functions for linalg module using the matrix class."""
- import pytest
- import numpy as np
- from numpy.linalg.tests.test_linalg import (
- CondCases,
- DetCases,
- EigCases,
- EigvalsCases,
- InvCases,
- LinalgCase,
- LinalgTestCase,
- LstsqCases,
- PinvCases,
- SolveCases,
- SVDCases,
- TestQR as _TestQR,
- _TestNorm2D,
- _TestNormDoubleBase,
- _TestNormInt64Base,
- _TestNormSingleBase,
- apply_tag,
- )
- CASES = []
- # square test cases
- CASES += apply_tag('square', [
- LinalgCase("0x0_matrix",
- np.empty((0, 0), dtype=np.double).view(np.matrix),
- np.empty((0, 1), dtype=np.double).view(np.matrix),
- tags={'size-0'}),
- LinalgCase("matrix_b_only",
- np.array([[1., 2.], [3., 4.]]),
- np.matrix([2., 1.]).T),
- LinalgCase("matrix_a_and_b",
- np.matrix([[1., 2.], [3., 4.]]),
- np.matrix([2., 1.]).T),
- ])
- # hermitian test-cases
- CASES += apply_tag('hermitian', [
- LinalgCase("hmatrix_a_and_b",
- np.matrix([[1., 2.], [2., 1.]]),
- None),
- ])
- # No need to make generalized or strided cases for matrices.
- class MatrixTestCase(LinalgTestCase):
- TEST_CASES = CASES
- class TestSolveMatrix(SolveCases, MatrixTestCase):
- pass
- class TestInvMatrix(InvCases, MatrixTestCase):
- pass
- class TestEigvalsMatrix(EigvalsCases, MatrixTestCase):
- pass
- class TestEigMatrix(EigCases, MatrixTestCase):
- pass
- class TestSVDMatrix(SVDCases, MatrixTestCase):
- pass
- class TestCondMatrix(CondCases, MatrixTestCase):
- pass
- class TestPinvMatrix(PinvCases, MatrixTestCase):
- pass
- class TestDetMatrix(DetCases, MatrixTestCase):
- pass
- @pytest.mark.thread_unsafe(
- reason="residuals not calculated properly for square tests (gh-29851)"
- )
- class TestLstsqMatrix(LstsqCases, MatrixTestCase):
- pass
- class _TestNorm2DMatrix(_TestNorm2D):
- array = np.matrix
- class TestNormDoubleMatrix(_TestNorm2DMatrix, _TestNormDoubleBase):
- pass
- class TestNormSingleMatrix(_TestNorm2DMatrix, _TestNormSingleBase):
- pass
- class TestNormInt64Matrix(_TestNorm2DMatrix, _TestNormInt64Base):
- pass
- class TestQRMatrix(_TestQR):
- array = np.matrix
|