test_layout.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. """Unit tests for layout functions."""
  2. import pytest
  3. import networkx as nx
  4. np = pytest.importorskip("numpy")
  5. pytest.importorskip("scipy")
  6. class TestLayout:
  7. @classmethod
  8. def setup_class(cls):
  9. cls.Gi = nx.grid_2d_graph(5, 5)
  10. cls.Gs = nx.Graph()
  11. nx.add_path(cls.Gs, "abcdef")
  12. cls.bigG = nx.grid_2d_graph(25, 25) # > 500 nodes for sparse
  13. def test_spring_fixed_without_pos(self):
  14. G = nx.path_graph(4)
  15. # No pos dict at all
  16. with pytest.raises(ValueError, match="nodes are fixed without positions"):
  17. nx.spring_layout(G, fixed=[0])
  18. pos = {0: (1, 1), 2: (0, 0)}
  19. # Node 1 not in pos dict
  20. with pytest.raises(ValueError, match="nodes are fixed without positions"):
  21. nx.spring_layout(G, fixed=[0, 1], pos=pos)
  22. # All fixed nodes in pos dict
  23. out = nx.spring_layout(G, fixed=[0, 2], pos=pos) # No ValueError
  24. assert all(np.array_equal(out[n], pos[n]) for n in (0, 2))
  25. def test_spring_init_pos(self):
  26. # Tests GH #2448
  27. import math
  28. G = nx.Graph()
  29. G.add_edges_from([(0, 1), (1, 2), (2, 0), (2, 3)])
  30. init_pos = {0: (0.0, 0.0)}
  31. fixed_pos = [0]
  32. pos = nx.fruchterman_reingold_layout(G, pos=init_pos, fixed=fixed_pos)
  33. has_nan = any(math.isnan(c) for coords in pos.values() for c in coords)
  34. assert not has_nan, "values should not be nan"
  35. def test_smoke_empty_graph(self):
  36. G = []
  37. nx.random_layout(G)
  38. nx.circular_layout(G)
  39. nx.planar_layout(G)
  40. nx.spring_layout(G)
  41. nx.fruchterman_reingold_layout(G)
  42. nx.spectral_layout(G)
  43. nx.shell_layout(G)
  44. nx.bipartite_layout(G, G)
  45. nx.spiral_layout(G)
  46. nx.multipartite_layout(G)
  47. nx.kamada_kawai_layout(G)
  48. def test_smoke_int(self):
  49. G = self.Gi
  50. nx.random_layout(G)
  51. nx.circular_layout(G)
  52. nx.planar_layout(G)
  53. nx.spring_layout(G)
  54. nx.forceatlas2_layout(G)
  55. nx.fruchterman_reingold_layout(G)
  56. nx.fruchterman_reingold_layout(self.bigG)
  57. nx.spectral_layout(G)
  58. nx.spectral_layout(G.to_directed())
  59. nx.spectral_layout(self.bigG)
  60. nx.spectral_layout(self.bigG.to_directed())
  61. nx.shell_layout(G)
  62. nx.spiral_layout(G)
  63. nx.kamada_kawai_layout(G)
  64. nx.kamada_kawai_layout(G, dim=1)
  65. nx.kamada_kawai_layout(G, dim=3)
  66. nx.arf_layout(G)
  67. def test_smoke_string(self):
  68. G = self.Gs
  69. nx.random_layout(G)
  70. nx.circular_layout(G)
  71. nx.planar_layout(G)
  72. nx.spring_layout(G)
  73. nx.forceatlas2_layout(G)
  74. nx.fruchterman_reingold_layout(G)
  75. nx.spectral_layout(G)
  76. nx.shell_layout(G)
  77. nx.spiral_layout(G)
  78. nx.kamada_kawai_layout(G)
  79. nx.kamada_kawai_layout(G, dim=1)
  80. nx.kamada_kawai_layout(G, dim=3)
  81. nx.arf_layout(G)
  82. def check_scale_and_center(self, pos, scale, center):
  83. center = np.array(center)
  84. low = center - scale
  85. hi = center + scale
  86. vpos = np.array(list(pos.values()))
  87. length = vpos.max(0) - vpos.min(0)
  88. assert (length <= 2 * scale).all()
  89. assert (vpos >= low).all()
  90. assert (vpos <= hi).all()
  91. def test_scale_and_center_arg(self):
  92. sc = self.check_scale_and_center
  93. c = (4, 5)
  94. G = nx.complete_graph(9)
  95. G.add_node(9)
  96. sc(nx.random_layout(G, center=c), scale=0.5, center=(4.5, 5.5))
  97. # rest can have 2*scale length: [-scale, scale]
  98. sc(nx.spring_layout(G, scale=2, center=c), scale=2, center=c)
  99. sc(nx.spectral_layout(G, scale=2, center=c), scale=2, center=c)
  100. sc(nx.circular_layout(G, scale=2, center=c), scale=2, center=c)
  101. sc(nx.shell_layout(G, scale=2, center=c), scale=2, center=c)
  102. sc(nx.spiral_layout(G, scale=2, center=c), scale=2, center=c)
  103. sc(nx.kamada_kawai_layout(G, scale=2, center=c), scale=2, center=c)
  104. c = (2, 3, 5)
  105. sc(nx.kamada_kawai_layout(G, dim=3, scale=2, center=c), scale=2, center=c)
  106. def test_planar_layout_non_planar_input(self):
  107. G = nx.complete_graph(9)
  108. pytest.raises(nx.NetworkXException, nx.planar_layout, G)
  109. def test_smoke_planar_layout_embedding_input(self):
  110. embedding = nx.PlanarEmbedding()
  111. embedding.set_data({0: [1, 2], 1: [0, 2], 2: [0, 1]})
  112. nx.planar_layout(embedding)
  113. def test_default_scale_and_center(self):
  114. sc = self.check_scale_and_center
  115. c = (0, 0)
  116. G = nx.complete_graph(9)
  117. G.add_node(9)
  118. sc(nx.random_layout(G), scale=0.5, center=(0.5, 0.5))
  119. sc(nx.spring_layout(G), scale=1, center=c)
  120. sc(nx.spectral_layout(G), scale=1, center=c)
  121. sc(nx.circular_layout(G), scale=1, center=c)
  122. sc(nx.shell_layout(G), scale=1, center=c)
  123. sc(nx.spiral_layout(G), scale=1, center=c)
  124. sc(nx.kamada_kawai_layout(G), scale=1, center=c)
  125. c = (0, 0, 0)
  126. sc(nx.kamada_kawai_layout(G, dim=3), scale=1, center=c)
  127. def test_circular_planar_and_shell_dim_error(self):
  128. G = nx.path_graph(4)
  129. pytest.raises(ValueError, nx.circular_layout, G, dim=1)
  130. pytest.raises(ValueError, nx.shell_layout, G, dim=1)
  131. pytest.raises(ValueError, nx.shell_layout, G, dim=3)
  132. pytest.raises(ValueError, nx.planar_layout, G, dim=1)
  133. pytest.raises(ValueError, nx.planar_layout, G, dim=3)
  134. def test_adjacency_interface_numpy(self):
  135. A = nx.to_numpy_array(self.Gs)
  136. pos = nx.drawing.layout._fruchterman_reingold(A)
  137. assert pos.shape == (6, 2)
  138. pos = nx.drawing.layout._fruchterman_reingold(A, dim=3)
  139. assert pos.shape == (6, 3)
  140. pos = nx.drawing.layout._sparse_fruchterman_reingold(A)
  141. assert pos.shape == (6, 2)
  142. def test_adjacency_interface_scipy(self):
  143. A = nx.to_scipy_sparse_array(self.Gs, dtype="d")
  144. pos = nx.drawing.layout._sparse_fruchterman_reingold(A)
  145. assert pos.shape == (6, 2)
  146. pos = nx.drawing.layout._sparse_spectral(A)
  147. assert pos.shape == (6, 2)
  148. pos = nx.drawing.layout._sparse_fruchterman_reingold(A, dim=3)
  149. assert pos.shape == (6, 3)
  150. def test_single_nodes(self):
  151. G = nx.path_graph(1)
  152. vpos = nx.shell_layout(G)
  153. assert not vpos[0].any()
  154. G = nx.path_graph(4)
  155. vpos = nx.shell_layout(G, [[0], [1, 2], [3]])
  156. assert not vpos[0].any()
  157. assert vpos[3].any() # ensure node 3 not at origin (#3188)
  158. assert np.linalg.norm(vpos[3]) <= 1 # ensure node 3 fits (#3753)
  159. vpos = nx.shell_layout(G, [[0], [1, 2], [3]], rotate=0)
  160. assert np.linalg.norm(vpos[3]) <= 1 # ensure node 3 fits (#3753)
  161. def test_smoke_initial_pos_forceatlas2(self):
  162. pos = nx.circular_layout(self.Gi)
  163. npos = nx.forceatlas2_layout(self.Gi, pos=pos)
  164. def test_smoke_initial_pos_fruchterman_reingold(self):
  165. pos = nx.circular_layout(self.Gi)
  166. npos = nx.fruchterman_reingold_layout(self.Gi, pos=pos)
  167. def test_smoke_initial_pos_arf(self):
  168. pos = nx.circular_layout(self.Gi)
  169. npos = nx.arf_layout(self.Gi, pos=pos)
  170. def test_fixed_node_fruchterman_reingold(self):
  171. # Dense version (numpy based)
  172. pos = nx.circular_layout(self.Gi)
  173. npos = nx.spring_layout(self.Gi, pos=pos, fixed=[(0, 0)])
  174. assert tuple(pos[(0, 0)]) == tuple(npos[(0, 0)])
  175. # Sparse version (scipy based)
  176. pos = nx.circular_layout(self.bigG)
  177. npos = nx.spring_layout(self.bigG, pos=pos, fixed=[(0, 0)])
  178. for axis in range(2):
  179. assert pos[(0, 0)][axis] == pytest.approx(npos[(0, 0)][axis], abs=1e-7)
  180. def test_center_parameter(self):
  181. G = nx.path_graph(1)
  182. nx.random_layout(G, center=(1, 1))
  183. vpos = nx.circular_layout(G, center=(1, 1))
  184. assert tuple(vpos[0]) == (1, 1)
  185. vpos = nx.planar_layout(G, center=(1, 1))
  186. assert tuple(vpos[0]) == (1, 1)
  187. vpos = nx.spring_layout(G, center=(1, 1))
  188. assert tuple(vpos[0]) == (1, 1)
  189. vpos = nx.fruchterman_reingold_layout(G, center=(1, 1))
  190. assert tuple(vpos[0]) == (1, 1)
  191. vpos = nx.spectral_layout(G, center=(1, 1))
  192. assert tuple(vpos[0]) == (1, 1)
  193. vpos = nx.shell_layout(G, center=(1, 1))
  194. assert tuple(vpos[0]) == (1, 1)
  195. vpos = nx.spiral_layout(G, center=(1, 1))
  196. assert tuple(vpos[0]) == (1, 1)
  197. def test_center_wrong_dimensions(self):
  198. G = nx.path_graph(1)
  199. assert id(nx.spring_layout) == id(nx.fruchterman_reingold_layout)
  200. pytest.raises(ValueError, nx.random_layout, G, center=(1, 1, 1))
  201. pytest.raises(ValueError, nx.circular_layout, G, center=(1, 1, 1))
  202. pytest.raises(ValueError, nx.planar_layout, G, center=(1, 1, 1))
  203. pytest.raises(ValueError, nx.spring_layout, G, center=(1, 1, 1))
  204. pytest.raises(ValueError, nx.spring_layout, G, dim=3, center=(1, 1))
  205. pytest.raises(ValueError, nx.spectral_layout, G, center=(1, 1, 1))
  206. pytest.raises(ValueError, nx.spectral_layout, G, dim=3, center=(1, 1))
  207. pytest.raises(ValueError, nx.shell_layout, G, center=(1, 1, 1))
  208. pytest.raises(ValueError, nx.spiral_layout, G, center=(1, 1, 1))
  209. pytest.raises(ValueError, nx.kamada_kawai_layout, G, center=(1, 1, 1))
  210. def test_empty_graph(self):
  211. G = nx.empty_graph()
  212. vpos = nx.random_layout(G, center=(1, 1))
  213. assert vpos == {}
  214. vpos = nx.circular_layout(G, center=(1, 1))
  215. assert vpos == {}
  216. vpos = nx.planar_layout(G, center=(1, 1))
  217. assert vpos == {}
  218. vpos = nx.bipartite_layout(G, G)
  219. assert vpos == {}
  220. vpos = nx.spring_layout(G, center=(1, 1))
  221. assert vpos == {}
  222. vpos = nx.fruchterman_reingold_layout(G, center=(1, 1))
  223. assert vpos == {}
  224. vpos = nx.spectral_layout(G, center=(1, 1))
  225. assert vpos == {}
  226. vpos = nx.shell_layout(G, center=(1, 1))
  227. assert vpos == {}
  228. vpos = nx.spiral_layout(G, center=(1, 1))
  229. assert vpos == {}
  230. vpos = nx.multipartite_layout(G, center=(1, 1))
  231. assert vpos == {}
  232. vpos = nx.kamada_kawai_layout(G, center=(1, 1))
  233. assert vpos == {}
  234. vpos = nx.forceatlas2_layout(G)
  235. assert vpos == {}
  236. vpos = nx.arf_layout(G)
  237. assert vpos == {}
  238. def test_bipartite_layout(self):
  239. G = nx.complete_bipartite_graph(3, 5)
  240. top, bottom = nx.bipartite.sets(G)
  241. vpos = nx.bipartite_layout(G, top)
  242. assert len(vpos) == len(G)
  243. top_x = vpos[list(top)[0]][0]
  244. bottom_x = vpos[list(bottom)[0]][0]
  245. for node in top:
  246. assert vpos[node][0] == top_x
  247. for node in bottom:
  248. assert vpos[node][0] == bottom_x
  249. vpos = nx.bipartite_layout(
  250. G, top, align="horizontal", center=(2, 2), scale=2, aspect_ratio=1
  251. )
  252. assert len(vpos) == len(G)
  253. top_y = vpos[list(top)[0]][1]
  254. bottom_y = vpos[list(bottom)[0]][1]
  255. for node in top:
  256. assert vpos[node][1] == top_y
  257. for node in bottom:
  258. assert vpos[node][1] == bottom_y
  259. pytest.raises(ValueError, nx.bipartite_layout, G, top, align="foo")
  260. def test_multipartite_layout(self):
  261. sizes = (0, 5, 7, 2, 8)
  262. G = nx.complete_multipartite_graph(*sizes)
  263. vpos = nx.multipartite_layout(G)
  264. assert len(vpos) == len(G)
  265. start = 0
  266. for n in sizes:
  267. end = start + n
  268. assert all(vpos[start][0] == vpos[i][0] for i in range(start + 1, end))
  269. start += n
  270. vpos = nx.multipartite_layout(G, align="horizontal", scale=2, center=(2, 2))
  271. assert len(vpos) == len(G)
  272. start = 0
  273. for n in sizes:
  274. end = start + n
  275. assert all(vpos[start][1] == vpos[i][1] for i in range(start + 1, end))
  276. start += n
  277. pytest.raises(ValueError, nx.multipartite_layout, G, align="foo")
  278. def test_kamada_kawai_costfn_1d(self):
  279. costfn = nx.drawing.layout._kamada_kawai_costfn
  280. pos = np.array([4.0, 7.0])
  281. invdist = 1 / np.array([[0.1, 2.0], [2.0, 0.3]])
  282. cost, grad = costfn(pos, np, invdist, meanweight=0, dim=1)
  283. assert cost == pytest.approx(((3 / 2.0 - 1) ** 2), abs=1e-7)
  284. assert grad[0] == pytest.approx((-0.5), abs=1e-7)
  285. assert grad[1] == pytest.approx(0.5, abs=1e-7)
  286. def check_kamada_kawai_costfn(self, pos, invdist, meanwt, dim):
  287. costfn = nx.drawing.layout._kamada_kawai_costfn
  288. cost, grad = costfn(pos.ravel(), np, invdist, meanweight=meanwt, dim=dim)
  289. expected_cost = 0.5 * meanwt * np.sum(np.sum(pos, axis=0) ** 2)
  290. for i in range(pos.shape[0]):
  291. for j in range(i + 1, pos.shape[0]):
  292. diff = np.linalg.norm(pos[i] - pos[j])
  293. expected_cost += (diff * invdist[i][j] - 1.0) ** 2
  294. assert cost == pytest.approx(expected_cost, abs=1e-7)
  295. dx = 1e-4
  296. for nd in range(pos.shape[0]):
  297. for dm in range(pos.shape[1]):
  298. idx = nd * pos.shape[1] + dm
  299. ps = pos.flatten()
  300. ps[idx] += dx
  301. cplus = costfn(ps, np, invdist, meanweight=meanwt, dim=pos.shape[1])[0]
  302. ps[idx] -= 2 * dx
  303. cminus = costfn(ps, np, invdist, meanweight=meanwt, dim=pos.shape[1])[0]
  304. assert grad[idx] == pytest.approx((cplus - cminus) / (2 * dx), abs=1e-5)
  305. def test_kamada_kawai_costfn(self):
  306. invdist = 1 / np.array([[0.1, 2.1, 1.7], [2.1, 0.2, 0.6], [1.7, 0.6, 0.3]])
  307. meanwt = 0.3
  308. # 2d
  309. pos = np.array([[1.3, -3.2], [2.7, -0.3], [5.1, 2.5]])
  310. self.check_kamada_kawai_costfn(pos, invdist, meanwt, 2)
  311. # 3d
  312. pos = np.array([[0.9, 8.6, -8.7], [-10, -0.5, -7.1], [9.1, -8.1, 1.6]])
  313. self.check_kamada_kawai_costfn(pos, invdist, meanwt, 3)
  314. def test_spiral_layout(self):
  315. G = self.Gs
  316. # a lower value of resolution should result in a more compact layout
  317. # intuitively, the total distance from the start and end nodes
  318. # via each node in between (transiting through each) will be less,
  319. # assuming rescaling does not occur on the computed node positions
  320. pos_standard = np.array(list(nx.spiral_layout(G, resolution=0.35).values()))
  321. pos_tighter = np.array(list(nx.spiral_layout(G, resolution=0.34).values()))
  322. distances = np.linalg.norm(pos_standard[:-1] - pos_standard[1:], axis=1)
  323. distances_tighter = np.linalg.norm(pos_tighter[:-1] - pos_tighter[1:], axis=1)
  324. assert sum(distances) > sum(distances_tighter)
  325. # return near-equidistant points after the first value if set to true
  326. pos_equidistant = np.array(list(nx.spiral_layout(G, equidistant=True).values()))
  327. distances_equidistant = np.linalg.norm(
  328. pos_equidistant[:-1] - pos_equidistant[1:], axis=1
  329. )
  330. assert np.allclose(
  331. distances_equidistant[1:], distances_equidistant[-1], atol=0.01
  332. )
  333. def test_spiral_layout_equidistant(self):
  334. G = nx.path_graph(10)
  335. nx.spiral_layout(G, equidistant=True, store_pos_as="pos")
  336. pos = nx.get_node_attributes(G, "pos")
  337. # Extract individual node positions as an array
  338. p = np.array(list(pos.values()))
  339. # Elementwise-distance between node positions
  340. dist = np.linalg.norm(p[1:] - p[:-1], axis=1)
  341. assert np.allclose(np.diff(dist), 0, atol=1e-3)
  342. def test_forceatlas2_layout_partial_input_test(self):
  343. # check whether partial pos input still returns a full proper position
  344. G = self.Gs
  345. node = nx.utils.arbitrary_element(G)
  346. pos = nx.circular_layout(G)
  347. del pos[node]
  348. pos = nx.forceatlas2_layout(G, pos=pos)
  349. assert len(pos) == len(G)
  350. def test_rescale_layout_dict(self):
  351. G = nx.empty_graph()
  352. vpos = nx.random_layout(G, center=(1, 1))
  353. assert nx.rescale_layout_dict(vpos) == {}
  354. G = nx.empty_graph(2)
  355. vpos = {0: (0.0, 0.0), 1: (1.0, 1.0)}
  356. s_vpos = nx.rescale_layout_dict(vpos)
  357. assert np.linalg.norm([sum(x) for x in zip(*s_vpos.values())]) < 1e-6
  358. G = nx.empty_graph(3)
  359. vpos = {0: (0, 0), 1: (1, 1), 2: (0.5, 0.5)}
  360. s_vpos = nx.rescale_layout_dict(vpos)
  361. expectation = {
  362. 0: np.array((-1, -1)),
  363. 1: np.array((1, 1)),
  364. 2: np.array((0, 0)),
  365. }
  366. for k, v in expectation.items():
  367. assert (s_vpos[k] == v).all()
  368. s_vpos = nx.rescale_layout_dict(vpos, scale=2)
  369. expectation = {
  370. 0: np.array((-2, -2)),
  371. 1: np.array((2, 2)),
  372. 2: np.array((0, 0)),
  373. }
  374. for k, v in expectation.items():
  375. assert (s_vpos[k] == v).all()
  376. def test_arf_layout_partial_input_test(self):
  377. # Checks whether partial pos input still returns a proper position.
  378. G = self.Gs
  379. node = nx.utils.arbitrary_element(G)
  380. pos = nx.circular_layout(G)
  381. del pos[node]
  382. pos = nx.arf_layout(G, pos=pos)
  383. assert len(pos) == len(G)
  384. def test_arf_layout_negative_a_check(self):
  385. """
  386. Checks input parameters correctly raises errors. For example, `a` should be larger than 1
  387. """
  388. G = self.Gs
  389. pytest.raises(ValueError, nx.arf_layout, G=G, a=-1)
  390. def test_smoke_seed_input(self):
  391. G = self.Gs
  392. nx.random_layout(G, seed=42)
  393. nx.spring_layout(G, seed=42)
  394. nx.arf_layout(G, seed=42)
  395. nx.forceatlas2_layout(G, seed=42)
  396. def test_node_at_center(self):
  397. # see gh-7791 avoid divide by zero
  398. G = nx.path_graph(3)
  399. orig_pos = {i: [i - 1, 0.0] for i in range(3)}
  400. new_pos = nx.forceatlas2_layout(G, pos=orig_pos)
  401. def test_initial_only_some_pos(self):
  402. G = nx.path_graph(3)
  403. orig_pos = {i: [i - 1, 0.0] for i in range(2)}
  404. new_pos = nx.forceatlas2_layout(G, pos=orig_pos, seed=42)
  405. def test_multipartite_layout_nonnumeric_partition_labels():
  406. """See gh-5123."""
  407. G = nx.Graph()
  408. G.add_node(0, subset="s0")
  409. G.add_node(1, subset="s0")
  410. G.add_node(2, subset="s1")
  411. G.add_node(3, subset="s1")
  412. G.add_edges_from([(0, 2), (0, 3), (1, 2)])
  413. pos = nx.multipartite_layout(G)
  414. assert len(pos) == len(G)
  415. def test_multipartite_layout_layer_order():
  416. """Return the layers in sorted order if the layers of the multipartite
  417. graph are sortable. See gh-5691"""
  418. G = nx.Graph()
  419. node_group = dict(zip(("a", "b", "c", "d", "e"), (2, 3, 1, 2, 4)))
  420. for node, layer in node_group.items():
  421. G.add_node(node, subset=layer)
  422. # Horizontal alignment, therefore y-coord determines layers
  423. pos = nx.multipartite_layout(G, align="horizontal")
  424. layers = nx.utils.groups(node_group)
  425. pos_from_layers = nx.multipartite_layout(G, align="horizontal", subset_key=layers)
  426. for (n1, p1), (n2, p2) in zip(pos.items(), pos_from_layers.items()):
  427. assert n1 == n2 and (p1 == p2).all()
  428. # Nodes "a" and "d" are in the same layer
  429. assert pos["a"][-1] == pos["d"][-1]
  430. # positions should be sorted according to layer
  431. assert pos["c"][-1] < pos["a"][-1] < pos["b"][-1] < pos["e"][-1]
  432. # Make sure that multipartite_layout still works when layers are not sortable
  433. G.nodes["a"]["subset"] = "layer_0" # Can't sort mixed strs/ints
  434. pos_nosort = nx.multipartite_layout(G) # smoke test: this should not raise
  435. assert pos_nosort.keys() == pos.keys()
  436. def _num_nodes_per_bfs_layer(pos):
  437. """Helper function to extract the number of nodes in each layer of bfs_layout"""
  438. x = np.array(list(pos.values()))[:, 0] # node positions in layered dimension
  439. _, layer_count = np.unique(x, return_counts=True)
  440. return layer_count
  441. @pytest.mark.parametrize("n", range(2, 7))
  442. def test_bfs_layout_complete_graph(n):
  443. """The complete graph should result in two layers: the starting node and
  444. a second layer containing all neighbors."""
  445. G = nx.complete_graph(n)
  446. nx.bfs_layout(G, start=0, store_pos_as="pos")
  447. pos = nx.get_node_attributes(G, "pos")
  448. assert np.array_equal(_num_nodes_per_bfs_layer(pos), [1, n - 1])
  449. def test_bfs_layout_barbell():
  450. G = nx.barbell_graph(5, 3)
  451. # Start in one of the "bells"
  452. pos = nx.bfs_layout(G, start=0)
  453. # start, bell-1, [1] * len(bar)+1, bell-1
  454. expected_nodes_per_layer = [1, 4, 1, 1, 1, 1, 4]
  455. assert np.array_equal(_num_nodes_per_bfs_layer(pos), expected_nodes_per_layer)
  456. # Start in the other "bell" - expect same layer pattern
  457. pos = nx.bfs_layout(G, start=12)
  458. assert np.array_equal(_num_nodes_per_bfs_layer(pos), expected_nodes_per_layer)
  459. # Starting in the center of the bar, expect layers to be symmetric
  460. pos = nx.bfs_layout(G, start=6)
  461. # Expected layers: {6 (start)}, {5, 7}, {4, 8}, {8 nodes from remainder of bells}
  462. expected_nodes_per_layer = [1, 2, 2, 8]
  463. assert np.array_equal(_num_nodes_per_bfs_layer(pos), expected_nodes_per_layer)
  464. def test_bfs_layout_disconnected():
  465. G = nx.complete_graph(5)
  466. G.add_edges_from([(10, 11), (11, 12)])
  467. with pytest.raises(nx.NetworkXError, match="bfs_layout didn't include all nodes"):
  468. nx.bfs_layout(G, start=0)
  469. def test_bipartite_layout_default_nodes_raises_non_bipartite_input():
  470. G = nx.complete_graph(5)
  471. with pytest.raises(nx.NetworkXError, match="Graph is not bipartite"):
  472. nx.bipartite_layout(G)
  473. # No exception if nodes are explicitly specified
  474. pos = nx.bipartite_layout(G, nodes=[2, 3])
  475. def test_bipartite_layout_default_nodes():
  476. G = nx.complete_bipartite_graph(3, 3)
  477. pos = nx.bipartite_layout(G) # no nodes specified
  478. # X coords of nodes should be the same within the bipartite sets
  479. for nodeset in nx.bipartite.sets(G):
  480. xs = [pos[k][0] for k in nodeset]
  481. assert all(x == pytest.approx(xs[0]) for x in xs)
  482. @pytest.mark.parametrize(
  483. "layout",
  484. [
  485. nx.random_layout,
  486. nx.circular_layout,
  487. nx.shell_layout,
  488. nx.spring_layout,
  489. nx.kamada_kawai_layout,
  490. nx.spectral_layout,
  491. nx.planar_layout,
  492. nx.spiral_layout,
  493. nx.forceatlas2_layout,
  494. ],
  495. )
  496. def test_layouts_negative_dim(layout):
  497. """Test all layouts that support dim kwarg handle invalid inputs."""
  498. G = nx.path_graph(4)
  499. valid_err_msgs = "|".join(
  500. [
  501. "negative dimensions.*not allowed",
  502. "can only handle 2",
  503. "cannot handle.*2",
  504. ]
  505. )
  506. with pytest.raises(ValueError, match=valid_err_msgs):
  507. layout(G, dim=-1)
  508. @pytest.mark.parametrize(
  509. ("num_nodes", "expected_method"), [(100, "force"), (501, "energy")]
  510. )
  511. @pytest.mark.parametrize(
  512. "extra_layout_kwargs",
  513. [
  514. {}, # No extra kwargs
  515. {"pos": {0: (0, 0)}, "fixed": [0]}, # Fixed node position
  516. {"dim": 3}, # 3D layout
  517. ],
  518. )
  519. def test_spring_layout_graph_size_heuristic(
  520. num_nodes, expected_method, extra_layout_kwargs
  521. ):
  522. """Expect 'force' layout for n < 500 and 'energy' for n >= 500"""
  523. G = nx.cycle_graph(num_nodes)
  524. # Seeded layout to compare explicit method to one determined by "auto"
  525. seed = 163674319
  526. # Compare explicit method to auto method
  527. expected = nx.spring_layout(
  528. G, method=expected_method, seed=seed, **extra_layout_kwargs
  529. )
  530. actual = nx.spring_layout(G, method="auto", seed=seed, **extra_layout_kwargs)
  531. assert np.allclose(list(expected.values()), list(actual.values()), atol=1e-5)