test_matching.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from itertools import product
  2. import numpy as np
  3. from numpy.testing import assert_array_equal, assert_equal
  4. import pytest
  5. from scipy.sparse import csr_array, coo_array, diags_array
  6. from scipy.sparse.csgraph import (
  7. maximum_bipartite_matching, min_weight_full_bipartite_matching
  8. )
  9. def test_maximum_bipartite_matching_raises_on_dense_input():
  10. with pytest.raises(TypeError):
  11. graph = np.array([[0, 1], [0, 0]])
  12. maximum_bipartite_matching(graph)
  13. def test_maximum_bipartite_matching_empty_graph():
  14. graph = csr_array((0, 0))
  15. x = maximum_bipartite_matching(graph, perm_type='row')
  16. y = maximum_bipartite_matching(graph, perm_type='column')
  17. expected_matching = np.array([])
  18. assert_array_equal(expected_matching, x)
  19. assert_array_equal(expected_matching, y)
  20. def test_maximum_bipartite_matching_empty_left_partition():
  21. graph = csr_array((2, 0))
  22. x = maximum_bipartite_matching(graph, perm_type='row')
  23. y = maximum_bipartite_matching(graph, perm_type='column')
  24. assert_array_equal(np.array([]), x)
  25. assert_array_equal(np.array([-1, -1]), y)
  26. def test_maximum_bipartite_matching_empty_right_partition():
  27. graph = csr_array((0, 3))
  28. x = maximum_bipartite_matching(graph, perm_type='row')
  29. y = maximum_bipartite_matching(graph, perm_type='column')
  30. assert_array_equal(np.array([-1, -1, -1]), x)
  31. assert_array_equal(np.array([]), y)
  32. def test_maximum_bipartite_matching_graph_with_no_edges():
  33. graph = csr_array((2, 2))
  34. x = maximum_bipartite_matching(graph, perm_type='row')
  35. y = maximum_bipartite_matching(graph, perm_type='column')
  36. assert_array_equal(np.array([-1, -1]), x)
  37. assert_array_equal(np.array([-1, -1]), y)
  38. def test_maximum_bipartite_matching_graph_that_causes_augmentation():
  39. # In this graph, column 1 is initially assigned to row 1, but it should be
  40. # reassigned to make room for row 2.
  41. graph = csr_array([[1, 1], [1, 0]])
  42. x = maximum_bipartite_matching(graph, perm_type='column')
  43. y = maximum_bipartite_matching(graph, perm_type='row')
  44. expected_matching = np.array([1, 0])
  45. assert_array_equal(expected_matching, x)
  46. assert_array_equal(expected_matching, y)
  47. def test_maximum_bipartite_matching_graph_with_more_rows_than_columns():
  48. graph = csr_array([[1, 1], [1, 0], [0, 1]])
  49. x = maximum_bipartite_matching(graph, perm_type='column')
  50. y = maximum_bipartite_matching(graph, perm_type='row')
  51. assert_array_equal(np.array([0, -1, 1]), x)
  52. assert_array_equal(np.array([0, 2]), y)
  53. def test_maximum_bipartite_matching_graph_with_more_columns_than_rows():
  54. graph = csr_array([[1, 1, 0], [0, 0, 1]])
  55. x = maximum_bipartite_matching(graph, perm_type='column')
  56. y = maximum_bipartite_matching(graph, perm_type='row')
  57. assert_array_equal(np.array([0, 2]), x)
  58. assert_array_equal(np.array([0, -1, 1]), y)
  59. def test_maximum_bipartite_matching_explicit_zeros_count_as_edges():
  60. data = [0, 0]
  61. indices = [1, 0]
  62. indptr = [0, 1, 2]
  63. graph = csr_array((data, indices, indptr), shape=(2, 2))
  64. x = maximum_bipartite_matching(graph, perm_type='row')
  65. y = maximum_bipartite_matching(graph, perm_type='column')
  66. expected_matching = np.array([1, 0])
  67. assert_array_equal(expected_matching, x)
  68. assert_array_equal(expected_matching, y)
  69. def test_maximum_bipartite_matching_feasibility_of_result():
  70. # This is a regression test for GitHub issue #11458
  71. data = np.ones(50, dtype=int)
  72. indices = [11, 12, 19, 22, 23, 5, 22, 3, 8, 10, 5, 6, 11, 12, 13, 5, 13,
  73. 14, 20, 22, 3, 15, 3, 13, 14, 11, 12, 19, 22, 23, 5, 22, 3, 8,
  74. 10, 5, 6, 11, 12, 13, 5, 13, 14, 20, 22, 3, 15, 3, 13, 14]
  75. indptr = [0, 5, 7, 10, 10, 15, 20, 22, 22, 23, 25, 30, 32, 35, 35, 40, 45,
  76. 47, 47, 48, 50]
  77. graph = csr_array((data, indices, indptr), shape=(20, 25))
  78. x = maximum_bipartite_matching(graph, perm_type='row')
  79. y = maximum_bipartite_matching(graph, perm_type='column')
  80. assert (x != -1).sum() == 13
  81. assert (y != -1).sum() == 13
  82. # Ensure that each element of the matching is in fact an edge in the graph.
  83. for u, v in zip(range(graph.shape[0]), y):
  84. if v != -1:
  85. assert graph[u, v]
  86. for u, v in zip(x, range(graph.shape[1])):
  87. if u != -1:
  88. assert graph[u, v]
  89. def test_matching_large_random_graph_with_one_edge_incident_to_each_vertex():
  90. np.random.seed(42)
  91. A = diags_array(np.ones(25), offsets=0, format='csr')
  92. rand_perm = np.random.permutation(25)
  93. rand_perm2 = np.random.permutation(25)
  94. Rrow = np.arange(25)
  95. Rcol = rand_perm
  96. Rdata = np.ones(25, dtype=int)
  97. Rmat = csr_array((Rdata, (Rrow, Rcol)))
  98. Crow = rand_perm2
  99. Ccol = np.arange(25)
  100. Cdata = np.ones(25, dtype=int)
  101. Cmat = csr_array((Cdata, (Crow, Ccol)))
  102. # Randomly permute identity matrix
  103. B = Rmat @ A @ Cmat
  104. # Row permute
  105. perm = maximum_bipartite_matching(B, perm_type='row')
  106. Rrow = np.arange(25)
  107. Rcol = perm
  108. Rdata = np.ones(25, dtype=int)
  109. Rmat = csr_array((Rdata, (Rrow, Rcol)))
  110. C1 = Rmat @ B
  111. # Column permute
  112. perm2 = maximum_bipartite_matching(B, perm_type='column')
  113. Crow = perm2
  114. Ccol = np.arange(25)
  115. Cdata = np.ones(25, dtype=int)
  116. Cmat = csr_array((Cdata, (Crow, Ccol)))
  117. C2 = B @ Cmat
  118. # Should get identity matrix back
  119. assert_equal(any(C1.diagonal() == 0), False)
  120. assert_equal(any(C2.diagonal() == 0), False)
  121. @pytest.mark.parametrize('num_rows,num_cols', [(0, 0), (2, 0), (0, 3)])
  122. def test_min_weight_full_matching_trivial_graph(num_rows, num_cols):
  123. biadjacency = csr_array((num_cols, num_rows))
  124. biadjacency1 = coo_array((num_cols, num_rows))
  125. row_ind, col_ind = min_weight_full_bipartite_matching(biadjacency)
  126. assert len(row_ind) == 0
  127. assert len(col_ind) == 0
  128. row_ind1, col_ind1 = min_weight_full_bipartite_matching(biadjacency1)
  129. assert len(row_ind1) == 0
  130. assert len(col_ind1) == 0
  131. @pytest.mark.parametrize('biadjacency',
  132. [
  133. [[1, 1, 1], [1, 0, 0], [1, 0, 0]],
  134. [[1, 1, 1], [0, 0, 1], [0, 0, 1]],
  135. [[1, 0, 0, 1], [1, 1, 0, 1], [0, 0, 0, 0]],
  136. [[1, 0, 0], [2, 0, 0]],
  137. [[0, 1, 0], [0, 2, 0]],
  138. [[1, 0], [2, 0], [5, 0]]
  139. ])
  140. def test_min_weight_full_matching_infeasible_problems(biadjacency):
  141. with pytest.raises(ValueError):
  142. min_weight_full_bipartite_matching(csr_array(biadjacency))
  143. with pytest.raises(ValueError):
  144. min_weight_full_bipartite_matching(coo_array(biadjacency))
  145. def test_min_weight_full_matching_large_infeasible():
  146. # Regression test for GitHub issue #17269
  147. a = np.asarray([
  148. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  149. 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  150. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  151. 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  152. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  153. 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  154. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  155. 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0, 0.0],
  156. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  157. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0, 0.0],
  158. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  159. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.0],
  160. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  161. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0],
  162. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  163. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0],
  164. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  165. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001],
  166. [0.0, 0.11687445, 0.0, 0.0, 0.01319788, 0.07509257, 0.0,
  167. 0.0, 0.0, 0.74228317, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  168. 0.0, 0.0, 0.0, 0.0, 0.0],
  169. [0.0, 0.0, 0.0, 0.81087935, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  170. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  171. [0.0, 0.0, 0.0, 0.0, 0.8408466, 0.0, 0.0, 0.0, 0.0, 0.01194389,
  172. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  173. [0.0, 0.82994211, 0.0, 0.0, 0.0, 0.11468516, 0.0, 0.0, 0.0,
  174. 0.11173505, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  175. 0.0, 0.0],
  176. [0.18796507, 0.0, 0.04002318, 0.0, 0.0, 0.0, 0.0, 0.0, 0.75883335,
  177. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  178. [0.0, 0.0, 0.71545464, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02748488,
  179. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  180. [0.78470564, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14829198,
  181. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  182. [0.0, 0.10870609, 0.0, 0.0, 0.0, 0.8918677, 0.0, 0.0, 0.0, 0.06306644,
  183. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  184. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
  185. 0.63844085, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  186. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7442354, 0.0, 0.0, 0.0,
  187. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  188. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09850549, 0.0, 0.0, 0.18638258,
  189. 0.2769244, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  190. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.73182464, 0.0, 0.0, 0.46443561,
  191. 0.38589284, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  192. [0.29510278, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.09666032, 0.0,
  193. 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  194. ])
  195. with pytest.raises(ValueError, match='no full matching exists'):
  196. min_weight_full_bipartite_matching(csr_array(a))
  197. with pytest.raises(ValueError, match='no full matching exists'):
  198. min_weight_full_bipartite_matching(coo_array(a))
  199. def test_explicit_zero_causes_warning():
  200. biadjacency = csr_array(((2, 0, 3), (0, 1, 1), (0, 2, 3)))
  201. with pytest.warns(UserWarning):
  202. min_weight_full_bipartite_matching(biadjacency)
  203. with pytest.warns(UserWarning):
  204. min_weight_full_bipartite_matching(biadjacency.tocoo())
  205. # General test for linear sum assignment solvers to make it possible to rely
  206. # on the same tests for scipy.optimize.linear_sum_assignment.
  207. def linear_sum_assignment_assertions(
  208. solver, array_type, sign, test_case
  209. ):
  210. cost_matrix, expected_cost = test_case
  211. maximize = sign == -1
  212. cost_matrix = sign * array_type(cost_matrix)
  213. expected_cost = sign * np.array(expected_cost)
  214. row_ind, col_ind = solver(cost_matrix, maximize=maximize)
  215. assert_array_equal(row_ind, np.sort(row_ind))
  216. assert_array_equal(expected_cost,
  217. np.array(cost_matrix[row_ind, col_ind]).flatten())
  218. cost_matrix = cost_matrix.T
  219. row_ind, col_ind = solver(cost_matrix, maximize=maximize)
  220. assert_array_equal(row_ind, np.sort(row_ind))
  221. assert_array_equal(np.sort(expected_cost),
  222. np.sort(np.array(
  223. cost_matrix[row_ind, col_ind])).flatten())
  224. linear_sum_assignment_test_cases = product(
  225. [-1, 1],
  226. [
  227. # Square
  228. ([[400, 150, 400],
  229. [400, 450, 600],
  230. [300, 225, 300]],
  231. [150, 400, 300]),
  232. # Rectangular variant
  233. ([[400, 150, 400, 1],
  234. [400, 450, 600, 2],
  235. [300, 225, 300, 3]],
  236. [150, 2, 300]),
  237. ([[10, 10, 8],
  238. [9, 8, 1],
  239. [9, 7, 4]],
  240. [10, 1, 7]),
  241. # Square
  242. ([[10, 10, 8, 11],
  243. [9, 8, 1, 1],
  244. [9, 7, 4, 10]],
  245. [10, 1, 4]),
  246. # Rectangular variant
  247. ([[10, float("inf"), float("inf")],
  248. [float("inf"), float("inf"), 1],
  249. [float("inf"), 7, float("inf")]],
  250. [10, 1, 7])
  251. ])
  252. @pytest.mark.parametrize('sign,test_case', linear_sum_assignment_test_cases)
  253. def test_min_weight_full_matching_small_inputs(sign, test_case):
  254. linear_sum_assignment_assertions(
  255. min_weight_full_bipartite_matching, csr_array, sign, test_case)