test_special_matrices.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. import pytest
  2. import numpy as np
  3. from numpy import arange, array, eye, copy, sqrt
  4. from numpy.testing import (assert_equal, assert_array_equal,
  5. assert_array_almost_equal, assert_allclose)
  6. from pytest import raises as assert_raises
  7. from scipy.fft import fft
  8. from scipy.special import comb
  9. from scipy.linalg import (toeplitz, hankel, circulant, hadamard, leslie, dft,
  10. companion, block_diag,
  11. helmert, hilbert, invhilbert, pascal, invpascal,
  12. fiedler, fiedler_companion, eigvals,
  13. convolution_matrix)
  14. from numpy.linalg import cond
  15. from scipy._lib._array_api import (make_xp_test_case, xp_assert_equal, xp_size,
  16. xp_default_dtype)
  17. class TestToeplitz:
  18. def test_basic(self):
  19. y = toeplitz([1, 2, 3])
  20. assert_array_equal(y, [[1, 2, 3], [2, 1, 2], [3, 2, 1]])
  21. y = toeplitz([1, 2, 3], [1, 4, 5])
  22. assert_array_equal(y, [[1, 4, 5], [2, 1, 4], [3, 2, 1]])
  23. def test_complex_01(self):
  24. data = (1.0 + arange(3.0)) * (1.0 + 1.0j)
  25. x = copy(data)
  26. t = toeplitz(x)
  27. # Calling toeplitz should not change x.
  28. assert_array_equal(x, data)
  29. # According to the docstring, x should be the first column of t.
  30. col0 = t[:, 0]
  31. assert_array_equal(col0, data)
  32. assert_array_equal(t[0, 1:], data[1:].conj())
  33. def test_scalar_00(self):
  34. """Scalar arguments still produce a 2D array."""
  35. t = toeplitz(10)
  36. assert_array_equal(t, [[10]])
  37. t = toeplitz(10, 20)
  38. assert_array_equal(t, [[10]])
  39. def test_scalar_01(self):
  40. c = array([1, 2, 3])
  41. t = toeplitz(c, 1)
  42. assert_array_equal(t, [[1], [2], [3]])
  43. def test_scalar_02(self):
  44. c = array([1, 2, 3])
  45. t = toeplitz(c, array(1))
  46. assert_array_equal(t, [[1], [2], [3]])
  47. def test_scalar_03(self):
  48. c = array([1, 2, 3])
  49. t = toeplitz(c, array([1]))
  50. assert_array_equal(t, [[1], [2], [3]])
  51. def test_scalar_04(self):
  52. r = array([10, 2, 3])
  53. t = toeplitz(1, r)
  54. assert_array_equal(t, [[1, 2, 3]])
  55. class TestHankel:
  56. def test_basic(self):
  57. y = hankel([1, 2, 3])
  58. assert_array_equal(y, [[1, 2, 3], [2, 3, 0], [3, 0, 0]])
  59. y = hankel([1, 2, 3], [3, 4, 5])
  60. assert_array_equal(y, [[1, 2, 3], [2, 3, 4], [3, 4, 5]])
  61. class TestCirculant:
  62. def test_basic(self):
  63. y = circulant([1, 2, 3])
  64. assert_array_equal(y, [[1, 3, 2], [2, 1, 3], [3, 2, 1]])
  65. class TestHadamard:
  66. def test_basic(self):
  67. y = hadamard(1)
  68. assert_array_equal(y, [[1]])
  69. y = hadamard(2, dtype=float)
  70. assert_array_equal(y, [[1.0, 1.0], [1.0, -1.0]])
  71. y = hadamard(4)
  72. assert_array_equal(y, [[1, 1, 1, 1],
  73. [1, -1, 1, -1],
  74. [1, 1, -1, -1],
  75. [1, -1, -1, 1]])
  76. assert_raises(ValueError, hadamard, 0)
  77. assert_raises(ValueError, hadamard, 5)
  78. class TestLeslie:
  79. def test_bad_shapes(self):
  80. assert_raises(ValueError, leslie, [[1, 1], [2, 2]], [3, 4, 5])
  81. assert_raises(ValueError, leslie, [1, 2], [1, 2])
  82. assert_raises(ValueError, leslie, [1], [])
  83. def test_basic(self):
  84. a = leslie([1, 2, 3], [0.25, 0.5])
  85. expected = array([[1.0, 2.0, 3.0],
  86. [0.25, 0.0, 0.0],
  87. [0.0, 0.5, 0.0]])
  88. assert_array_equal(a, expected)
  89. class TestCompanion:
  90. def test_bad_shapes(self):
  91. assert_raises(ValueError, companion, [0, 4, 5])
  92. assert_raises(ValueError, companion, [1])
  93. assert_raises(ValueError, companion, [])
  94. def test_basic(self):
  95. c = companion([1, 2, 3])
  96. expected = array([
  97. [-2.0, -3.0],
  98. [1.0, 0.0]])
  99. assert_array_equal(c, expected)
  100. c = companion([2.0, 5.0, -10.0])
  101. expected = array([
  102. [-2.5, 5.0],
  103. [1.0, 0.0]])
  104. assert_array_equal(c, expected)
  105. c = companion([(1.0, 2.0, 3.0),
  106. (4.0, 5.0, 6.0)])
  107. expected = array([
  108. ([-2.00, -3.00],
  109. [+1.00, +0.00]),
  110. ([-1.25, -1.50],
  111. [+1.00, +0.00])
  112. ])
  113. assert_array_equal(c, expected)
  114. @make_xp_test_case(block_diag)
  115. class TestBlockDiag:
  116. def test_basic(self, xp):
  117. dtype = xp.asarray(1).dtype
  118. x = block_diag(xp.eye(2, dtype=dtype), xp.asarray([[1, 2], [3, 4], [5, 6]]),
  119. xp.asarray([[1, 2, 3]]))
  120. xp_assert_equal(x, xp.asarray([[1, 0, 0, 0, 0, 0, 0],
  121. [0, 1, 0, 0, 0, 0, 0],
  122. [0, 0, 1, 2, 0, 0, 0],
  123. [0, 0, 3, 4, 0, 0, 0],
  124. [0, 0, 5, 6, 0, 0, 0],
  125. [0, 0, 0, 0, 1, 2, 3]]))
  126. def test_dtype(self, xp):
  127. x = block_diag(xp.asarray([[1.5]]))
  128. assert x.dtype == xp_default_dtype(xp)
  129. x = block_diag(xp.asarray([[True]]))
  130. assert x.dtype == xp.bool
  131. def test_mixed_dtypes(self, xp):
  132. actual = block_diag(xp.asarray([[1.]]), xp.asarray([[1j]]))
  133. desired = xp.asarray([[1, 0], [0, 1j]])
  134. xp_assert_equal(actual, desired)
  135. def test_scalar_and_1d_args(self, xp):
  136. a = block_diag(xp.asarray(1))
  137. assert a.shape == (1, 1)
  138. xp_assert_equal(a, xp.asarray([[1]]))
  139. a = block_diag(xp.asarray([2, 3]), xp.asarray(4))
  140. xp_assert_equal(a, xp.asarray([[2, 3, 0], [0, 0, 4]]))
  141. def test_no_args(self):
  142. a = block_diag()
  143. assert a.ndim == 2
  144. assert a.nbytes == 0
  145. def test_empty_matrix_arg(self, xp):
  146. # regression test for gh-4596: check the shape of the result
  147. # for empty matrix inputs. Empty matrices are no longer ignored
  148. # (gh-4908) it is viewed as a shape (1, 0) matrix.
  149. dtype = xp.asarray(1).dtype
  150. a = block_diag(xp.asarray([[1, 0], [0, 1]]),
  151. xp.asarray([], dtype=dtype),
  152. xp.asarray([[2, 3], [4, 5], [6, 7]]))
  153. xp_assert_equal(a, xp.asarray([[1, 0, 0, 0],
  154. [0, 1, 0, 0],
  155. [0, 0, 0, 0],
  156. [0, 0, 2, 3],
  157. [0, 0, 4, 5],
  158. [0, 0, 6, 7]]))
  159. @pytest.mark.skip_xp_backends("dask.array", reason="dask/dask#11800")
  160. def test_zerosized_matrix_arg(self, xp):
  161. # test for gh-4908: check the shape of the result for
  162. # zero-sized matrix inputs, i.e. matrices with shape (0,n) or (n,0).
  163. # note that [[]] takes shape (1,0)
  164. dtype = xp.asarray(1).dtype
  165. a = block_diag(xp.asarray([[1, 0], [0, 1]]),
  166. xp.asarray([[]], dtype=dtype),
  167. xp.asarray([[2, 3], [4, 5], [6, 7]]),
  168. xp.zeros([0, 2], dtype=dtype))
  169. xp_assert_equal(a, xp.asarray([[1, 0, 0, 0, 0, 0],
  170. [0, 1, 0, 0, 0, 0],
  171. [0, 0, 0, 0, 0, 0],
  172. [0, 0, 2, 3, 0, 0],
  173. [0, 0, 4, 5, 0, 0],
  174. [0, 0, 6, 7, 0, 0]]))
  175. class TestHelmert:
  176. def test_orthogonality(self):
  177. for n in range(1, 7):
  178. H = helmert(n, full=True)
  179. Id = np.eye(n)
  180. assert_allclose(H.dot(H.T), Id, atol=1e-12)
  181. assert_allclose(H.T.dot(H), Id, atol=1e-12)
  182. def test_subspace(self):
  183. for n in range(2, 7):
  184. H_full = helmert(n, full=True)
  185. H_partial = helmert(n)
  186. for U in H_full[1:, :].T, H_partial.T:
  187. C = np.eye(n) - np.full((n, n), 1 / n)
  188. assert_allclose(U.dot(U.T), C)
  189. assert_allclose(U.T.dot(U), np.eye(n-1), atol=1e-12)
  190. class TestHilbert:
  191. def test_basic(self):
  192. h3 = array([[1.0, 1/2., 1/3.],
  193. [1/2., 1/3., 1/4.],
  194. [1/3., 1/4., 1/5.]])
  195. assert_array_almost_equal(hilbert(3), h3)
  196. assert_array_equal(hilbert(1), [[1.0]])
  197. h0 = hilbert(0)
  198. assert_equal(h0.shape, (0, 0))
  199. class TestInvHilbert:
  200. def test_basic(self):
  201. invh1 = array([[1]])
  202. assert_array_equal(invhilbert(1, exact=True), invh1)
  203. assert_array_equal(invhilbert(1), invh1)
  204. invh2 = array([[4, -6],
  205. [-6, 12]])
  206. assert_array_equal(invhilbert(2, exact=True), invh2)
  207. assert_array_almost_equal(invhilbert(2), invh2)
  208. invh3 = array([[9, -36, 30],
  209. [-36, 192, -180],
  210. [30, -180, 180]])
  211. assert_array_equal(invhilbert(3, exact=True), invh3)
  212. assert_array_almost_equal(invhilbert(3), invh3)
  213. invh4 = array([[16, -120, 240, -140],
  214. [-120, 1200, -2700, 1680],
  215. [240, -2700, 6480, -4200],
  216. [-140, 1680, -4200, 2800]])
  217. assert_array_equal(invhilbert(4, exact=True), invh4)
  218. assert_array_almost_equal(invhilbert(4), invh4)
  219. invh5 = array([[25, -300, 1050, -1400, 630],
  220. [-300, 4800, -18900, 26880, -12600],
  221. [1050, -18900, 79380, -117600, 56700],
  222. [-1400, 26880, -117600, 179200, -88200],
  223. [630, -12600, 56700, -88200, 44100]])
  224. assert_array_equal(invhilbert(5, exact=True), invh5)
  225. assert_array_almost_equal(invhilbert(5), invh5)
  226. invh17 = array([
  227. [289, -41616, 1976760, -46124400, 629598060, -5540462928,
  228. 33374693352, -143034400080, 446982500250, -1033026222800,
  229. 1774926873720, -2258997839280, 2099709530100, -1384423866000,
  230. 613101997800, -163493866080, 19835652870],
  231. [-41616, 7990272, -426980160, 10627061760, -151103534400,
  232. 1367702848512, -8410422724704, 36616806420480, -115857864064800,
  233. 270465047424000, -468580694662080, 600545887119360,
  234. -561522320049600, 372133135180800, -165537539406000,
  235. 44316454993920, -5395297580640],
  236. [1976760, -426980160, 24337869120, -630981792000, 9228108708000,
  237. -85267724461920, 532660105897920, -2348052711713280,
  238. 7504429831470000, -17664748409880000, 30818191841236800,
  239. -39732544853164800, 37341234283298400, -24857330514030000,
  240. 11100752642520000, -2982128117299200, 364182586693200],
  241. [-46124400, 10627061760, -630981792000, 16826181120000,
  242. -251209625940000, 2358021022156800, -14914482965141760,
  243. 66409571644416000, -214015221119700000, 507295338950400000,
  244. -890303319857952000, 1153715376477081600, -1089119333262870000,
  245. 727848632044800000, -326170262829600000, 87894302404608000,
  246. -10763618673376800],
  247. [629598060, -151103534400, 9228108708000,
  248. -251209625940000, 3810012660090000, -36210360321495360,
  249. 231343968720664800, -1038687206500944000, 3370739732635275000,
  250. -8037460526495400000, 14178080368737885600, -18454939322943942000,
  251. 17489975175339030000, -11728977435138600000, 5272370630081100000,
  252. -1424711708039692800, 174908803442373000],
  253. [-5540462928, 1367702848512, -85267724461920, 2358021022156800,
  254. -36210360321495360, 347619459086355456, -2239409617216035264,
  255. 10124803292907663360, -33052510749726468000,
  256. 79217210949138662400, -140362995650505067440,
  257. 183420385176741672960, -174433352415381259200,
  258. 117339159519533952000, -52892422160973595200,
  259. 14328529177999196160, -1763080738699119840],
  260. [33374693352, -8410422724704, 532660105897920,
  261. -14914482965141760, 231343968720664800, -2239409617216035264,
  262. 14527452132196331328, -66072377044391477760,
  263. 216799987176909536400, -521925895055522958000,
  264. 928414062734059661760, -1217424500995626443520,
  265. 1161358898976091015200, -783401860847777371200,
  266. 354015418167362952000, -96120549902411274240,
  267. 11851820521255194480],
  268. [-143034400080, 36616806420480, -2348052711713280,
  269. 66409571644416000, -1038687206500944000, 10124803292907663360,
  270. -66072377044391477760, 302045152202932469760,
  271. -995510145200094810000, 2405996923185123840000,
  272. -4294704507885446054400, 5649058909023744614400,
  273. -5403874060541811254400, 3654352703663101440000,
  274. -1655137020003255360000, 450325202737117593600,
  275. -55630994283442749600],
  276. [446982500250, -115857864064800, 7504429831470000,
  277. -214015221119700000, 3370739732635275000, -33052510749726468000,
  278. 216799987176909536400, -995510145200094810000,
  279. 3293967392206196062500, -7988661659013106500000,
  280. 14303908928401362270000, -18866974090684772052000,
  281. 18093328327706957325000, -12263364009096700500000,
  282. 5565847995255512250000, -1517208935002984080000,
  283. 187754605706619279900],
  284. [-1033026222800, 270465047424000, -17664748409880000,
  285. 507295338950400000, -8037460526495400000, 79217210949138662400,
  286. -521925895055522958000, 2405996923185123840000,
  287. -7988661659013106500000, 19434404971634224000000,
  288. -34894474126569249192000, 46141453390504792320000,
  289. -44349976506971935800000, 30121928988527376000000,
  290. -13697025107665828500000, 3740200989399948902400,
  291. -463591619028689580000],
  292. [1774926873720, -468580694662080,
  293. 30818191841236800, -890303319857952000, 14178080368737885600,
  294. -140362995650505067440, 928414062734059661760,
  295. -4294704507885446054400, 14303908928401362270000,
  296. -34894474126569249192000, 62810053427824648545600,
  297. -83243376594051600326400, 80177044485212743068000,
  298. -54558343880470209780000, 24851882355348879230400,
  299. -6797096028813368678400, 843736746632215035600],
  300. [-2258997839280, 600545887119360, -39732544853164800,
  301. 1153715376477081600, -18454939322943942000, 183420385176741672960,
  302. -1217424500995626443520, 5649058909023744614400,
  303. -18866974090684772052000, 46141453390504792320000,
  304. -83243376594051600326400, 110552468520163390156800,
  305. -106681852579497947388000, 72720410752415168870400,
  306. -33177973900974346080000, 9087761081682520473600,
  307. -1129631016152221783200],
  308. [2099709530100, -561522320049600, 37341234283298400,
  309. -1089119333262870000, 17489975175339030000,
  310. -174433352415381259200, 1161358898976091015200,
  311. -5403874060541811254400, 18093328327706957325000,
  312. -44349976506971935800000, 80177044485212743068000,
  313. -106681852579497947388000, 103125790826848015808400,
  314. -70409051543137015800000, 32171029219823375700000,
  315. -8824053728865840192000, 1098252376814660067000],
  316. [-1384423866000, 372133135180800,
  317. -24857330514030000, 727848632044800000, -11728977435138600000,
  318. 117339159519533952000, -783401860847777371200,
  319. 3654352703663101440000, -12263364009096700500000,
  320. 30121928988527376000000, -54558343880470209780000,
  321. 72720410752415168870400, -70409051543137015800000,
  322. 48142941226076592000000, -22027500987368499000000,
  323. 6049545098753157120000, -753830033789944188000],
  324. [613101997800, -165537539406000,
  325. 11100752642520000, -326170262829600000, 5272370630081100000,
  326. -52892422160973595200, 354015418167362952000,
  327. -1655137020003255360000, 5565847995255512250000,
  328. -13697025107665828500000, 24851882355348879230400,
  329. -33177973900974346080000, 32171029219823375700000,
  330. -22027500987368499000000, 10091416708498869000000,
  331. -2774765838662800128000, 346146444087219270000],
  332. [-163493866080, 44316454993920, -2982128117299200,
  333. 87894302404608000, -1424711708039692800,
  334. 14328529177999196160, -96120549902411274240,
  335. 450325202737117593600, -1517208935002984080000,
  336. 3740200989399948902400, -6797096028813368678400,
  337. 9087761081682520473600, -8824053728865840192000,
  338. 6049545098753157120000, -2774765838662800128000,
  339. 763806510427609497600, -95382575704033754400],
  340. [19835652870, -5395297580640, 364182586693200, -10763618673376800,
  341. 174908803442373000, -1763080738699119840, 11851820521255194480,
  342. -55630994283442749600, 187754605706619279900,
  343. -463591619028689580000, 843736746632215035600,
  344. -1129631016152221783200, 1098252376814660067000,
  345. -753830033789944188000, 346146444087219270000,
  346. -95382575704033754400, 11922821963004219300]
  347. ])
  348. assert_array_equal(invhilbert(17, exact=True), invh17)
  349. assert_allclose(invhilbert(17), invh17.astype(float), rtol=1e-12)
  350. def test_inverse(self):
  351. for n in range(1, 10):
  352. a = hilbert(n)
  353. b = invhilbert(n)
  354. # The Hilbert matrix is increasingly badly conditioned,
  355. # so take that into account in the test
  356. c = cond(a)
  357. assert_allclose(a.dot(b), eye(n), atol=1e-15*c, rtol=1e-15*c)
  358. class TestPascal:
  359. cases = [
  360. (1, array([[1]]), array([[1]])),
  361. (2, array([[1, 1],
  362. [1, 2]]),
  363. array([[1, 0],
  364. [1, 1]])),
  365. (3, array([[1, 1, 1],
  366. [1, 2, 3],
  367. [1, 3, 6]]),
  368. array([[1, 0, 0],
  369. [1, 1, 0],
  370. [1, 2, 1]])),
  371. (4, array([[1, 1, 1, 1],
  372. [1, 2, 3, 4],
  373. [1, 3, 6, 10],
  374. [1, 4, 10, 20]]),
  375. array([[1, 0, 0, 0],
  376. [1, 1, 0, 0],
  377. [1, 2, 1, 0],
  378. [1, 3, 3, 1]])),
  379. ]
  380. def check_case(self, n, sym, low):
  381. assert_array_equal(pascal(n), sym)
  382. assert_array_equal(pascal(n, kind='lower'), low)
  383. assert_array_equal(pascal(n, kind='upper'), low.T)
  384. assert_array_almost_equal(pascal(n, exact=False), sym)
  385. assert_array_almost_equal(pascal(n, exact=False, kind='lower'), low)
  386. assert_array_almost_equal(pascal(n, exact=False, kind='upper'), low.T)
  387. def test_cases(self):
  388. for n, sym, low in self.cases:
  389. self.check_case(n, sym, low)
  390. def test_big(self):
  391. p = pascal(50)
  392. assert p[-1, -1] == comb(98, 49, exact=True)
  393. def test_threshold(self):
  394. # Regression test. An early version of `pascal` returned an
  395. # array of type np.uint64 for n=35, but that data type is too small
  396. # to hold p[-1, -1]. The second assert_equal below would fail
  397. # because p[-1, -1] overflowed.
  398. p = pascal(34)
  399. assert_equal(2*p.item(-1, -2), p.item(-1, -1), err_msg="n = 34")
  400. p = pascal(35)
  401. assert_equal(2.*p.item(-1, -2), 1.*p.item(-1, -1), err_msg="n = 35")
  402. def test_invpascal():
  403. def check_invpascal(n, kind, exact):
  404. ip = invpascal(n, kind=kind, exact=exact)
  405. p = pascal(n, kind=kind, exact=exact)
  406. # Matrix-multiply ip and p, and check that we get the identity matrix.
  407. # We can't use the simple expression e = ip.dot(p), because when
  408. # n < 35 and exact is True, p.dtype is np.uint64 and ip.dtype is
  409. # np.int64. The product of those dtypes is np.float64, which loses
  410. # precision when n is greater than 18. Instead we'll cast both to
  411. # object arrays, and then multiply.
  412. e = ip.astype(object).dot(p.astype(object))
  413. assert_array_equal(e, eye(n), err_msg=f"n={n} kind={kind!r} exact={exact!r}")
  414. kinds = ['symmetric', 'lower', 'upper']
  415. ns = [1, 2, 5, 18]
  416. for n in ns:
  417. for kind in kinds:
  418. for exact in [True, False]:
  419. check_invpascal(n, kind, exact)
  420. ns = [19, 34, 35, 50]
  421. for n in ns:
  422. for kind in kinds:
  423. check_invpascal(n, kind, True)
  424. def test_dft():
  425. m = dft(2)
  426. expected = array([[1.0, 1.0], [1.0, -1.0]])
  427. assert_array_almost_equal(m, expected)
  428. m = dft(2, scale='n')
  429. assert_array_almost_equal(m, expected/2.0)
  430. m = dft(2, scale='sqrtn')
  431. assert_array_almost_equal(m, expected/sqrt(2.0))
  432. x = array([0, 1, 2, 3, 4, 5, 0, 1])
  433. m = dft(8)
  434. mx = m.dot(x)
  435. fx = fft(x)
  436. assert_array_almost_equal(mx, fx)
  437. @make_xp_test_case(fiedler)
  438. def test_fiedler(xp):
  439. f = fiedler(xp.asarray([]))
  440. assert xp_size(f) == 0
  441. f = fiedler(xp.asarray([123.]))
  442. xp_assert_equal(f, xp.asarray([[0.]]))
  443. f = fiedler(xp.arange(1, 7))
  444. des = xp.asarray([[0, 1, 2, 3, 4, 5],
  445. [1, 0, 1, 2, 3, 4],
  446. [2, 1, 0, 1, 2, 3],
  447. [3, 2, 1, 0, 1, 2],
  448. [4, 3, 2, 1, 0, 1],
  449. [5, 4, 3, 2, 1, 0]])
  450. xp_assert_equal(f, des)
  451. def test_fiedler_companion():
  452. fc = fiedler_companion([])
  453. assert_equal(fc.size, 0)
  454. fc = fiedler_companion([1.])
  455. assert_equal(fc.size, 0)
  456. fc = fiedler_companion([1., 2.])
  457. assert_array_equal(fc, np.array([[-2.]]))
  458. fc = fiedler_companion([1e-12, 2., 3.])
  459. assert_array_almost_equal(fc, companion([1e-12, 2., 3.]))
  460. with assert_raises(ValueError):
  461. fiedler_companion([0, 1, 2])
  462. fc = fiedler_companion([1., -16., 86., -176., 105.])
  463. assert_array_almost_equal(eigvals(fc),
  464. np.array([7., 5., 3., 1.]))
  465. class TestConvolutionMatrix:
  466. """
  467. Test convolution_matrix vs. numpy.convolve for various parameters.
  468. """
  469. def create_vector(self, n, cpx):
  470. """Make a complex or real test vector of length n."""
  471. x = np.linspace(-2.5, 2.2, n)
  472. if cpx:
  473. x = x + 1j*np.linspace(-1.5, 3.1, n)
  474. return x
  475. def test_bad_n(self):
  476. # n must be a positive integer
  477. with pytest.raises(ValueError, match='n must be a positive integer'):
  478. convolution_matrix([1, 2, 3], 0)
  479. def test_empty_first_arg(self):
  480. # first arg must have at least one value
  481. with pytest.raises(ValueError, match=r'len\(a\)'):
  482. convolution_matrix([], 4)
  483. def test_bad_mode(self):
  484. # mode must be in ('full', 'valid', 'same')
  485. with pytest.raises(ValueError, match='mode.*must be one of'):
  486. convolution_matrix((1, 1), 4, mode='invalid argument')
  487. @pytest.mark.parametrize('cpx', [False, True])
  488. @pytest.mark.parametrize('na', [1, 2, 9])
  489. @pytest.mark.parametrize('nv', [1, 2, 9])
  490. @pytest.mark.parametrize('mode', [None, 'full', 'valid', 'same'])
  491. def test_against_numpy_convolve(self, cpx, na, nv, mode):
  492. a = self.create_vector(na, cpx)
  493. v = self.create_vector(nv, cpx)
  494. if mode is None:
  495. y1 = np.convolve(v, a)
  496. A = convolution_matrix(a, nv)
  497. else:
  498. y1 = np.convolve(v, a, mode)
  499. A = convolution_matrix(a, nv, mode)
  500. y2 = A @ v
  501. assert_array_almost_equal(y1, y2)
  502. @pytest.mark.fail_slow(5) # `leslie` has an import in the function
  503. @pytest.mark.parametrize('f, args', [(circulant, ()),
  504. (companion, ()),
  505. (convolution_matrix, (5, 'same')),
  506. (fiedler, ()),
  507. (fiedler_companion, ()),
  508. (hankel, (np.arange(9),)),
  509. (leslie, (np.arange(9),)),
  510. (toeplitz, (np.arange(9),)),
  511. ])
  512. def test_batch(f, args):
  513. rng = np.random.default_rng(283592436523456)
  514. batch_shape = (2, 3)
  515. m = 10
  516. A = rng.random(batch_shape + (m,))
  517. if f in {hankel}:
  518. message = "Beginning in SciPy 1.19, multidimensional input will be..."
  519. with pytest.warns(FutureWarning, match=message):
  520. f(A, *args)
  521. return
  522. res = f(A, *args)
  523. ref = np.asarray([f(a, *args) for a in A.reshape(-1, m)])
  524. ref = ref.reshape(A.shape[:-1] + ref.shape[-2:])
  525. assert_allclose(res, ref)