test_kdtree.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536
  1. # Copyright Anne M. Archibald 2008
  2. # Released under the scipy license
  3. import os
  4. from numpy.testing import (assert_equal, assert_array_equal, assert_,
  5. assert_almost_equal, assert_array_almost_equal,
  6. assert_allclose)
  7. from pytest import raises as assert_raises
  8. import pytest
  9. from platform import python_implementation
  10. import numpy as np
  11. from scipy.spatial import KDTree, Rectangle, distance_matrix, cKDTree
  12. from scipy.spatial._ckdtree import cKDTreeNode
  13. from scipy.spatial import minkowski_distance
  14. import itertools
  15. @pytest.fixture(params=[KDTree, cKDTree])
  16. def kdtree_type(request):
  17. return request.param
  18. def KDTreeTest(kls):
  19. """Class decorator to create test cases for KDTree and cKDTree
  20. Tests use the class variable ``kdtree_type`` as the tree constructor.
  21. """
  22. if not kls.__name__.startswith('_Test'):
  23. raise RuntimeError("Expected a class name starting with _Test")
  24. for tree in (KDTree, cKDTree):
  25. test_name = kls.__name__[1:] + '_' + tree.__name__
  26. if test_name in globals():
  27. raise RuntimeError("Duplicated test name: " + test_name)
  28. # Create a new sub-class with kdtree_type defined
  29. test_case = type(test_name, (kls,), {'kdtree_type': tree})
  30. globals()[test_name] = test_case
  31. return kls
  32. def distance_box(a, b, p, boxsize):
  33. diff = a - b
  34. diff[diff > 0.5 * boxsize] -= boxsize
  35. diff[diff < -0.5 * boxsize] += boxsize
  36. d = minkowski_distance(diff, 0, p)
  37. return d
  38. class ConsistencyTests:
  39. def distance(self, a, b, p):
  40. return minkowski_distance(a, b, p)
  41. def test_nearest(self):
  42. x = self.x
  43. d, i = self.kdtree.query(x, 1)
  44. assert_almost_equal(d**2, np.sum((x-self.data[i])**2))
  45. eps = 1e-8
  46. assert_(np.all(np.sum((self.data-x[np.newaxis, :])**2, axis=1) > d**2-eps))
  47. def test_m_nearest(self):
  48. x = self.x
  49. m = self.m
  50. dd, ii = self.kdtree.query(x, m)
  51. d = np.amax(dd)
  52. i = ii[np.argmax(dd)]
  53. assert_almost_equal(d**2, np.sum((x-self.data[i])**2))
  54. eps = 1e-8
  55. assert_equal(
  56. np.sum(np.sum((self.data-x[np.newaxis, :])**2, axis=1) < d**2+eps),
  57. m,
  58. )
  59. def test_points_near(self):
  60. x = self.x
  61. d = self.d
  62. dd, ii = self.kdtree.query(x, k=self.kdtree.n, distance_upper_bound=d)
  63. eps = 1e-8
  64. hits = 0
  65. for near_d, near_i in zip(dd, ii):
  66. if near_d == np.inf:
  67. continue
  68. hits += 1
  69. assert_almost_equal(near_d**2, np.sum((x-self.data[near_i])**2))
  70. assert_(near_d < d+eps, f"near_d={near_d:g} should be less than {d:g}")
  71. assert_equal(np.sum(self.distance(self.data, x, 2) < d**2+eps), hits)
  72. def test_points_near_l1(self):
  73. x = self.x
  74. d = self.d
  75. dd, ii = self.kdtree.query(x, k=self.kdtree.n, p=1, distance_upper_bound=d)
  76. eps = 1e-8
  77. hits = 0
  78. for near_d, near_i in zip(dd, ii):
  79. if near_d == np.inf:
  80. continue
  81. hits += 1
  82. assert_almost_equal(near_d, self.distance(x, self.data[near_i], 1))
  83. assert_(near_d < d+eps, f"near_d={near_d:g} should be less than {d:g}")
  84. assert_equal(np.sum(self.distance(self.data, x, 1) < d+eps), hits)
  85. def test_points_near_linf(self):
  86. x = self.x
  87. d = self.d
  88. dd, ii = self.kdtree.query(x, k=self.kdtree.n, p=np.inf, distance_upper_bound=d)
  89. eps = 1e-8
  90. hits = 0
  91. for near_d, near_i in zip(dd, ii):
  92. if near_d == np.inf:
  93. continue
  94. hits += 1
  95. assert_almost_equal(near_d, self.distance(x, self.data[near_i], np.inf))
  96. assert_(near_d < d+eps, f"near_d={near_d:g} should be less than {d:g}")
  97. assert_equal(np.sum(self.distance(self.data, x, np.inf) < d+eps), hits)
  98. def test_approx(self):
  99. x = self.x
  100. k = self.k
  101. eps = 0.1
  102. d_real, i_real = self.kdtree.query(x, k)
  103. d, i = self.kdtree.query(x, k, eps=eps)
  104. assert_(np.all(d <= d_real*(1+eps)))
  105. @KDTreeTest
  106. class _Test_random(ConsistencyTests):
  107. def setup_method(self):
  108. self.n = 100
  109. self.m = 4
  110. np.random.seed(1234)
  111. self.data = np.random.randn(self.n, self.m)
  112. self.kdtree = self.kdtree_type(self.data, leafsize=2)
  113. self.x = np.random.randn(self.m)
  114. self.d = 0.2
  115. self.k = 10
  116. @KDTreeTest
  117. class _Test_random_far(_Test_random):
  118. def setup_method(self):
  119. super().setup_method()
  120. self.x = np.random.randn(self.m)+10
  121. @KDTreeTest
  122. class _Test_small(ConsistencyTests):
  123. def setup_method(self):
  124. self.data = np.array([[0, 0, 0],
  125. [0, 0, 1],
  126. [0, 1, 0],
  127. [0, 1, 1],
  128. [1, 0, 0],
  129. [1, 0, 1],
  130. [1, 1, 0],
  131. [1, 1, 1]])
  132. self.kdtree = self.kdtree_type(self.data)
  133. self.n = self.kdtree.n
  134. self.m = self.kdtree.m
  135. np.random.seed(1234)
  136. self.x = np.random.randn(3)
  137. self.d = 0.5
  138. self.k = 4
  139. def test_nearest(self):
  140. assert_array_equal(
  141. self.kdtree.query((0, 0, 0.1), 1),
  142. (0.1, 0))
  143. def test_nearest_two(self):
  144. assert_array_equal(
  145. self.kdtree.query((0, 0, 0.1), 2),
  146. ([0.1, 0.9], [0, 1]))
  147. @KDTreeTest
  148. class _Test_small_nonleaf(_Test_small):
  149. def setup_method(self):
  150. super().setup_method()
  151. self.kdtree = self.kdtree_type(self.data, leafsize=1)
  152. class Test_vectorization_KDTree:
  153. def setup_method(self):
  154. self.data = np.array([[0, 0, 0],
  155. [0, 0, 1],
  156. [0, 1, 0],
  157. [0, 1, 1],
  158. [1, 0, 0],
  159. [1, 0, 1],
  160. [1, 1, 0],
  161. [1, 1, 1]])
  162. self.kdtree = KDTree(self.data)
  163. def test_single_query(self):
  164. d, i = self.kdtree.query(np.array([0, 0, 0]))
  165. assert_(isinstance(d, float))
  166. assert_(np.issubdtype(i, np.signedinteger))
  167. def test_vectorized_query(self):
  168. d, i = self.kdtree.query(np.zeros((2, 4, 3)))
  169. assert_equal(np.shape(d), (2, 4))
  170. assert_equal(np.shape(i), (2, 4))
  171. def test_single_query_multiple_neighbors(self):
  172. s = 23
  173. kk = self.kdtree.n+s
  174. d, i = self.kdtree.query(np.array([0, 0, 0]), k=kk)
  175. assert_equal(np.shape(d), (kk,))
  176. assert_equal(np.shape(i), (kk,))
  177. assert_(np.all(~np.isfinite(d[-s:])))
  178. assert_(np.all(i[-s:] == self.kdtree.n))
  179. def test_vectorized_query_multiple_neighbors(self):
  180. s = 23
  181. kk = self.kdtree.n+s
  182. d, i = self.kdtree.query(np.zeros((2, 4, 3)), k=kk)
  183. assert_equal(np.shape(d), (2, 4, kk))
  184. assert_equal(np.shape(i), (2, 4, kk))
  185. assert_(np.all(~np.isfinite(d[:, :, -s:])))
  186. assert_(np.all(i[:, :, -s:] == self.kdtree.n))
  187. def test_query_raises_for_k_none(self):
  188. x = 1.0
  189. with pytest.raises(ValueError, match="k must be an integer or*"):
  190. self.kdtree.query(x, k=None)
  191. class Test_vectorization_cKDTree:
  192. def setup_method(self):
  193. self.data = np.array([[0, 0, 0],
  194. [0, 0, 1],
  195. [0, 1, 0],
  196. [0, 1, 1],
  197. [1, 0, 0],
  198. [1, 0, 1],
  199. [1, 1, 0],
  200. [1, 1, 1]])
  201. self.kdtree = cKDTree(self.data)
  202. def test_single_query(self):
  203. d, i = self.kdtree.query([0, 0, 0])
  204. assert_(isinstance(d, float))
  205. assert_(isinstance(i, int))
  206. def test_vectorized_query(self):
  207. d, i = self.kdtree.query(np.zeros((2, 4, 3)))
  208. assert_equal(np.shape(d), (2, 4))
  209. assert_equal(np.shape(i), (2, 4))
  210. def test_vectorized_query_noncontiguous_values(self):
  211. np.random.seed(1234)
  212. qs = np.random.randn(3, 1000).T
  213. ds, i_s = self.kdtree.query(qs)
  214. for q, d, i in zip(qs, ds, i_s):
  215. assert_equal(self.kdtree.query(q), (d, i))
  216. def test_single_query_multiple_neighbors(self):
  217. s = 23
  218. kk = self.kdtree.n+s
  219. d, i = self.kdtree.query([0, 0, 0], k=kk)
  220. assert_equal(np.shape(d), (kk,))
  221. assert_equal(np.shape(i), (kk,))
  222. assert_(np.all(~np.isfinite(d[-s:])))
  223. assert_(np.all(i[-s:] == self.kdtree.n))
  224. def test_vectorized_query_multiple_neighbors(self):
  225. s = 23
  226. kk = self.kdtree.n+s
  227. d, i = self.kdtree.query(np.zeros((2, 4, 3)), k=kk)
  228. assert_equal(np.shape(d), (2, 4, kk))
  229. assert_equal(np.shape(i), (2, 4, kk))
  230. assert_(np.all(~np.isfinite(d[:, :, -s:])))
  231. assert_(np.all(i[:, :, -s:] == self.kdtree.n))
  232. class ball_consistency:
  233. tol = 0.0
  234. def distance(self, a, b, p):
  235. return minkowski_distance(a * 1.0, b * 1.0, p)
  236. def test_in_ball(self):
  237. x = np.atleast_2d(self.x)
  238. d = np.broadcast_to(self.d, x.shape[:-1])
  239. l = self.T.query_ball_point(x, self.d, p=self.p, eps=self.eps)
  240. for i, ind in enumerate(l):
  241. dist = self.distance(self.data[ind], x[i], self.p) - d[i]*(1.+self.eps)
  242. norm = self.distance(self.data[ind], x[i], self.p) + d[i]*(1.+self.eps)
  243. assert_array_equal(dist < self.tol * norm, True)
  244. def test_found_all(self):
  245. x = np.atleast_2d(self.x)
  246. d = np.broadcast_to(self.d, x.shape[:-1])
  247. l = self.T.query_ball_point(x, self.d, p=self.p, eps=self.eps)
  248. for i, ind in enumerate(l):
  249. c = np.ones(self.T.n, dtype=bool)
  250. c[ind] = False
  251. dist = self.distance(self.data[c], x[i], self.p) - d[i]/(1.+self.eps)
  252. norm = self.distance(self.data[c], x[i], self.p) + d[i]/(1.+self.eps)
  253. assert_array_equal(dist > -self.tol * norm, True)
  254. @KDTreeTest
  255. class _Test_random_ball(ball_consistency):
  256. def setup_method(self):
  257. n = 100
  258. m = 4
  259. np.random.seed(1234)
  260. self.data = np.random.randn(n, m)
  261. self.T = self.kdtree_type(self.data, leafsize=2)
  262. self.x = np.random.randn(m)
  263. self.p = 2.
  264. self.eps = 0
  265. self.d = 0.2
  266. @KDTreeTest
  267. class _Test_random_ball_periodic(ball_consistency):
  268. def distance(self, a, b, p):
  269. return distance_box(a, b, p, 1.0)
  270. def setup_method(self):
  271. n = 10000
  272. m = 4
  273. np.random.seed(1234)
  274. self.data = np.random.uniform(size=(n, m))
  275. self.T = self.kdtree_type(self.data, leafsize=2, boxsize=1)
  276. self.x = np.full(m, 0.1)
  277. self.p = 2.
  278. self.eps = 0
  279. self.d = 0.2
  280. def test_in_ball_outside(self):
  281. l = self.T.query_ball_point(self.x + 1.0, self.d, p=self.p, eps=self.eps)
  282. for i in l:
  283. assert_(self.distance(self.data[i], self.x, self.p) <= self.d*(1.+self.eps))
  284. l = self.T.query_ball_point(self.x - 1.0, self.d, p=self.p, eps=self.eps)
  285. for i in l:
  286. assert_(self.distance(self.data[i], self.x, self.p) <= self.d*(1.+self.eps))
  287. def test_found_all_outside(self):
  288. c = np.ones(self.T.n, dtype=bool)
  289. l = self.T.query_ball_point(self.x + 1.0, self.d, p=self.p, eps=self.eps)
  290. c[l] = False
  291. assert np.all(
  292. self.distance(self.data[c], self.x, self.p) >= self.d/(1.+self.eps)
  293. )
  294. l = self.T.query_ball_point(self.x - 1.0, self.d, p=self.p, eps=self.eps)
  295. c[l] = False
  296. assert np.all(
  297. self.distance(self.data[c], self.x, self.p) >= self.d/(1.+self.eps)
  298. )
  299. @KDTreeTest
  300. class _Test_random_ball_largep_issue9890(ball_consistency):
  301. # allow some roundoff errors due to numerical issues
  302. tol = 1e-13
  303. def setup_method(self):
  304. n = 1000
  305. m = 2
  306. np.random.seed(123)
  307. self.data = np.random.randint(100, 1000, size=(n, m))
  308. self.T = self.kdtree_type(self.data)
  309. self.x = self.data
  310. self.p = 100
  311. self.eps = 0
  312. self.d = 10
  313. @KDTreeTest
  314. class _Test_random_ball_approx(_Test_random_ball):
  315. def setup_method(self):
  316. super().setup_method()
  317. self.eps = 0.1
  318. @KDTreeTest
  319. class _Test_random_ball_approx_periodic(_Test_random_ball):
  320. def setup_method(self):
  321. super().setup_method()
  322. self.eps = 0.1
  323. @KDTreeTest
  324. class _Test_random_ball_far(_Test_random_ball):
  325. def setup_method(self):
  326. super().setup_method()
  327. self.d = 2.
  328. @KDTreeTest
  329. class _Test_random_ball_far_periodic(_Test_random_ball_periodic):
  330. def setup_method(self):
  331. super().setup_method()
  332. self.d = 2.
  333. @KDTreeTest
  334. class _Test_random_ball_l1(_Test_random_ball):
  335. def setup_method(self):
  336. super().setup_method()
  337. self.p = 1
  338. @KDTreeTest
  339. class _Test_random_ball_linf(_Test_random_ball):
  340. def setup_method(self):
  341. super().setup_method()
  342. self.p = np.inf
  343. def test_random_ball_vectorized(kdtree_type):
  344. n = 20
  345. m = 5
  346. np.random.seed(1234)
  347. T = kdtree_type(np.random.randn(n, m))
  348. r = T.query_ball_point(np.random.randn(2, 3, m), 1)
  349. assert_equal(r.shape, (2, 3))
  350. assert_(isinstance(r[0, 0], list))
  351. @pytest.mark.fail_slow(5)
  352. def test_query_ball_point_multithreading(kdtree_type):
  353. np.random.seed(0)
  354. n = 5000
  355. k = 2
  356. points = np.random.randn(n, k)
  357. T = kdtree_type(points)
  358. l1 = T.query_ball_point(points, 0.003, workers=1)
  359. l2 = T.query_ball_point(points, 0.003, workers=64)
  360. l3 = T.query_ball_point(points, 0.003, workers=-1)
  361. for i in range(n):
  362. if l1[i] or l2[i]:
  363. assert_array_equal(l1[i], l2[i])
  364. for i in range(n):
  365. if l1[i] or l3[i]:
  366. assert_array_equal(l1[i], l3[i])
  367. class two_trees_consistency:
  368. def distance(self, a, b, p):
  369. return minkowski_distance(a, b, p)
  370. def test_all_in_ball(self):
  371. r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
  372. for i, l in enumerate(r):
  373. for j in l:
  374. assert (self.distance(self.data1[i], self.data2[j], self.p)
  375. <= self.d*(1.+self.eps))
  376. def test_found_all(self):
  377. r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
  378. for i, l in enumerate(r):
  379. c = np.ones(self.T2.n, dtype=bool)
  380. c[l] = False
  381. assert np.all(self.distance(self.data2[c], self.data1[i], self.p)
  382. >= self.d/(1.+self.eps))
  383. @KDTreeTest
  384. class _Test_two_random_trees(two_trees_consistency):
  385. def setup_method(self):
  386. n = 50
  387. m = 4
  388. np.random.seed(1234)
  389. self.data1 = np.random.randn(n, m)
  390. self.T1 = self.kdtree_type(self.data1, leafsize=2)
  391. self.data2 = np.random.randn(n, m)
  392. self.T2 = self.kdtree_type(self.data2, leafsize=2)
  393. self.p = 2.
  394. self.eps = 0
  395. self.d = 0.2
  396. @KDTreeTest
  397. class _Test_two_random_trees_periodic(two_trees_consistency):
  398. def distance(self, a, b, p):
  399. return distance_box(a, b, p, 1.0)
  400. def setup_method(self):
  401. n = 50
  402. m = 4
  403. np.random.seed(1234)
  404. self.data1 = np.random.uniform(size=(n, m))
  405. self.T1 = self.kdtree_type(self.data1, leafsize=2, boxsize=1.0)
  406. self.data2 = np.random.uniform(size=(n, m))
  407. self.T2 = self.kdtree_type(self.data2, leafsize=2, boxsize=1.0)
  408. self.p = 2.
  409. self.eps = 0
  410. self.d = 0.2
  411. @KDTreeTest
  412. class _Test_two_random_trees_far(_Test_two_random_trees):
  413. def setup_method(self):
  414. super().setup_method()
  415. self.d = 2
  416. @KDTreeTest
  417. class _Test_two_random_trees_far_periodic(_Test_two_random_trees_periodic):
  418. def setup_method(self):
  419. super().setup_method()
  420. self.d = 2
  421. @KDTreeTest
  422. class _Test_two_random_trees_linf(_Test_two_random_trees):
  423. def setup_method(self):
  424. super().setup_method()
  425. self.p = np.inf
  426. @KDTreeTest
  427. class _Test_two_random_trees_linf_periodic(_Test_two_random_trees_periodic):
  428. def setup_method(self):
  429. super().setup_method()
  430. self.p = np.inf
  431. class Test_rectangle:
  432. def setup_method(self):
  433. self.rect = Rectangle([0, 0], [1, 1])
  434. def test_min_inside(self):
  435. assert_almost_equal(self.rect.min_distance_point([0.5, 0.5]), 0)
  436. def test_min_one_side(self):
  437. assert_almost_equal(self.rect.min_distance_point([0.5, 1.5]), 0.5)
  438. def test_min_two_sides(self):
  439. assert_almost_equal(self.rect.min_distance_point([2, 2]), np.sqrt(2))
  440. def test_max_inside(self):
  441. assert_almost_equal(self.rect.max_distance_point([0.5, 0.5]), 1/np.sqrt(2))
  442. def test_max_one_side(self):
  443. assert_almost_equal(self.rect.max_distance_point([0.5, 1.5]),
  444. np.hypot(0.5, 1.5))
  445. def test_max_two_sides(self):
  446. assert_almost_equal(self.rect.max_distance_point([2, 2]), 2*np.sqrt(2))
  447. def test_split(self):
  448. less, greater = self.rect.split(0, 0.1)
  449. assert_array_equal(less.maxes, [0.1, 1])
  450. assert_array_equal(less.mins, [0, 0])
  451. assert_array_equal(greater.maxes, [1, 1])
  452. assert_array_equal(greater.mins, [0.1, 0])
  453. def test_distance_l2():
  454. assert_almost_equal(minkowski_distance([0, 0], [1, 1], 2), np.sqrt(2))
  455. def test_distance_l1():
  456. assert_almost_equal(minkowski_distance([0, 0], [1, 1], 1), 2)
  457. def test_distance_linf():
  458. assert_almost_equal(minkowski_distance([0, 0], [1, 1], np.inf), 1)
  459. def test_distance_vectorization():
  460. np.random.seed(1234)
  461. x = np.random.randn(10, 1, 3)
  462. y = np.random.randn(1, 7, 3)
  463. assert_equal(minkowski_distance(x, y).shape, (10, 7))
  464. class count_neighbors_consistency:
  465. def test_one_radius(self):
  466. r = 0.2
  467. assert_equal(self.T1.count_neighbors(self.T2, r),
  468. np.sum([len(l) for l in self.T1.query_ball_tree(self.T2, r)]))
  469. def test_large_radius(self):
  470. r = 1000
  471. assert_equal(self.T1.count_neighbors(self.T2, r),
  472. np.sum([len(l) for l in self.T1.query_ball_tree(self.T2, r)]))
  473. def test_multiple_radius(self):
  474. rs = np.exp(np.linspace(np.log(0.01), np.log(10), 3))
  475. results = self.T1.count_neighbors(self.T2, rs)
  476. assert_(np.all(np.diff(results) >= 0))
  477. for r, result in zip(rs, results):
  478. assert_equal(self.T1.count_neighbors(self.T2, r), result)
  479. @KDTreeTest
  480. class _Test_count_neighbors(count_neighbors_consistency):
  481. def setup_method(self):
  482. n = 50
  483. m = 2
  484. np.random.seed(1234)
  485. self.T1 = self.kdtree_type(np.random.randn(n, m), leafsize=2)
  486. self.T2 = self.kdtree_type(np.random.randn(n, m), leafsize=2)
  487. class sparse_distance_matrix_consistency:
  488. def distance(self, a, b, p):
  489. return minkowski_distance(a, b, p)
  490. def test_consistency_with_neighbors(self):
  491. M = self.T1.sparse_distance_matrix(self.T2, self.r)
  492. r = self.T1.query_ball_tree(self.T2, self.r)
  493. for i, l in enumerate(r):
  494. for j in l:
  495. assert_almost_equal(
  496. M[i, j],
  497. self.distance(self.T1.data[i], self.T2.data[j], self.p),
  498. decimal=14
  499. )
  500. for ((i, j), d) in M.items():
  501. assert_(j in r[i])
  502. def test_zero_distance(self):
  503. # raises an exception for bug 870 (FIXME: Does it?)
  504. self.T1.sparse_distance_matrix(self.T1, self.r)
  505. def test_consistency(self):
  506. # Test consistency with a distance_matrix
  507. M1 = self.T1.sparse_distance_matrix(self.T2, self.r)
  508. expected = distance_matrix(self.T1.data, self.T2.data)
  509. expected[expected > self.r] = 0
  510. assert_array_almost_equal(M1.toarray(), expected, decimal=14)
  511. def test_against_logic_error_regression(self):
  512. # regression test for gh-5077 logic error
  513. np.random.seed(0)
  514. too_many = np.array(np.random.randn(18, 2), dtype=int)
  515. tree = self.kdtree_type(
  516. too_many, balanced_tree=False, compact_nodes=False)
  517. d = tree.sparse_distance_matrix(tree, 3).toarray()
  518. assert_array_almost_equal(d, d.T, decimal=14)
  519. def test_ckdtree_return_types(self):
  520. # brute-force reference
  521. ref = np.zeros((self.n, self.n))
  522. for i in range(self.n):
  523. for j in range(self.n):
  524. v = self.data1[i, :] - self.data2[j, :]
  525. ref[i, j] = np.dot(v, v)
  526. ref = np.sqrt(ref)
  527. ref[ref > self.r] = 0.
  528. # test return type 'dict'
  529. dist = np.zeros((self.n, self.n))
  530. r = self.T1.sparse_distance_matrix(self.T2, self.r, output_type='dict')
  531. for i, j in r.keys():
  532. dist[i, j] = r[(i, j)]
  533. assert_array_almost_equal(ref, dist, decimal=14)
  534. # test return type 'ndarray'
  535. dist = np.zeros((self.n, self.n))
  536. r = self.T1.sparse_distance_matrix(self.T2, self.r,
  537. output_type='ndarray')
  538. for k in range(r.shape[0]):
  539. i = r['i'][k]
  540. j = r['j'][k]
  541. v = r['v'][k]
  542. dist[i, j] = v
  543. assert_array_almost_equal(ref, dist, decimal=14)
  544. # test return type 'dok_matrix'
  545. r = self.T1.sparse_distance_matrix(self.T2, self.r,
  546. output_type='dok_matrix')
  547. assert_array_almost_equal(ref, r.toarray(), decimal=14)
  548. # test return type 'coo_matrix'
  549. r = self.T1.sparse_distance_matrix(self.T2, self.r,
  550. output_type='coo_matrix')
  551. assert_array_almost_equal(ref, r.toarray(), decimal=14)
  552. @KDTreeTest
  553. class _Test_sparse_distance_matrix(sparse_distance_matrix_consistency):
  554. def setup_method(self):
  555. n = 50
  556. m = 4
  557. np.random.seed(1234)
  558. data1 = np.random.randn(n, m)
  559. data2 = np.random.randn(n, m)
  560. self.T1 = self.kdtree_type(data1, leafsize=2)
  561. self.T2 = self.kdtree_type(data2, leafsize=2)
  562. self.r = 0.5
  563. self.p = 2
  564. self.data1 = data1
  565. self.data2 = data2
  566. self.n = n
  567. self.m = m
  568. def test_distance_matrix():
  569. m = 10
  570. n = 11
  571. k = 4
  572. np.random.seed(1234)
  573. xs = np.random.randn(m, k)
  574. ys = np.random.randn(n, k)
  575. ds = distance_matrix(xs, ys)
  576. assert_equal(ds.shape, (m, n))
  577. for i in range(m):
  578. for j in range(n):
  579. assert_almost_equal(minkowski_distance(xs[i], ys[j]), ds[i, j])
  580. def test_distance_matrix_looping():
  581. m = 10
  582. n = 11
  583. k = 4
  584. np.random.seed(1234)
  585. xs = np.random.randn(m, k)
  586. ys = np.random.randn(n, k)
  587. ds = distance_matrix(xs, ys)
  588. dsl = distance_matrix(xs, ys, threshold=1)
  589. assert_equal(ds, dsl)
  590. def check_onetree_query(T, d):
  591. r = T.query_ball_tree(T, d)
  592. s = set()
  593. for i, l in enumerate(r):
  594. for j in l:
  595. if i < j:
  596. s.add((i, j))
  597. assert_(s == T.query_pairs(d))
  598. def test_onetree_query(kdtree_type):
  599. np.random.seed(0)
  600. n = 50
  601. k = 4
  602. points = np.random.randn(n, k)
  603. T = kdtree_type(points)
  604. check_onetree_query(T, 0.1)
  605. points = np.random.randn(3*n, k)
  606. points[:n] *= 0.001
  607. points[n:2*n] += 2
  608. T = kdtree_type(points)
  609. check_onetree_query(T, 0.1)
  610. check_onetree_query(T, 0.001)
  611. check_onetree_query(T, 0.00001)
  612. check_onetree_query(T, 1e-6)
  613. def test_query_pairs_single_node(kdtree_type):
  614. tree = kdtree_type([[0, 1]])
  615. assert_equal(tree.query_pairs(0.5), set())
  616. def test_kdtree_query_pairs(kdtree_type):
  617. np.random.seed(0)
  618. n = 50
  619. k = 2
  620. r = 0.1
  621. r2 = r**2
  622. points = np.random.randn(n, k)
  623. T = kdtree_type(points)
  624. # brute force reference
  625. brute = set()
  626. for i in range(n):
  627. for j in range(i+1, n):
  628. v = points[i, :] - points[j, :]
  629. if np.dot(v, v) <= r2:
  630. brute.add((i, j))
  631. l0 = sorted(brute)
  632. # test default return type
  633. s = T.query_pairs(r)
  634. l1 = sorted(s)
  635. assert_array_equal(l0, l1)
  636. # test return type 'set'
  637. s = T.query_pairs(r, output_type='set')
  638. l1 = sorted(s)
  639. assert_array_equal(l0, l1)
  640. # test return type 'ndarray'
  641. s = set()
  642. arr = T.query_pairs(r, output_type='ndarray')
  643. for i in range(arr.shape[0]):
  644. s.add((int(arr[i, 0]), int(arr[i, 1])))
  645. l2 = sorted(s)
  646. assert_array_equal(l0, l2)
  647. def test_query_pairs_eps(kdtree_type):
  648. spacing = np.sqrt(2)
  649. # irrational spacing to have potential rounding errors
  650. x_range = np.linspace(0, 3 * spacing, 4)
  651. y_range = np.linspace(0, 3 * spacing, 4)
  652. xy_array = [(xi, yi) for xi in x_range for yi in y_range]
  653. tree = kdtree_type(xy_array)
  654. pairs_eps = tree.query_pairs(r=spacing, eps=.1)
  655. # result: 24 with eps, 16 without due to rounding
  656. pairs = tree.query_pairs(r=spacing * 1.01)
  657. # result: 24
  658. assert_equal(pairs, pairs_eps)
  659. def test_ball_point_ints(kdtree_type):
  660. # Regression test for #1373.
  661. x, y = np.mgrid[0:4, 0:4]
  662. points = list(zip(x.ravel(), y.ravel()))
  663. tree = kdtree_type(points)
  664. assert_equal(sorted([4, 8, 9, 12]),
  665. sorted(tree.query_ball_point((2, 0), 1)))
  666. points = np.asarray(points, dtype=float)
  667. tree = kdtree_type(points)
  668. assert_equal(sorted([4, 8, 9, 12]),
  669. sorted(tree.query_ball_point((2, 0), 1)))
  670. def test_kdtree_comparisons():
  671. # Regression test: node comparisons were done wrong in 0.12 w/Py3.
  672. nodes = [KDTree.node() for _ in range(3)]
  673. assert_equal(sorted(nodes), sorted(nodes[::-1]))
  674. def test_kdtree_build_modes(kdtree_type):
  675. # check if different build modes for KDTree give similar query results
  676. np.random.seed(0)
  677. n = 5000
  678. k = 4
  679. points = np.random.randn(n, k)
  680. T1 = kdtree_type(points).query(points, k=5)[-1]
  681. T2 = kdtree_type(points, compact_nodes=False).query(points, k=5)[-1]
  682. T3 = kdtree_type(points, balanced_tree=False).query(points, k=5)[-1]
  683. T4 = kdtree_type(points, compact_nodes=False,
  684. balanced_tree=False).query(points, k=5)[-1]
  685. assert_array_equal(T1, T2)
  686. assert_array_equal(T1, T3)
  687. assert_array_equal(T1, T4)
  688. def test_kdtree_pickle(kdtree_type):
  689. # test if it is possible to pickle a KDTree
  690. import pickle
  691. np.random.seed(0)
  692. n = 50
  693. k = 4
  694. points = np.random.randn(n, k)
  695. T1 = kdtree_type(points)
  696. tmp = pickle.dumps(T1)
  697. T2 = pickle.loads(tmp)
  698. T1 = T1.query(points, k=5)[-1]
  699. T2 = T2.query(points, k=5)[-1]
  700. assert_array_equal(T1, T2)
  701. def test_kdtree_pickle_boxsize(kdtree_type):
  702. # test if it is possible to pickle a periodic KDTree
  703. import pickle
  704. np.random.seed(0)
  705. n = 50
  706. k = 4
  707. points = np.random.uniform(size=(n, k))
  708. T1 = kdtree_type(points, boxsize=1.0)
  709. tmp = pickle.dumps(T1)
  710. T2 = pickle.loads(tmp)
  711. T1 = T1.query(points, k=5)[-1]
  712. T2 = T2.query(points, k=5)[-1]
  713. assert_array_equal(T1, T2)
  714. def test_kdtree_copy_data(kdtree_type):
  715. # check if copy_data=True makes the kd-tree
  716. # impervious to data corruption by modification of
  717. # the data arrray
  718. np.random.seed(0)
  719. n = 5000
  720. k = 4
  721. points = np.random.randn(n, k)
  722. T = kdtree_type(points, copy_data=True)
  723. q = points.copy()
  724. T1 = T.query(q, k=5)[-1]
  725. points[...] = np.random.randn(n, k)
  726. T2 = T.query(q, k=5)[-1]
  727. assert_array_equal(T1, T2)
  728. def test_ckdtree_parallel(kdtree_type, monkeypatch):
  729. # check if parallel=True also generates correct query results
  730. np.random.seed(0)
  731. n = 5000
  732. k = 4
  733. points = np.random.randn(n, k)
  734. T = kdtree_type(points)
  735. T1 = T.query(points, k=5, workers=64)[-1]
  736. T2 = T.query(points, k=5, workers=-1)[-1]
  737. T3 = T.query(points, k=5)[-1]
  738. assert_array_equal(T1, T2)
  739. assert_array_equal(T1, T3)
  740. monkeypatch.setattr(os, 'cpu_count', lambda: None)
  741. with pytest.raises(NotImplementedError, match="Cannot determine the"):
  742. T.query(points, 1, workers=-1)
  743. def test_ckdtree_view():
  744. # Check that the nodes can be correctly viewed from Python.
  745. # This test also sanity checks each node in the cKDTree, and
  746. # thus verifies the internal structure of the kd-tree.
  747. np.random.seed(0)
  748. n = 100
  749. k = 4
  750. points = np.random.randn(n, k)
  751. kdtree = cKDTree(points)
  752. # walk the whole kd-tree and sanity check each node
  753. def recurse_tree(n):
  754. assert_(isinstance(n, cKDTreeNode))
  755. if n.split_dim == -1:
  756. assert_(n.lesser is None)
  757. assert_(n.greater is None)
  758. assert_(n.indices.shape[0] <= kdtree.leafsize)
  759. else:
  760. recurse_tree(n.lesser)
  761. recurse_tree(n.greater)
  762. x = n.lesser.data_points[:, n.split_dim]
  763. y = n.greater.data_points[:, n.split_dim]
  764. assert_(x.max() < y.min())
  765. recurse_tree(kdtree.tree)
  766. # check that indices are correctly retrieved
  767. n = kdtree.tree
  768. assert_array_equal(np.sort(n.indices), range(100))
  769. # check that data_points are correctly retrieved
  770. assert_array_equal(kdtree.data[n.indices, :], n.data_points)
  771. # KDTree is specialized to type double points, so no need to make
  772. # a unit test corresponding to test_ball_point_ints()
  773. def test_kdtree_list_k(kdtree_type):
  774. # check kdtree periodic boundary
  775. n = 200
  776. m = 2
  777. klist = [1, 2, 3]
  778. kint = 3
  779. np.random.seed(1234)
  780. data = np.random.uniform(size=(n, m))
  781. kdtree = kdtree_type(data, leafsize=1)
  782. # check agreement between arange(1, k+1) and k
  783. dd, ii = kdtree.query(data, klist)
  784. dd1, ii1 = kdtree.query(data, kint)
  785. assert_equal(dd, dd1)
  786. assert_equal(ii, ii1)
  787. # now check skipping one element
  788. klist = np.array([1, 3])
  789. kint = 3
  790. dd, ii = kdtree.query(data, kint)
  791. dd1, ii1 = kdtree.query(data, klist)
  792. assert_equal(dd1, dd[..., klist - 1])
  793. assert_equal(ii1, ii[..., klist - 1])
  794. # check k == 1 special case
  795. # and k == [1] non-special case
  796. dd, ii = kdtree.query(data, 1)
  797. dd1, ii1 = kdtree.query(data, [1])
  798. assert_equal(len(dd.shape), 1)
  799. assert_equal(len(dd1.shape), 2)
  800. assert_equal(dd, np.ravel(dd1))
  801. assert_equal(ii, np.ravel(ii1))
  802. @pytest.mark.fail_slow(10)
  803. def test_kdtree_box(kdtree_type):
  804. # check ckdtree periodic boundary
  805. n = 2000
  806. m = 3
  807. k = 3
  808. np.random.seed(1234)
  809. data = np.random.uniform(size=(n, m))
  810. kdtree = kdtree_type(data, leafsize=1, boxsize=1.0)
  811. # use the standard python KDTree for the simulated periodic box
  812. kdtree2 = kdtree_type(data, leafsize=1)
  813. for p in [1, 2, 3.0, np.inf]:
  814. dd, ii = kdtree.query(data, k, p=p)
  815. dd1, ii1 = kdtree.query(data + 1.0, k, p=p)
  816. assert_almost_equal(dd, dd1)
  817. assert_equal(ii, ii1)
  818. dd1, ii1 = kdtree.query(data - 1.0, k, p=p)
  819. assert_almost_equal(dd, dd1)
  820. assert_equal(ii, ii1)
  821. dd2, ii2 = simulate_periodic_box(kdtree2, data, k, boxsize=1.0, p=p)
  822. assert_almost_equal(dd, dd2)
  823. assert_equal(ii, ii2)
  824. def test_kdtree_box_0boxsize(kdtree_type):
  825. # check ckdtree periodic boundary that mimics non-periodic
  826. n = 2000
  827. m = 2
  828. k = 3
  829. np.random.seed(1234)
  830. data = np.random.uniform(size=(n, m))
  831. kdtree = kdtree_type(data, leafsize=1, boxsize=0.0)
  832. # use the standard python KDTree for the simulated periodic box
  833. kdtree2 = kdtree_type(data, leafsize=1)
  834. for p in [1, 2, np.inf]:
  835. dd, ii = kdtree.query(data, k, p=p)
  836. dd1, ii1 = kdtree2.query(data, k, p=p)
  837. assert_almost_equal(dd, dd1)
  838. assert_equal(ii, ii1)
  839. def test_kdtree_box_upper_bounds(kdtree_type):
  840. data = np.linspace(0, 2, 10).reshape(-1, 2)
  841. data[:, 1] += 10
  842. with pytest.raises(ValueError):
  843. kdtree_type(data, leafsize=1, boxsize=1.0)
  844. with pytest.raises(ValueError):
  845. kdtree_type(data, leafsize=1, boxsize=(0.0, 2.0))
  846. # skip a dimension.
  847. kdtree_type(data, leafsize=1, boxsize=(2.0, 0.0))
  848. def test_kdtree_box_lower_bounds(kdtree_type):
  849. data = np.linspace(-1, 1, 10)
  850. assert_raises(ValueError, kdtree_type, data, leafsize=1, boxsize=1.0)
  851. def simulate_periodic_box(kdtree, data, k, boxsize, p):
  852. dd = []
  853. ii = []
  854. x = np.arange(3 ** data.shape[1])
  855. nn = np.array(np.unravel_index(x, [3] * data.shape[1])).T
  856. nn = nn - 1.0
  857. for n in nn:
  858. image = data + n * 1.0 * boxsize
  859. dd2, ii2 = kdtree.query(image, k, p=p)
  860. dd2 = dd2.reshape(-1, k)
  861. ii2 = ii2.reshape(-1, k)
  862. dd.append(dd2)
  863. ii.append(ii2)
  864. dd = np.concatenate(dd, axis=-1)
  865. ii = np.concatenate(ii, axis=-1)
  866. result = np.empty([len(data), len(nn) * k], dtype=[
  867. ('ii', 'i8'),
  868. ('dd', 'f8')])
  869. result['ii'][:] = ii
  870. result['dd'][:] = dd
  871. result.sort(order='dd')
  872. return result['dd'][:, :k], result['ii'][:, :k]
  873. @pytest.mark.skipif(python_implementation() == 'PyPy',
  874. reason="Fails on PyPy CI runs. See #9507")
  875. def test_ckdtree_memuse():
  876. # unit test adaptation of gh-5630
  877. # NOTE: this will fail when run via valgrind,
  878. # because rss is no longer a reliable memory usage indicator.
  879. try:
  880. import resource
  881. except ImportError:
  882. # resource is not available on Windows
  883. return
  884. # Make some data
  885. dx, dy = 0.05, 0.05
  886. y, x = np.mgrid[slice(1, 5 + dy, dy),
  887. slice(1, 5 + dx, dx)]
  888. z = np.sin(x)**10 + np.cos(10 + y*x) * np.cos(x)
  889. z_copy = np.empty_like(z)
  890. z_copy[:] = z
  891. # Place FILLVAL in z_copy at random number of random locations
  892. FILLVAL = 99.
  893. mask = np.random.randint(0, z.size, np.random.randint(50) + 5)
  894. z_copy.flat[mask] = FILLVAL
  895. igood = np.vstack(np.nonzero(x != FILLVAL)).T
  896. ibad = np.vstack(np.nonzero(x == FILLVAL)).T
  897. mem_use = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  898. # burn-in
  899. for i in range(10):
  900. tree = cKDTree(igood)
  901. # count memleaks while constructing and querying cKDTree
  902. num_leaks = 0
  903. for i in range(100):
  904. mem_use = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  905. tree = cKDTree(igood)
  906. dist, iquery = tree.query(ibad, k=4, p=2)
  907. new_mem_use = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  908. if new_mem_use > mem_use:
  909. num_leaks += 1
  910. # ideally zero leaks, but errors might accidentally happen
  911. # outside cKDTree
  912. assert_(num_leaks < 10)
  913. def test_kdtree_weights(kdtree_type):
  914. data = np.linspace(0, 1, 4).reshape(-1, 1)
  915. tree1 = kdtree_type(data, leafsize=1)
  916. weights = np.ones(len(data), dtype='f4')
  917. nw = tree1._build_weights(weights)
  918. assert_array_equal(nw, [4, 2, 1, 1, 2, 1, 1])
  919. assert_raises(ValueError, tree1._build_weights, weights[:-1])
  920. for i in range(10):
  921. # since weights are uniform, these shall agree:
  922. c1 = tree1.count_neighbors(tree1, np.linspace(0, 10, i))
  923. c2 = tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  924. weights=(weights, weights))
  925. c3 = tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  926. weights=(weights, None))
  927. c4 = tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  928. weights=(None, weights))
  929. tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  930. weights=weights)
  931. assert_array_equal(c1, c2)
  932. assert_array_equal(c1, c3)
  933. assert_array_equal(c1, c4)
  934. for i in range(len(data)):
  935. # this tests removal of one data point by setting weight to 0
  936. w1 = weights.copy()
  937. w1[i] = 0
  938. data2 = data[w1 != 0]
  939. tree2 = kdtree_type(data2)
  940. c1 = tree1.count_neighbors(tree1, np.linspace(0, 10, 100),
  941. weights=(w1, w1))
  942. # "c2 is correct"
  943. c2 = tree2.count_neighbors(tree2, np.linspace(0, 10, 100))
  944. assert_array_equal(c1, c2)
  945. #this asserts for two different trees, singular weights
  946. # crashes
  947. assert_raises(ValueError, tree1.count_neighbors,
  948. tree2, np.linspace(0, 10, 100), weights=w1)
  949. @pytest.mark.fail_slow(10)
  950. def test_kdtree_count_neighbous_multiple_r(kdtree_type):
  951. n = 2000
  952. m = 2
  953. np.random.seed(1234)
  954. data = np.random.normal(size=(n, m))
  955. kdtree = kdtree_type(data, leafsize=1)
  956. r0 = [0, 0.01, 0.01, 0.02, 0.05]
  957. i0 = np.arange(len(r0))
  958. n0 = kdtree.count_neighbors(kdtree, r0)
  959. nnc = kdtree.count_neighbors(kdtree, r0, cumulative=False)
  960. assert_equal(n0, nnc.cumsum())
  961. for i, r in zip(itertools.permutations(i0),
  962. itertools.permutations(r0)):
  963. # permute n0 by i and it shall agree
  964. n = kdtree.count_neighbors(kdtree, r)
  965. assert_array_equal(n, n0[list(i)])
  966. def test_len0_arrays(kdtree_type):
  967. # make sure len-0 arrays are handled correctly
  968. # in range queries (gh-5639)
  969. rng = np.random.RandomState(1234)
  970. X = rng.rand(10, 2)
  971. Y = rng.rand(10, 2)
  972. tree = kdtree_type(X)
  973. # query_ball_point (single)
  974. d, i = tree.query([.5, .5], k=1)
  975. z = tree.query_ball_point([.5, .5], 0.1*d)
  976. assert_array_equal(z, [])
  977. # query_ball_point (multiple)
  978. d, i = tree.query(Y, k=1)
  979. mind = d.min()
  980. z = tree.query_ball_point(Y, 0.1*mind)
  981. y = np.empty(shape=(10, ), dtype=object)
  982. y.fill([])
  983. assert_array_equal(y, z)
  984. # query_ball_tree
  985. other = kdtree_type(Y)
  986. y = tree.query_ball_tree(other, 0.1*mind)
  987. assert_array_equal(10*[[]], y)
  988. # count_neighbors
  989. y = tree.count_neighbors(other, 0.1*mind)
  990. assert_(y == 0)
  991. # sparse_distance_matrix
  992. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='dok_matrix')
  993. assert_array_equal(y == np.zeros((10, 10)), True)
  994. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='coo_matrix')
  995. assert_array_equal(y == np.zeros((10, 10)), True)
  996. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='dict')
  997. assert_equal(y, {})
  998. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='ndarray')
  999. _dtype = [('i', np.intp), ('j', np.intp), ('v', np.float64)]
  1000. res_dtype = np.dtype(_dtype, align=True)
  1001. z = np.empty(shape=(0, ), dtype=res_dtype)
  1002. assert_array_equal(y, z)
  1003. # query_pairs
  1004. d, i = tree.query(X, k=2)
  1005. mind = d[:, -1].min()
  1006. y = tree.query_pairs(0.1*mind, output_type='set')
  1007. assert_equal(y, set())
  1008. y = tree.query_pairs(0.1*mind, output_type='ndarray')
  1009. z = np.empty(shape=(0, 2), dtype=np.intp)
  1010. assert_array_equal(y, z)
  1011. def test_kdtree_duplicated_inputs(kdtree_type):
  1012. # check kdtree with duplicated inputs
  1013. n = 1024
  1014. for m in range(1, 8):
  1015. data = np.ones((n, m))
  1016. data[n//2:] = 2
  1017. for balanced, compact in itertools.product((False, True), repeat=2):
  1018. kdtree = kdtree_type(data, balanced_tree=balanced,
  1019. compact_nodes=compact, leafsize=1)
  1020. assert kdtree.size == 3
  1021. tree = (kdtree.tree if kdtree_type is cKDTree else
  1022. kdtree.tree._node)
  1023. assert_equal(
  1024. np.sort(tree.lesser.indices),
  1025. np.arange(0, n // 2))
  1026. assert_equal(
  1027. np.sort(tree.greater.indices),
  1028. np.arange(n // 2, n))
  1029. def test_kdtree_noncumulative_nondecreasing(kdtree_type):
  1030. # check kdtree with duplicated inputs
  1031. # it shall not divide more than 3 nodes.
  1032. # root left (1), and right (2)
  1033. kdtree = kdtree_type([[0]], leafsize=1)
  1034. assert_raises(ValueError, kdtree.count_neighbors,
  1035. kdtree, [0.1, 0], cumulative=False)
  1036. def test_short_knn(kdtree_type):
  1037. # The test case is based on github: #6425 by @SteveDoyle2
  1038. xyz = np.array([
  1039. [0., 0., 0.],
  1040. [1.01, 0., 0.],
  1041. [0., 1., 0.],
  1042. [0., 1.01, 0.],
  1043. [1., 0., 0.],
  1044. [1., 1., 0.]],
  1045. dtype='float64')
  1046. ckdt = kdtree_type(xyz)
  1047. deq, ieq = ckdt.query(xyz, k=4, distance_upper_bound=0.2)
  1048. assert_array_almost_equal(deq,
  1049. [[0., np.inf, np.inf, np.inf],
  1050. [0., 0.01, np.inf, np.inf],
  1051. [0., 0.01, np.inf, np.inf],
  1052. [0., 0.01, np.inf, np.inf],
  1053. [0., 0.01, np.inf, np.inf],
  1054. [0., np.inf, np.inf, np.inf]])
  1055. def test_query_ball_point_vector_r(kdtree_type):
  1056. np.random.seed(1234)
  1057. data = np.random.normal(size=(100, 3))
  1058. query = np.random.normal(size=(100, 3))
  1059. tree = kdtree_type(data)
  1060. d = np.random.uniform(0, 0.3, size=len(query))
  1061. rvector = tree.query_ball_point(query, d)
  1062. rscalar = [tree.query_ball_point(qi, di) for qi, di in zip(query, d)]
  1063. for a, b in zip(rvector, rscalar):
  1064. assert_array_equal(sorted(a), sorted(b))
  1065. def test_query_ball_point_length(kdtree_type):
  1066. np.random.seed(1234)
  1067. data = np.random.normal(size=(100, 3))
  1068. query = np.random.normal(size=(100, 3))
  1069. tree = kdtree_type(data)
  1070. d = 0.3
  1071. length = tree.query_ball_point(query, d, return_length=True)
  1072. length2 = [len(ind) for ind in tree.query_ball_point(query, d, return_length=False)]
  1073. length3 = [len(tree.query_ball_point(qi, d)) for qi in query]
  1074. length4 = [tree.query_ball_point(qi, d, return_length=True) for qi in query]
  1075. assert_array_equal(length, length2)
  1076. assert_array_equal(length, length3)
  1077. assert_array_equal(length, length4)
  1078. def test_discontiguous(kdtree_type):
  1079. np.random.seed(1234)
  1080. data = np.random.normal(size=(100, 3))
  1081. d_contiguous = np.arange(100) * 0.04
  1082. d_discontiguous = np.ascontiguousarray(
  1083. np.arange(100)[::-1] * 0.04)[::-1]
  1084. query_contiguous = np.random.normal(size=(100, 3))
  1085. query_discontiguous = np.ascontiguousarray(query_contiguous.T).T
  1086. assert query_discontiguous.strides[-1] != query_contiguous.strides[-1]
  1087. assert d_discontiguous.strides[-1] != d_contiguous.strides[-1]
  1088. tree = kdtree_type(data)
  1089. length1 = tree.query_ball_point(query_contiguous,
  1090. d_contiguous, return_length=True)
  1091. length2 = tree.query_ball_point(query_discontiguous,
  1092. d_discontiguous, return_length=True)
  1093. assert_array_equal(length1, length2)
  1094. d1, i1 = tree.query(query_contiguous, 1)
  1095. d2, i2 = tree.query(query_discontiguous, 1)
  1096. assert_array_equal(d1, d2)
  1097. assert_array_equal(i1, i2)
  1098. @pytest.mark.parametrize("balanced_tree, compact_nodes",
  1099. [(True, False),
  1100. (True, True),
  1101. (False, False),
  1102. (False, True)])
  1103. def test_kdtree_empty_input(kdtree_type, balanced_tree, compact_nodes):
  1104. # https://github.com/scipy/scipy/issues/5040
  1105. np.random.seed(1234)
  1106. empty_v3 = np.empty(shape=(0, 3))
  1107. query_v3 = np.ones(shape=(1, 3))
  1108. query_v2 = np.ones(shape=(2, 3))
  1109. tree = kdtree_type(empty_v3, balanced_tree=balanced_tree,
  1110. compact_nodes=compact_nodes)
  1111. length = tree.query_ball_point(query_v3, 0.3, return_length=True)
  1112. assert length == 0
  1113. dd, ii = tree.query(query_v2, 2)
  1114. assert ii.shape == (2, 2)
  1115. assert dd.shape == (2, 2)
  1116. assert np.isinf(dd).all()
  1117. N = tree.count_neighbors(tree, [0, 1])
  1118. assert_array_equal(N, [0, 0])
  1119. M = tree.sparse_distance_matrix(tree, 0.3)
  1120. assert M.shape == (0, 0)
  1121. @KDTreeTest
  1122. class _Test_sorted_query_ball_point:
  1123. def setup_method(self):
  1124. np.random.seed(1234)
  1125. self.x = np.random.randn(100, 1)
  1126. self.ckdt = self.kdtree_type(self.x)
  1127. def test_return_sorted_True(self):
  1128. idxs_list = self.ckdt.query_ball_point(self.x, 1., return_sorted=True)
  1129. for idxs in idxs_list:
  1130. assert_array_equal(idxs, sorted(idxs))
  1131. for xi in self.x:
  1132. idxs = self.ckdt.query_ball_point(xi, 1., return_sorted=True)
  1133. assert_array_equal(idxs, sorted(idxs))
  1134. def test_return_sorted_None(self):
  1135. """Previous behavior was to sort the returned indices if there were
  1136. multiple points per query but not sort them if there was a single point
  1137. per query."""
  1138. idxs_list = self.ckdt.query_ball_point(self.x, 1.)
  1139. for idxs in idxs_list:
  1140. assert_array_equal(idxs, sorted(idxs))
  1141. idxs_list_single = [self.ckdt.query_ball_point(xi, 1.) for xi in self.x]
  1142. idxs_list_False = self.ckdt.query_ball_point(self.x, 1., return_sorted=False)
  1143. for idxs0, idxs1 in zip(idxs_list_False, idxs_list_single):
  1144. assert_array_equal(idxs0, idxs1)
  1145. def test_kdtree_complex_data():
  1146. # Test that KDTree rejects complex input points (gh-9108)
  1147. points = np.random.rand(10, 2).view(complex)
  1148. with pytest.raises(TypeError, match="complex data"):
  1149. t = KDTree(points)
  1150. t = KDTree(points.real)
  1151. with pytest.raises(TypeError, match="complex data"):
  1152. t.query(points)
  1153. with pytest.raises(TypeError, match="complex data"):
  1154. t.query_ball_point(points, r=1)
  1155. def test_kdtree_tree_access():
  1156. # Test KDTree.tree can be used to traverse the KDTree
  1157. np.random.seed(1234)
  1158. points = np.random.rand(100, 4)
  1159. t = KDTree(points)
  1160. root = t.tree
  1161. assert isinstance(root, KDTree.innernode)
  1162. assert root.children == points.shape[0]
  1163. # Visit the tree and assert some basic properties for each node
  1164. nodes = [root]
  1165. while nodes:
  1166. n = nodes.pop(-1)
  1167. if isinstance(n, KDTree.leafnode):
  1168. assert isinstance(n.children, int)
  1169. assert n.children == len(n.idx)
  1170. assert_array_equal(points[n.idx], n._node.data_points)
  1171. else:
  1172. assert isinstance(n, KDTree.innernode)
  1173. assert isinstance(n.split_dim, int)
  1174. assert 0 <= n.split_dim < t.m
  1175. assert isinstance(n.split, float)
  1176. assert isinstance(n.children, int)
  1177. assert n.children == n.less.children + n.greater.children
  1178. nodes.append(n.greater)
  1179. nodes.append(n.less)
  1180. def test_kdtree_attributes():
  1181. # Test KDTree's attributes are available
  1182. np.random.seed(1234)
  1183. points = np.random.rand(100, 4)
  1184. t = KDTree(points)
  1185. assert isinstance(t.m, int)
  1186. assert t.n == points.shape[0]
  1187. assert isinstance(t.n, int)
  1188. assert t.m == points.shape[1]
  1189. assert isinstance(t.leafsize, int)
  1190. assert t.leafsize == 10
  1191. assert_array_equal(t.maxes, np.amax(points, axis=0))
  1192. assert_array_equal(t.mins, np.amin(points, axis=0))
  1193. assert t.data is points
  1194. @pytest.mark.parametrize("kdtree_class", [KDTree, cKDTree])
  1195. def test_kdtree_count_neighbors_weighted(kdtree_class):
  1196. rng = np.random.RandomState(1234)
  1197. r = np.arange(0.05, 1, 0.05)
  1198. A = rng.random(21).reshape((7,3))
  1199. B = rng.random(45).reshape((15,3))
  1200. wA = rng.random(7)
  1201. wB = rng.random(15)
  1202. kdA = kdtree_class(A)
  1203. kdB = kdtree_class(B)
  1204. nAB = kdA.count_neighbors(kdB, r, cumulative=False, weights=(wA,wB))
  1205. # Compare against brute-force
  1206. weights = wA[None, :] * wB[:, None]
  1207. dist = np.linalg.norm(A[None, :, :] - B[:, None, :], axis=-1)
  1208. expect = [np.sum(weights[(prev_radius < dist) & (dist <= radius)])
  1209. for prev_radius, radius in zip(itertools.chain([0], r[:-1]), r)]
  1210. assert_allclose(nAB, expect)
  1211. def test_kdtree_nan():
  1212. vals = [1, 5, -10, 7, -4, -16, -6, 6, 3, -11]
  1213. n = len(vals)
  1214. data = np.concatenate([vals, np.full(n, np.nan)])[:, None]
  1215. with pytest.raises(ValueError, match="must be finite"):
  1216. KDTree(data)
  1217. def test_nonfinite_inputs_gh_18223():
  1218. rng = np.random.default_rng(12345)
  1219. coords = rng.uniform(size=(100, 3), low=0.0, high=0.1)
  1220. t = KDTree(coords, balanced_tree=False, compact_nodes=False)
  1221. bad_coord = [np.nan for _ in range(3)]
  1222. with pytest.raises(ValueError, match="must be finite"):
  1223. t.query(bad_coord)
  1224. with pytest.raises(ValueError, match="must be finite"):
  1225. t.query_ball_point(bad_coord, 1)
  1226. coords[0, :] = np.nan
  1227. with pytest.raises(ValueError, match="must be finite"):
  1228. KDTree(coords, balanced_tree=True, compact_nodes=False)
  1229. with pytest.raises(ValueError, match="must be finite"):
  1230. KDTree(coords, balanced_tree=False, compact_nodes=True)
  1231. with pytest.raises(ValueError, match="must be finite"):
  1232. KDTree(coords, balanced_tree=True, compact_nodes=True)
  1233. with pytest.raises(ValueError, match="must be finite"):
  1234. KDTree(coords, balanced_tree=False, compact_nodes=False)
  1235. @pytest.mark.parametrize("incantation", [cKDTree, KDTree])
  1236. def test_gh_18800(incantation):
  1237. # our prohibition on non-finite values
  1238. # in kd-tree workflows means we need
  1239. # coercion to NumPy arrays enforced
  1240. class ArrLike(np.ndarray):
  1241. def __new__(cls, input_array):
  1242. obj = np.asarray(input_array).view(cls)
  1243. # we override all() to mimic the problem
  1244. # pandas DataFrames encountered in gh-18800
  1245. obj.all = None
  1246. return obj
  1247. def __array_finalize__(self, obj):
  1248. if obj is None:
  1249. return
  1250. self.all = getattr(obj, 'all', None)
  1251. points = [
  1252. [66.22, 32.54],
  1253. [22.52, 22.39],
  1254. [31.01, 81.21],
  1255. ]
  1256. arr = np.array(points)
  1257. arr_like = ArrLike(arr)
  1258. tree = incantation(points, 10)
  1259. tree.query(arr_like, 1)
  1260. tree.query_ball_point(arr_like, 200)