test_nnls.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import numpy as np
  2. from numpy.testing import assert_allclose
  3. from pytest import raises as assert_raises
  4. from scipy.optimize import nnls
  5. import pytest
  6. class TestNNLS:
  7. def setup_method(self):
  8. self.rng = np.random.default_rng(1685225766635251)
  9. def test_nnls(self):
  10. a = np.arange(25.0).reshape(-1, 5)
  11. x = np.arange(5.0)
  12. y = a @ x
  13. x, res = nnls(a, y)
  14. assert res < 1e-7
  15. assert np.linalg.norm((a @ x) - y) < 1e-7
  16. def test_nnls_tall(self):
  17. a = self.rng.uniform(low=-10, high=10, size=[50, 10])
  18. x = np.abs(self.rng.uniform(low=-2, high=2, size=[10]))
  19. x[::2] = 0
  20. b = a @ x
  21. xact, rnorm = nnls(a, b)
  22. assert_allclose(xact, x, rtol=0., atol=1e-10)
  23. assert rnorm < 1e-12
  24. def test_nnls_wide(self):
  25. # If too wide then problem becomes too ill-conditioned ans starts
  26. # emitting warnings, hence small m, n difference.
  27. a = self.rng.uniform(low=-10, high=10, size=[100, 120])
  28. x = np.abs(self.rng.uniform(low=-2, high=2, size=[120]))
  29. x[::2] = 0
  30. b = a @ x
  31. xact, rnorm = nnls(a, b)
  32. assert_allclose(xact, x, rtol=0., atol=1e-10)
  33. assert rnorm < 1e-12
  34. def test_maxiter(self):
  35. # test that maxiter argument does stop iterations
  36. a = self.rng.uniform(size=(5, 10))
  37. b = self.rng.uniform(size=5)
  38. with assert_raises(RuntimeError):
  39. nnls(a, b, maxiter=1)
  40. def test_nnls_inner_loop_case1(self):
  41. # See gh-20168
  42. n = np.array(
  43. [3, 2, 0, 1, 1, 1, 3, 8, 14, 16, 29, 23, 41, 47, 53, 57, 67, 76,
  44. 103, 89, 97, 94, 85, 95, 78, 78, 78, 77, 73, 50, 50, 56, 68, 98,
  45. 95, 112, 134, 145, 158, 172, 213, 234, 222, 215, 216, 216, 206,
  46. 183, 135, 156, 110, 92, 63, 60, 52, 29, 20, 16, 12, 5, 5, 5, 1, 2,
  47. 3, 0, 2])
  48. k = np.array(
  49. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  50. 0., 0., 0., 0.7205812007860187, 0., 1.4411624015720375,
  51. 0.7205812007860187, 2.882324803144075, 5.76464960628815,
  52. 5.76464960628815, 12.249880413362318, 15.132205216506394,
  53. 20.176273622008523, 27.382085629868712, 48.27894045266326,
  54. 47.558359251877235, 68.45521407467177, 97.99904330689854,
  55. 108.0871801179028, 135.46926574777152, 140.51333415327366,
  56. 184.4687874012208, 171.49832578707245, 205.36564222401535,
  57. 244.27702706646033, 214.01261663344755, 228.42424064916793,
  58. 232.02714665309804, 205.36564222401535, 172.9394881886445,
  59. 191.67459940908097, 162.1307701768542, 153.48379576742198,
  60. 110.96950492104689, 103.04311171240067, 86.46974409432225,
  61. 60.528820866025576, 43.234872047161126, 23.779179625938617,
  62. 24.499760826724636, 17.29394881886445, 11.5292992125763,
  63. 5.76464960628815, 5.044068405502131, 3.6029060039300935, 0.,
  64. 2.882324803144075, 0., 0., 0.])
  65. d = np.array(
  66. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  67. 0., 0., 0., 0.003889242101538, 0., 0.007606268390096, 0.,
  68. 0.025457371599973, 0.036952882091577, 0., 0.08518359183449,
  69. 0.048201126400243, 0.196234990022205, 0.144116240157247,
  70. 0.171145134062442, 0., 0., 0.269555036538714, 0., 0., 0.,
  71. 0.010893241091872, 0., 0., 0., 0., 0., 0., 0., 0.,
  72. 0.048167058272886, 0.011238724891049, 0., 0., 0.055162603456078,
  73. 0., 0., 0., 0., 0.027753339088588, 0., 0., 0., 0., 0., 0., 0., 0.,
  74. 0., 0.])
  75. # The following code sets up a system of equations such that
  76. # $k_i-p_i*n_i$ is minimized for $p_i$ with weights $n_i$ and
  77. # monotonicity constraints on $p_i$. This translates to a system of
  78. # equations of the form $k_i - (d_1 + ... + d_i) * n_i$ and
  79. # non-negativity constraints on the $d_i$. If $n_i$ is zero the
  80. # system is modified such that $d_i - d_{i+1}$ is then minimized.
  81. N = len(n)
  82. A = np.diag(n) @ np.tril(np.ones((N, N)))
  83. w = n ** 0.5
  84. nz = (n == 0).nonzero()[0]
  85. A[nz, nz] = 1
  86. A[nz, np.minimum(nz + 1, N - 1)] = -1
  87. w[nz] = 1
  88. k[nz] = 0
  89. W = np.diag(w)
  90. # Small perturbations can already make the infinite loop go away (just
  91. # uncomment the next line)
  92. # k = k + 1e-10 * np.random.normal(size=N)
  93. dact, _ = nnls(W @ A, W @ k)
  94. assert_allclose(dact, d, rtol=0., atol=1e-10)
  95. def test_nnls_inner_loop_case2(self):
  96. # See gh-20168
  97. n = np.array(
  98. [1, 0, 1, 2, 2, 2, 3, 3, 5, 4, 14, 14, 19, 26, 36, 42, 36, 64, 64,
  99. 64, 81, 85, 85, 95, 95, 95, 75, 76, 69, 81, 62, 59, 68, 64, 71, 67,
  100. 74, 78, 118, 135, 153, 159, 210, 195, 218, 243, 236, 215, 196, 175,
  101. 185, 149, 144, 103, 104, 75, 56, 40, 32, 26, 17, 9, 12, 8, 2, 1, 1,
  102. 1])
  103. k = np.array(
  104. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  105. 0., 0., 0., 0., 0., 0.7064355064917867, 0., 0., 2.11930651947536,
  106. 0.7064355064917867, 0., 3.5321775324589333, 7.064355064917867,
  107. 11.302968103868587, 16.95445215580288, 20.486629688261814,
  108. 20.486629688261814, 37.44108184406469, 55.808405012851146,
  109. 78.41434122058831, 103.13958394780086, 105.965325973768,
  110. 125.74552015553803, 149.057891869767, 176.60887662294667,
  111. 197.09550631120848, 211.930651947536, 204.86629688261814,
  112. 233.8301526487814, 221.1143135319292, 195.6826352982249,
  113. 197.80194181770025, 191.4440222592742, 187.91184472681525,
  114. 144.11284332432447, 131.39700420747232, 116.5618585711448,
  115. 93.24948685691584, 89.01087381796512, 53.68909849337579,
  116. 45.211872415474346, 31.083162285638615, 24.72524272721253,
  117. 16.95445215580288, 9.890097090885014, 9.890097090885014,
  118. 2.8257420259671466, 2.8257420259671466, 1.4128710129835733,
  119. 0.7064355064917867, 1.4128710129835733])
  120. d = np.array(
  121. [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  122. 0., 0., 0., 0., 0., 0.0021916146355674473, 0., 0.,
  123. 0.011252740799789484, 0., 0., 0.037746623295934395,
  124. 0.03602328132946222, 0.09509167709829734, 0.10505765870204821,
  125. 0.01391037014274718, 0.0188296228752321, 0.20723559202324254,
  126. 0.3056220879462608, 0.13304643490426477, 0., 0., 0., 0., 0., 0.,
  127. 0., 0., 0., 0., 0., 0.043185876949706214, 0.0037266261379722554,
  128. 0., 0., 0., 0., 0., 0.094797899357143, 0., 0., 0., 0., 0., 0., 0.,
  129. 0., 0.23450935613672663, 0., 0., 0.07064355064917871])
  130. # The following code sets up a system of equations such that
  131. # $k_i-p_i*n_i$ is minimized for $p_i$ with weights $n_i$ and
  132. # monotonicity constraints on $p_i$. This translates to a system of
  133. # equations of the form $k_i - (d_1 + ... + d_i) * n_i$ and
  134. # non-negativity constraints on the $d_i$. If $n_i$ is zero the
  135. # system is modified such that $d_i - d_{i+1}$ is then minimized.
  136. N = len(n)
  137. A = np.diag(n) @ np.tril(np.ones((N, N)))
  138. w = n ** 0.5
  139. nz = (n == 0).nonzero()[0]
  140. A[nz, nz] = 1
  141. A[nz, np.minimum(nz + 1, N - 1)] = -1
  142. w[nz] = 1
  143. k[nz] = 0
  144. W = np.diag(w)
  145. dact, _ = nnls(W @ A, W @ k)
  146. p = np.cumsum(dact)
  147. assert np.all(dact >= 0)
  148. assert np.linalg.norm(k - n * p, ord=np.inf) < 28
  149. assert_allclose(dact, d, rtol=0., atol=1e-10)
  150. def test_nnls_gh20302(self):
  151. # See gh-20302
  152. A = np.array(
  153. [0.33408569134321575, 0.11136189711440525, 0.049140798007949286,
  154. 0.03712063237146841, 0.055680948557202625, 0.16642814595936478,
  155. 0.11095209730624318, 0.09791993030943345, 0.14793612974165757,
  156. 0.44380838922497273, 0.11099502671044059, 0.11099502671044059,
  157. 0.14693672599330593, 0.3329850801313218, 1.498432860590948,
  158. 0.0832374225132955, 0.11098323001772734, 0.19589481249472837,
  159. 0.5919105600945457, 3.5514633605672747, 0.06658716751427037,
  160. 0.11097861252378394, 0.24485832778293645, 0.9248217710315328,
  161. 6.936163282736496, 0.05547609388181014, 0.11095218776362029,
  162. 0.29376003042571264, 1.3314262531634435, 11.982836278470993,
  163. 0.047506113282944136, 0.11084759766020298, 0.3423969672933396,
  164. 1.8105107617833156, 19.010362998724812, 0.041507335004505576,
  165. 0.11068622667868154, 0.39074115283013344, 2.361306169145206,
  166. 28.335674029742474, 0.03682846280947718, 0.11048538842843154,
  167. 0.4387861797121048, 2.9831054875676517, 40.2719240821633,
  168. 0.03311278164362387, 0.11037593881207958, 0.4870572300443105,
  169. 3.6791979604026523, 55.187969406039784, 0.030079304092299915,
  170. 0.11029078167176636, 0.5353496017200152, 4.448394860761242,
  171. 73.3985152025605, 0.02545939709595835, 0.11032405408248619,
  172. 0.6328767609778363, 6.214921713313388, 121.19097340961108,
  173. 0.022080881724881523, 0.11040440862440762, 0.7307742886903428,
  174. 8.28033064683057, 186.30743955368786, 0.020715838214945492,
  175. 0.1104844704797093, 0.7800578384588346, 9.42800814760186,
  176. 226.27219554244465, 0.01843179728340054, 0.11059078370040323,
  177. 0.8784095015912599, 11.94380463964355, 322.48272527037585,
  178. 0.015812787653789077, 0.11068951357652354, 1.0257259848595766,
  179. 16.27135849574896, 512.5477926160922, 0.014438550529330062,
  180. 0.11069555405819713, 1.1234754801775881, 19.519316032262093,
  181. 673.4164031130423, 0.012760770585072577, 0.110593345070629,
  182. 1.2688431112524712, 24.920367089248398, 971.8943164806875,
  183. 0.011427556646114315, 0.11046638091243838, 1.413623342459821,
  184. 30.967408782453557, 1347.0822820367298, 0.010033330264470307,
  185. 0.11036663290917338, 1.6071533470570285, 40.063087746029936,
  186. 1983.122843428482, 0.008950061496507258, 0.11038409179025618,
  187. 1.802244865119193, 50.37194055362024, 2795.642700725923,
  188. 0.008071078821135658, 0.11030474388885401, 1.9956465761433504,
  189. 61.80742482572119, 3801.1566267818534, 0.007191031207777556,
  190. 0.11026247851925586, 2.238160187262168, 77.7718015155818,
  191. 5366.2543045751445, 0.00636834224248, 0.11038459886965334,
  192. 2.5328963107984297, 99.49331844784753, 7760.4788389321075,
  193. 0.005624259098118485, 0.11061042892966355, 2.879742607664547,
  194. 128.34496770138628, 11358.529641572684, 0.0050354270614989555,
  195. 0.11077939535297703, 3.2263279459292575, 160.85168205252265,
  196. 15924.316523199741, 0.0044997853165982555, 0.1109947044760903,
  197. 3.6244287189055613, 202.60233390369015, 22488.859063309606,
  198. 0.004023601950058174, 0.1113196539516095, 4.07713905729421,
  199. 255.6270320242126, 31825.565487014468, 0.0036024117873727094,
  200. 0.111674765408554, 4.582933773135057, 321.9583486728612,
  201. 44913.18963986413, 0.003201503089582304, 0.11205260813538065,
  202. 5.191786833370116, 411.79333489752383, 64857.45024636,
  203. 0.0028633044552448853, 0.11262330857296549, 5.864295861648949,
  204. 522.7223161899905, 92521.84996562831, 0.0025691897303891965,
  205. 0.11304434813712465, 6.584584405106342, 656.5615739804199,
  206. 129999.19164812315, 0.0022992911894424675, 0.11343169867916175,
  207. 7.4080129906658305, 828.2026426227864, 183860.98666225857,
  208. 0.0020449922071108764, 0.11383789952917212, 8.388975556433872,
  209. 1058.2750599896935, 265097.9025274183, 0.001831274615120854,
  210. 0.11414945100919989, 9.419351803810935, 1330.564050780237,
  211. 373223.2162438565, 0.0016363333454631633, 0.11454333418242145,
  212. 10.6143816579462, 1683.787012481595, 530392.9089317025,
  213. 0.0014598610433380044, 0.11484240207592301, 11.959688127956882,
  214. 2132.0874753402027, 754758.9662704318, 0.0012985240015312626,
  215. 0.11513579480243862, 13.514425358573531, 2715.5160990137824,
  216. 1083490.9235064993, 0.0011614735761289934, 0.11537304189548002,
  217. 15.171418602667567, 3415.195870828736, 1526592.554260445,
  218. 0.0010347472698811352, 0.11554677847006009, 17.080800985009617,
  219. 4322.412404600832, 2172012.2333119176, 0.0009232988811258664,
  220. 0.1157201264344419, 19.20004861829407, 5453.349531598553,
  221. 3075689.135821584, 0.0008228871862975205, 0.11602709326795038,
  222. 21.65735242414206, 6920.203923780365, 4390869.389638642,
  223. 0.00073528900066722, 0.11642075843897651, 24.40223571298994,
  224. 8755.811207598026, 6238515.485413593, 0.0006602764384729194,
  225. 0.11752920604817965, 27.694443541914293, 11171.386093291572,
  226. 8948280.260726549, 0.0005935538977939806, 0.11851292825953147,
  227. 31.325508920763063, 14174.185724149384, 12735505.873148222,
  228. 0.0005310755355633124, 0.11913794514470308, 35.381052949627765,
  229. 17987.010118815077, 18157886.71494382, 0.00047239949671590953,
  230. 0.1190446731724092, 39.71342528048061, 22679.438775422022,
  231. 25718483.571328573, 0.00041829129789387623, 0.11851586773659825,
  232. 44.45299332965028, 28542.57147989741, 36391778.63686921,
  233. 0.00037321512015419886, 0.11880681324908665, 50.0668539579632,
  234. 36118.26128449941, 51739409.29004541, 0.0003315539616702064,
  235. 0.1184752823034871, 56.04387059062639, 45383.29960621684,
  236. 72976345.76679668, 0.00029456064937920213, 0.11831519416731286,
  237. 62.91195073220101, 57265.53993693082, 103507463.43600245,
  238. 0.00026301867496859703, 0.11862142241083726, 70.8217262087034,
  239. 72383.14781936012, 146901598.49939138, 0.00023618734450420032,
  240. 0.11966825454879482, 80.26535457124461, 92160.51176984518,
  241. 210125966.835247, 0.00021165918071578316, 0.12043407382728061,
  242. 90.7169587544247, 116975.56852918258, 299515943.218972,
  243. 0.00018757727511329545, 0.11992440455576689, 101.49899864101785,
  244. 147056.26174166967, 423080865.0307836, 0.00016654469159895833,
  245. 0.11957908856805206, 113.65970431102812, 184937.67016486943,
  246. 597533612.3026931, 0.00014717439179415048, 0.11872067604728138,
  247. 126.77899683346702, 231758.58906776624, 841283678.3159915,
  248. 0.00012868496382376066, 0.1166314722122684, 139.93635237349534,
  249. 287417.30847929465, 1172231492.6328032, 0.00011225559452625302,
  250. 0.11427619522772557, 154.0034283704458, 355281.4912295324,
  251. 1627544511.322488, 9.879511142981067e-05, 0.11295574406808354,
  252. 170.96532050841535, 442971.0111288653, 2279085852.2580123,
  253. 8.71257780313587e-05, 0.11192758284428547, 190.35067416684697,
  254. 554165.2523674504, 3203629323.93623, 7.665069027765277e-05,
  255. 0.11060694607065294, 211.28835951100046, 690933.608546013,
  256. 4486577387.093535, 6.734021094824451e-05, 0.10915848194710433,
  257. 234.24338803525194, 860487.9079859136, 6276829044.8032465,
  258. 5.9191625040287665e-05, 0.10776821865668373, 259.7454711820425,
  259. 1071699.0387579766, 8780430224.544102, 5.1856803674907676e-05,
  260. 0.10606444911641115, 287.1843540288165, 1331126.3723998806,
  261. 12251687131.5685, 4.503421404759231e-05, 0.10347361247668461,
  262. 314.7338642485931, 1638796.0697522392, 16944331963.203278,
  263. 3.90470387455642e-05, 0.1007804070023012, 344.3427560918527,
  264. 2014064.4865519698, 23392351979.057854, 3.46557661636393e-05,
  265. 0.10046706610839032, 385.56603915081587, 2533036.2523656,
  266. 33044724430.235435, 3.148745865254635e-05, 0.1025441570117926,
  267. 442.09038234164746, 3262712.3882769793, 47815050050.199135,
  268. 2.9790762078715404e-05, 0.1089845379379672, 527.8068231298969,
  269. 4375751.903321453, 72035815708.42941, 2.8772639817606534e-05,
  270. 0.11823636789048445, 643.2048194503195, 5989838.001888927,
  271. 110764084330.93005, 2.7951691815106586e-05, 0.12903432664913705,
  272. 788.5500418523591, 8249371.000613411, 171368308481.2427,
  273. 2.6844392423114212e-05, 0.1392060709754626, 955.6296403631383,
  274. 11230229.319931043, 262063016295.25085, 2.499458273851386e-05,
  275. 0.14559344445184325, 1122.7022399726002, 14820229.698461473,
  276. 388475270970.9214, 2.337386729019776e-05, 0.15294300496886065,
  277. 1324.8158105672455, 19644861.137128454, 578442936182.7473,
  278. 2.0081014872174113e-05, 0.14760215298210377, 1436.2385042492353,
  279. 23923681.729276657, 791311658718.4193, 1.773374462991839e-05,
  280. 0.14642752940923615, 1600.5596278736678, 29949429.82503553,
  281. 1112815989293.9326, 1.5303115839590797e-05, 0.14194150045081785,
  282. 1742.873058605698, 36634451.931305364, 1529085389160.7544,
  283. 1.3148448731163076e-05, 0.13699368732998807, 1889.5284359054356,
  284. 44614279.74469635, 2091762812969.9607, 1.1739194407590062e-05,
  285. 0.13739553134643406, 2128.794599579694, 56462810.11822766,
  286. 2973783283306.8145, 1.0293367506254706e-05, 0.13533033372723272,
  287. 2355.372854690074, 70176508.28667311, 4151852759764.441,
  288. 9.678312586863569e-06, 0.14293577249119244, 2794.531827932675,
  289. 93528671.31952812, 6215821967224.52, -1.174086323572049e-05,
  290. 0.1429501325944908, 3139.4804810720925, 118031680.16618933,
  291. -6466892421886.174, -2.1188265307407812e-05, 0.1477108290912869,
  292. 3644.1133424610953, 153900132.62392554, -4828013117542.036,
  293. -8.614483025123122e-05, 0.16037100755883044, 4444.386620899393,
  294. 210846007.89660168, -1766340937974.433, 4.981445776141726e-05,
  295. 0.16053420251962536, 4997.558254401547, 266327328.4755411,
  296. 3862250287024.725, 1.8500019169456637e-05, 0.15448417164977674,
  297. 5402.289867444643, 323399508.1475582, 12152445411933.408,
  298. -5.647882376069748e-05, 0.1406372975946189, 5524.633133597753,
  299. 371512945.9909363, -4162951345292.1514, 2.8048523486337994e-05,
  300. 0.13183417571186926, 5817.462495763679, 439447252.3728975,
  301. 9294740538175.03]).reshape(89, 5)
  302. b = np.ones(89, dtype=np.float64)
  303. sol, rnorm = nnls(A, b)
  304. assert_allclose(sol, np.array([0.61124315, 8.22262829, 0., 0., 0.]))
  305. assert_allclose(rnorm, 1.0556460808977297)
  306. def test_nnls_gh21021_ex1(self):
  307. # Review examples used in gh-21021
  308. A = [[0.004734199143798789, -0.09661916455815653, -0.04308779048103441,
  309. 0.4039475561867938, -0.27742598780954364, -0.20816924034369574,
  310. -0.17264070902176, 0.05251808558963846],
  311. [-0.030263548855047975, -0.30356483926431466, 0.18080406600591398,
  312. -0.06892233941254086, -0.41837298885432317, 0.30245352819647003,
  313. -0.19008975278116397, -0.00990809825429995],
  314. [-0.2561747595787612, -0.04376282125249583, 0.4422181991706678,
  315. -0.13720906318924858, -0.0069523811763796475, -0.059238287107464795,
  316. 0.028663214369642594, 0.5415531284893763],
  317. [0.2949336072968401, 0.33997647534935094, 0.38441519339815755,
  318. -0.306001783010386, 0.18120773805949028, -0.36669767490747895,
  319. -0.021539960590992304, -0.2784251712424615],
  320. [0.5009075736232653, -0.20161970347571165, 0.08404512586550646,
  321. 0.2520496489348788, 0.14812015101612894, -0.25823455803981266,
  322. -0.1596872058396596, 0.5960141613922691]
  323. ]
  324. b = [18.036779281222124, -18.126530733870887, 13.535652034584029,
  325. -2.6654275476795966, 9.166315328199575]
  326. # Obtained from matlab's lstnonneg
  327. des_sol = np.array([0., 118.017802006619, 45.1996532316584, 102.62156313537,
  328. 0., 55.8590204314398, 0., 29.7328833253434])
  329. sol, res = nnls(A, b)
  330. assert_allclose(sol, des_sol)
  331. assert np.abs(np.linalg.norm(A@sol - b) - res) < 5e-14
  332. def test_nnls_gh21021_ex2(self):
  333. A = np.array([
  334. [0.2508259992635229, -0.24031300195203256],
  335. [0.510647748500133, 0.2872936081767836],
  336. [0.8196387904102849, -0.03520620107046682],
  337. [0.030739759120097084, -0.07768656359879388]])
  338. b = np.array([24.456141951303913,
  339. 28.047143273432333,
  340. 41.10526799545987,
  341. -1.2078282698324068])
  342. sol, res = nnls(A, b)
  343. assert_allclose(sol, np.array([54.3047953202271, 0.0]))
  344. assert np.abs(np.linalg.norm(A@sol - b) - res) < 5e-14
  345. def test_nnls_gh21021_ex3(self):
  346. A = np.array([
  347. [0.08247592017366788, 0.058398241636675674, -0.1031496693415968,
  348. 0.03156983127072098, -0.029503680182026665],
  349. [0.21463607509982277, -0.2164518969308173, -0.10816833396662294,
  350. 0.12133867146012027, -0.15025010408668332],
  351. [0.07251900316494089, -0.003044559315020767, 0.042682817961676424,
  352. -0.018157525489298176, 0.11561953260568134],
  353. [0.2328797918159187, -0.09112909645892767, 0.21348169727099078,
  354. 0.00449447624089599, -0.16615256386885716],
  355. [-0.02440856024843897, -0.20131427208575386, 0.030275781997161483,
  356. -0.04560777213546784, 0.11007266012013553],
  357. [-0.2928391429686263, -0.20437574856615687, -0.020892110811574407,
  358. -0.10455040720819309, 0.05337267000160461],
  359. [0.22041503019400316, 0.014262782992311842, 0.08274606359871121,
  360. -0.17933172096518907, -0.11809690350702161],
  361. [0.10440436007469953, 0.09171452270577712, 0.03942347724809893,
  362. 0.11457669688231396, 0.07529747295631585],
  363. [-0.052087576116032056, -0.15787717158077047, -0.08232202515883282,
  364. -0.03194837933710708, -0.0546812506025729],
  365. [-0.010388407673304468, 0.015174707581808923, 0.04764509565386281,
  366. -0.1781221936030805, 0.10218894080536609],
  367. [0.03272263140115928, -0.27576456949442574, 0.024897570959901753,
  368. -0.1417129166632282, -0.03320796462136591],
  369. [-0.12490006751823997, -0.03012003515442302, -0.051495264012509506,
  370. 0.012070729698374614, 0.04811700123118234],
  371. [0.15254854117990788, -0.051863547789218374, 0.058012914127346174,
  372. -0.06717991061422621, -0.14514671564242257],
  373. [0.12251250415395559, -0.17462495626695362, -0.025334728552179834,
  374. 0.11425350676877533, 0.06183915953812639],
  375. [0.19334259720491218, 0.2164301986218955, -0.018882278726614483,
  376. 0.07950236716817938, -0.2220529357431092],
  377. [-0.01822205701890852, 0.12630444976752267, -0.03118092027244001,
  378. 0.02773743885242581, 0.06444433740044248],
  379. [0.13344116850581977, -0.05142877469996826, 0.3385702016705455,
  380. -0.25814970787123004, 0.2679034842977378],
  381. [0.1309747058619377, 0.12090608957940627, -0.13957978654106512,
  382. 0.17048819760322642, -0.241775259969348],
  383. [0.28613102173467275, -0.47153463906732174, 0.20359970518269746,
  384. -0.0962095202871843, -0.07703076550836387],
  385. [0.2212788380372723, 0.02569245145758152, -0.021596152392209966,
  386. 0.04610005150029433, -0.2024454395619734],
  387. [-0.043225338359410316, 0.17816095186290315, -0.014709092962616079,
  388. 0.06993970293287989, -0.09033722782555903],
  389. [0.17747622942563512, -0.20991014784011458, 0.06265720409894943,
  390. 0.0689704059061795, 0.024474319398401525],
  391. [-0.1163880385601698, 0.29989570587630027, 0.033443765320984545,
  392. 0.008470296514656, -0.0014457113271462002],
  393. [0.024375314902718406, 0.05279830705548363, 0.02691082431023144,
  394. 0.05265079368002343, 0.15542988147487913],
  395. [-0.01855218360922308, -0.050265869142888164, 0.2567912677240452,
  396. -0.2606428528561333, 0.25334396245022245]])
  397. b = np.array([-7.876625373734849, -8.259856278691373, 3.2593082374900963,
  398. 16.30170376973345, 2.311892943629045, -1.595345202555738,
  399. 6.318582970536518, 3.0104212955340093, -6.286202915842167,
  400. 3.6382333725029294, 1.9012066681249356, -3.932236581436514,
  401. 4.4299317131740406, -1.9345885161292682, -1.4418721521970805,
  402. -2.3810103256943926, 25.853603392922526, -10.658470311610483,
  403. 15.547103681119214, -1.6491066136547277, -1.1232029689817422,
  404. 4.7845749463206975, 2.553803732013229, 2.0549409701753705,
  405. 19.60887153608244])
  406. sol, res = nnls(A, b)
  407. assert_allclose(sol, np.array([0.0, 0.0, 76.3611306173957, 0.0, 0.0]),
  408. atol=5e-14)
  409. assert np.abs(np.linalg.norm(A@sol - b) - res) < 5e-14
  410. def test_atol_deprecation_warning(self):
  411. """Test that using atol parameter triggers deprecation warning"""
  412. a = np.array([[1, 0], [1, 0], [0, 1]])
  413. b = np.array([2, 1, 1])
  414. with pytest.warns(DeprecationWarning, match="{'atol'}"):
  415. nnls(a, b, atol=1e-8)
  416. def test_2D_singleton_RHS_input(self):
  417. # Test that a 2D singleton RHS input is accepted
  418. A = np.array([[1.0, 0.5, -1.],
  419. [1.0, 0.5, 0.0],
  420. [-1., 0.0, 1.0]])
  421. b = np.array([[-1.0, 2.0, 2.0]]).T
  422. x, r = nnls(A, b)
  423. assert_allclose(x, np.array([1.0, 2.0, 3.0]))
  424. assert_allclose(r, 0.0)
  425. def test_2D_not_singleton_RHS_input_2(self):
  426. # Test that a 2D but not a column vector RHS input is rejected
  427. A = np.array([[1.0, 0.5, -1.],
  428. [1.0, 0.5, 0.0],
  429. [1.0, 0.5, 0.0],
  430. [0.0, 0.0, 1.0]])
  431. b = np.ones(shape=[4, 2], dtype=np.float64)
  432. with pytest.raises(ValueError, match="Expected a 1D array"):
  433. nnls(A, b)
  434. def test_gh_22791_32bit(self):
  435. # Scikit-learn got hit by this problem on 32-bit arch.
  436. desired = [0, 0, 1.05617285, 0, 0, 0, 0, 0.23123048, 0, 0, 0, 0.26128651]
  437. rng = np.random.RandomState(42)
  438. n_samples, n_features = 5, 12
  439. X = rng.randn(n_samples, n_features)
  440. X[:2, :] = 0
  441. y = rng.randn(n_samples)
  442. coef, _ = nnls(X, y)
  443. assert_allclose(coef, desired)