test__remove_redundancy.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """
  2. Unit test for Linear Programming via Simplex Algorithm.
  3. """
  4. # TODO: add tests for:
  5. # https://github.com/scipy/scipy/issues/5400
  6. # https://github.com/scipy/scipy/issues/6690
  7. import numpy as np
  8. from numpy.testing import (
  9. assert_,
  10. assert_allclose,
  11. assert_equal)
  12. from .test_linprog import magic_square
  13. from scipy.optimize._remove_redundancy import _remove_redundancy_svd
  14. from scipy.optimize._remove_redundancy import _remove_redundancy_pivot_dense
  15. from scipy.optimize._remove_redundancy import _remove_redundancy_pivot_sparse
  16. from scipy.optimize._remove_redundancy import _remove_redundancy_id
  17. from scipy.sparse import csc_array
  18. def redundancy_removed(A, B):
  19. """Checks whether a matrix contains only independent rows of another"""
  20. for rowA in A:
  21. # `rowA in B` is not a reliable check
  22. for rowB in B:
  23. if np.all(rowA == rowB):
  24. break
  25. else:
  26. return False
  27. return A.shape[0] == np.linalg.matrix_rank(A) == np.linalg.matrix_rank(B)
  28. class RRCommonTests:
  29. def setup_method(self):
  30. self.rng = np.random.default_rng(2017)
  31. def test_no_redundancy(self):
  32. m, n = 10, 10
  33. A0 = self.rng.random((m, n))
  34. b0 = self.rng.random(m)
  35. A1, b1, status, message = self.rr(A0, b0)
  36. assert_allclose(A0, A1)
  37. assert_allclose(b0, b1)
  38. assert_equal(status, 0)
  39. def test_infeasible_zero_row(self):
  40. A = np.eye(3)
  41. A[1, :] = 0
  42. b = self.rng.random(3)
  43. A1, b1, status, message = self.rr(A, b)
  44. assert_equal(status, 2)
  45. def test_remove_zero_row(self):
  46. A = np.eye(3)
  47. A[1, :] = 0
  48. b = self.rng.random(3)
  49. b[1] = 0
  50. A1, b1, status, message = self.rr(A, b)
  51. assert_equal(status, 0)
  52. assert_allclose(A1, A[[0, 2], :])
  53. assert_allclose(b1, b[[0, 2]])
  54. def test_infeasible_m_gt_n(self):
  55. m, n = 20, 10
  56. A0 = self.rng.random((m, n))
  57. b0 = self.rng.random(m)
  58. A1, b1, status, message = self.rr(A0, b0)
  59. assert_equal(status, 2)
  60. def test_infeasible_m_eq_n(self):
  61. m, n = 10, 10
  62. A0 = self.rng.random((m, n))
  63. b0 = self.rng.random(m)
  64. A0[-1, :] = 2 * A0[-2, :]
  65. A1, b1, status, message = self.rr(A0, b0)
  66. assert_equal(status, 2)
  67. def test_infeasible_m_lt_n(self):
  68. m, n = 9, 10
  69. A0 = self.rng.random((m, n))
  70. b0 = self.rng.random(m)
  71. A0[-1, :] = np.arange(m - 1).dot(A0[:-1])
  72. A1, b1, status, message = self.rr(A0, b0)
  73. assert_equal(status, 2)
  74. def test_m_gt_n(self):
  75. rng = np.random.default_rng(2032)
  76. m, n = 20, 10
  77. A0 = rng.random((m, n))
  78. b0 = rng.random(m)
  79. x = np.linalg.solve(A0[:n, :], b0[:n])
  80. b0[n:] = A0[n:, :].dot(x)
  81. A1, b1, status, message = self.rr(A0, b0)
  82. assert_equal(status, 0)
  83. assert_equal(A1.shape[0], n)
  84. assert_equal(np.linalg.matrix_rank(A1), n)
  85. def test_m_gt_n_rank_deficient(self):
  86. m, n = 20, 10
  87. A0 = np.zeros((m, n))
  88. A0[:, 0] = 1
  89. b0 = np.ones(m)
  90. A1, b1, status, message = self.rr(A0, b0)
  91. assert_equal(status, 0)
  92. assert_allclose(A1, A0[0:1, :])
  93. assert_allclose(b1, b0[0])
  94. def test_m_lt_n_rank_deficient(self):
  95. m, n = 9, 10
  96. A0 = self.rng.random((m, n))
  97. b0 = self.rng.random(m)
  98. A0[-1, :] = np.arange(m - 1).dot(A0[:-1])
  99. b0[-1] = np.arange(m - 1).dot(b0[:-1])
  100. A1, b1, status, message = self.rr(A0, b0)
  101. assert_equal(status, 0)
  102. assert_equal(A1.shape[0], 8)
  103. assert_equal(np.linalg.matrix_rank(A1), 8)
  104. def test_dense1(self):
  105. A = np.ones((6, 6))
  106. A[0, :3] = 0
  107. A[1, 3:] = 0
  108. A[3:, ::2] = -1
  109. A[3, :2] = 0
  110. A[4, 2:] = 0
  111. b = np.zeros(A.shape[0])
  112. A1, b1, status, message = self.rr(A, b)
  113. assert_(redundancy_removed(A1, A))
  114. assert_equal(status, 0)
  115. def test_dense2(self):
  116. A = np.eye(6)
  117. A[-2, -1] = 1
  118. A[-1, :] = 1
  119. b = np.zeros(A.shape[0])
  120. A1, b1, status, message = self.rr(A, b)
  121. assert_(redundancy_removed(A1, A))
  122. assert_equal(status, 0)
  123. def test_dense3(self):
  124. A = np.eye(6)
  125. A[-2, -1] = 1
  126. A[-1, :] = 1
  127. b = self.rng.random(A.shape[0])
  128. b[-1] = np.sum(b[:-1])
  129. A1, b1, status, message = self.rr(A, b)
  130. assert_(redundancy_removed(A1, A))
  131. assert_equal(status, 0)
  132. def test_m_gt_n_sparse(self):
  133. rng = np.random.default_rng(2013)
  134. m, n = 20, 5
  135. p = 0.1
  136. A = rng.random((m, n))
  137. A[rng.random((m, n)) > p] = 0
  138. rank = np.linalg.matrix_rank(A)
  139. b = np.zeros(A.shape[0])
  140. A1, b1, status, message = self.rr(A, b)
  141. assert_equal(status, 0)
  142. assert_equal(A1.shape[0], rank)
  143. assert_equal(np.linalg.matrix_rank(A1), rank)
  144. def test_m_lt_n_sparse(self):
  145. rng = np.random.default_rng(2017)
  146. m, n = 20, 50
  147. p = 0.05
  148. A = rng.random((m, n))
  149. A[rng.random((m, n)) > p] = 0
  150. rank = np.linalg.matrix_rank(A)
  151. b = np.zeros(A.shape[0])
  152. A1, b1, status, message = self.rr(A, b)
  153. assert_equal(status, 0)
  154. assert_equal(A1.shape[0], rank)
  155. assert_equal(np.linalg.matrix_rank(A1), rank)
  156. def test_m_eq_n_sparse(self):
  157. rng = np.random.default_rng(2017)
  158. m, n = 100, 100
  159. p = 0.01
  160. A = rng.random((m, n))
  161. A[rng.random((m, n)) > p] = 0
  162. rank = np.linalg.matrix_rank(A)
  163. b = np.zeros(A.shape[0])
  164. A1, b1, status, message = self.rr(A, b)
  165. assert_equal(status, 0)
  166. assert_equal(A1.shape[0], rank)
  167. assert_equal(np.linalg.matrix_rank(A1), rank)
  168. def test_magic_square(self):
  169. A, b, c, numbers, _ = magic_square(3)
  170. A1, b1, status, message = self.rr(A, b)
  171. assert_equal(status, 0)
  172. assert_equal(A1.shape[0], 23)
  173. assert_equal(np.linalg.matrix_rank(A1), 23)
  174. def test_magic_square2(self):
  175. A, b, c, numbers, _ = magic_square(4)
  176. A1, b1, status, message = self.rr(A, b)
  177. assert_equal(status, 0)
  178. assert_equal(A1.shape[0], 39)
  179. assert_equal(np.linalg.matrix_rank(A1), 39)
  180. class TestRRSVD(RRCommonTests):
  181. def rr(self, A, b):
  182. return _remove_redundancy_svd(A, b)
  183. class TestRRPivotDense(RRCommonTests):
  184. def rr(self, A, b):
  185. return _remove_redundancy_pivot_dense(A, b)
  186. class TestRRID(RRCommonTests):
  187. def rr(self, A, b):
  188. return _remove_redundancy_id(A, b)
  189. class TestRRPivotSparse(RRCommonTests):
  190. def rr(self, A, b):
  191. rr_res = _remove_redundancy_pivot_sparse(csc_array(A), b)
  192. A1, b1, status, message = rr_res
  193. return A1.toarray(), b1, status, message