nx_pylab.py 101 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978
  1. """
  2. **********
  3. Matplotlib
  4. **********
  5. Draw networks with matplotlib.
  6. Examples
  7. --------
  8. >>> G = nx.complete_graph(5)
  9. >>> nx.draw(G)
  10. See Also
  11. --------
  12. - :doc:`matplotlib <matplotlib:index>`
  13. - :func:`matplotlib.pyplot.scatter`
  14. - :obj:`matplotlib.patches.FancyArrowPatch`
  15. """
  16. import collections
  17. import itertools
  18. import math
  19. from numbers import Number
  20. import networkx as nx
  21. __all__ = [
  22. "display",
  23. "apply_matplotlib_colors",
  24. "draw",
  25. "draw_networkx",
  26. "draw_networkx_nodes",
  27. "draw_networkx_edges",
  28. "draw_networkx_labels",
  29. "draw_networkx_edge_labels",
  30. "draw_bipartite",
  31. "draw_circular",
  32. "draw_kamada_kawai",
  33. "draw_random",
  34. "draw_spectral",
  35. "draw_spring",
  36. "draw_planar",
  37. "draw_shell",
  38. "draw_forceatlas2",
  39. ]
  40. def apply_matplotlib_colors(
  41. G, src_attr, dest_attr, map, vmin=None, vmax=None, nodes=True
  42. ):
  43. """
  44. Apply colors from a matplotlib colormap to a graph.
  45. Reads values from the `src_attr` and use a matplotlib colormap
  46. to produce a color. Write the color to `dest_attr`.
  47. Parameters
  48. ----------
  49. G : nx.Graph
  50. The graph to read and compute colors for.
  51. src_attr : str or other attribute name
  52. The name of the attribute to read from the graph.
  53. dest_attr : str or other attribute name
  54. The name of the attribute to write to on the graph.
  55. map : matplotlib.colormap
  56. The matplotlib colormap to use.
  57. vmin : float, default None
  58. The minimum value for scaling the colormap. If `None`, find the
  59. minimum value of `src_attr`.
  60. vmax : float, default None
  61. The maximum value for scaling the colormap. If `None`, find the
  62. maximum value of `src_attr`.
  63. nodes : bool, default True
  64. Whether the attribute names are edge attributes or node attributes.
  65. """
  66. import matplotlib as mpl
  67. if nodes:
  68. type_iter = G.nodes()
  69. elif G.is_multigraph():
  70. type_iter = G.edges(keys=True)
  71. else:
  72. type_iter = G.edges()
  73. if vmin is None or vmax is None:
  74. vals = [type_iter[a][src_attr] for a in type_iter]
  75. if vmin is None:
  76. vmin = min(vals)
  77. if vmax is None:
  78. vmax = max(vals)
  79. mapper = mpl.cm.ScalarMappable(cmap=map)
  80. mapper.set_clim(vmin, vmax)
  81. def do_map(x):
  82. # Cast numpy scalars to float
  83. return tuple(float(x) for x in mapper.to_rgba(x))
  84. if nodes:
  85. nx.set_node_attributes(
  86. G, {n: do_map(G.nodes[n][src_attr]) for n in G.nodes()}, dest_attr
  87. )
  88. else:
  89. nx.set_edge_attributes(
  90. G, {e: do_map(G.edges[e][src_attr]) for e in type_iter}, dest_attr
  91. )
  92. class CurvedArrowTextBase:
  93. def __init__(
  94. self,
  95. arrow,
  96. *args,
  97. label_pos=0.5,
  98. labels_horizontal=False,
  99. ax=None,
  100. **kwargs,
  101. ):
  102. # Bind to FancyArrowPatch
  103. self.arrow = arrow
  104. # how far along the text should be on the curve,
  105. # 0 is at start, 1 is at end etc.
  106. self.label_pos = label_pos
  107. self.labels_horizontal = labels_horizontal
  108. if ax is None:
  109. ax = plt.gca()
  110. self.ax = ax
  111. self.x, self.y, self.angle = self._update_text_pos_angle(arrow)
  112. # Create text object
  113. super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs)
  114. # Bind to axis
  115. self.ax.add_artist(self)
  116. def _get_arrow_path_disp(self, arrow):
  117. """
  118. This is part of FancyArrowPatch._get_path_in_displaycoord
  119. It omits the second part of the method where path is converted
  120. to polygon based on width
  121. The transform is taken from ax, not the object, as the object
  122. has not been added yet, and doesn't have transform
  123. """
  124. dpi_cor = arrow._dpi_cor
  125. trans_data = self.ax.transData
  126. if arrow._posA_posB is None:
  127. raise ValueError(
  128. "Can only draw labels for fancy arrows with "
  129. "posA and posB inputs, not custom path"
  130. )
  131. posA = arrow._convert_xy_units(arrow._posA_posB[0])
  132. posB = arrow._convert_xy_units(arrow._posA_posB[1])
  133. (posA, posB) = trans_data.transform((posA, posB))
  134. _path = arrow.get_connectionstyle()(
  135. posA,
  136. posB,
  137. patchA=arrow.patchA,
  138. patchB=arrow.patchB,
  139. shrinkA=arrow.shrinkA * dpi_cor,
  140. shrinkB=arrow.shrinkB * dpi_cor,
  141. )
  142. # Return is in display coordinates
  143. return _path
  144. def _update_text_pos_angle(self, arrow):
  145. # Fractional label position
  146. # Text position at a proportion t along the line in display coords
  147. # default is 0.5 so text appears at the halfway point
  148. import matplotlib as mpl
  149. import numpy as np
  150. t = self.label_pos
  151. tt = 1 - t
  152. path_disp = self._get_arrow_path_disp(arrow)
  153. conn = arrow.get_connectionstyle()
  154. # 1. Calculate x and y
  155. points = path_disp.vertices
  156. if is_curve := isinstance(
  157. conn,
  158. mpl.patches.ConnectionStyle.Angle3 | mpl.patches.ConnectionStyle.Arc3,
  159. ):
  160. # Arc3 or Angle3 type Connection Styles - Bezier curve
  161. (x1, y1), (cx, cy), (x2, y2) = points
  162. x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2
  163. y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2
  164. else:
  165. if not isinstance(
  166. conn,
  167. mpl.patches.ConnectionStyle.Angle
  168. | mpl.patches.ConnectionStyle.Arc
  169. | mpl.patches.ConnectionStyle.Bar,
  170. ):
  171. msg = f"invalid connection style: {type(conn)}"
  172. raise TypeError(msg)
  173. # A. Collect lines
  174. codes = path_disp.codes
  175. lines = [
  176. points[i - 1 : i + 1]
  177. for i in range(1, len(points))
  178. if codes[i] == mpl.path.Path.LINETO
  179. ]
  180. # B. If more than one line, find the right one and position in it
  181. if (nlines := len(lines)) != 1:
  182. dists = [math.dist(*line) for line in lines]
  183. dist_tot = sum(dists)
  184. cdist = 0
  185. last_cut = 0
  186. i_last = nlines - 1
  187. for i, dist in enumerate(dists):
  188. cdist += dist
  189. cut = cdist / dist_tot
  190. if i == i_last or t < cut:
  191. t = (t - last_cut) / (dist / dist_tot)
  192. tt = 1 - t
  193. lines = [lines[i]]
  194. break
  195. last_cut = cut
  196. [[(cx1, cy1), (cx2, cy2)]] = lines
  197. x = cx1 * tt + cx2 * t
  198. y = cy1 * tt + cy2 * t
  199. # 2. Calculate Angle
  200. if self.labels_horizontal:
  201. # Horizontal text labels
  202. angle = 0
  203. else:
  204. # Labels parallel to curve
  205. if is_curve:
  206. change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx)
  207. change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy)
  208. else:
  209. change_x = (cx2 - cx1) / 2
  210. change_y = (cy2 - cy1) / 2
  211. angle = np.arctan2(change_y, change_x) / (2 * np.pi) * 360
  212. # Text is "right way up"
  213. if angle > 90:
  214. angle -= 180
  215. elif angle < -90:
  216. angle += 180
  217. (x, y) = self.ax.transData.inverted().transform((x, y))
  218. return x, y, angle
  219. def draw(self, renderer):
  220. # recalculate the text position and angle
  221. self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow)
  222. self.set_position((self.x, self.y))
  223. self.set_rotation(self.angle)
  224. # redraw text
  225. super().draw(renderer)
  226. def display(
  227. G,
  228. canvas=None,
  229. **kwargs,
  230. ):
  231. """Draw the graph G.
  232. Draw the graph as a collection of nodes connected by edges.
  233. The exact details of what the graph looks like are controlled by the below
  234. attributes. All nodes and nodes at the end of visible edges must have a
  235. position set, but nearly all other node and edge attributes are options and
  236. nodes or edges missing the attribute will use the default listed below. A more
  237. complete description of each parameter is given below this summary.
  238. .. list-table:: Default Visualization Attributes
  239. :widths: 25 25 50
  240. :header-rows: 1
  241. * - Parameter
  242. - Default Attribute
  243. - Default Value
  244. * - node_pos
  245. - `"pos"`
  246. - If there is not position, a layout will be calculated with `nx.spring_layout`.
  247. * - node_visible
  248. - `"visible"`
  249. - True
  250. * - node_color
  251. - `"color"`
  252. - #1f78b4
  253. * - node_size
  254. - `"size"`
  255. - 300
  256. * - node_label
  257. - `"label"`
  258. - Dict describing the node label. Defaults create a black text with
  259. the node name as the label. The dict respects these keys and defaults:
  260. * size : 12
  261. * color : black
  262. * family : sans serif
  263. * weight : normal
  264. * alpha : 1.0
  265. * h_align : center
  266. * v_align : center
  267. * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
  268. Default is None.
  269. * - node_shape
  270. - `"shape"`
  271. - "o"
  272. * - node_alpha
  273. - `"alpha"`
  274. - 1.0
  275. * - node_border_width
  276. - `"border_width"`
  277. - 1.0
  278. * - node_border_color
  279. - `"border_color"`
  280. - Matching node_color
  281. * - edge_visible
  282. - `"visible"`
  283. - True
  284. * - edge_width
  285. - `"width"`
  286. - 1.0
  287. * - edge_color
  288. - `"color"`
  289. - Black (#000000)
  290. * - edge_label
  291. - `"label"`
  292. - Dict describing the edge label. Defaults create black text with a
  293. white bounding box. The dictionary respects these keys and defaults:
  294. * size : 12
  295. * color : black
  296. * family : sans serif
  297. * weight : normal
  298. * alpha : 1.0
  299. * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
  300. Default {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
  301. * h_align : "center"
  302. * v_align : "center"
  303. * pos : 0.5
  304. * rotate : True
  305. * - edge_style
  306. - `"style"`
  307. - "-"
  308. * - edge_alpha
  309. - `"alpha"`
  310. - 1.0
  311. * - edge_arrowstyle
  312. - `"arrowstyle"`
  313. - ``"-|>"`` if `G` is directed else ``"-"``
  314. * - edge_arrowsize
  315. - `"arrowsize"`
  316. - 10 if `G` is directed else 0
  317. * - edge_curvature
  318. - `"curvature"`
  319. - arc3
  320. * - edge_source_margin
  321. - `"source_margin"`
  322. - 0
  323. * - edge_target_margin
  324. - `"target_margin"`
  325. - 0
  326. Parameters
  327. ----------
  328. G : graph
  329. A networkx graph
  330. canvas : Matplotlib Axes object, optional
  331. Draw the graph in specified Matplotlib axes
  332. node_pos : string or function, default "pos"
  333. A string naming the node attribute storing the position of nodes as a tuple.
  334. Or a function to be called with input `G` which returns the layout as a dict keyed
  335. by node to position tuple like the NetworkX layout functions.
  336. If no nodes in the graph has the attribute, a spring layout is calculated.
  337. node_visible : string or bool, default visible
  338. A string naming the node attribute which stores if a node should be drawn.
  339. If `True`, all nodes will be visible while if `False` no nodes will be visible.
  340. If incomplete, nodes missing this attribute will be shown by default.
  341. node_color : string, default "color"
  342. A string naming the node attribute which stores the color of each node.
  343. Visible nodes without this attribute will use '#1f78b4' as a default.
  344. node_size : string or number, default "size"
  345. A string naming the node attribute which stores the size of each node.
  346. Visible nodes without this attribute will use a default size of 300.
  347. node_label : string or bool, default "label"
  348. A string naming the node attribute which stores the label of each node.
  349. The attribute value can be a string, False (no label for that node),
  350. True (the node is the label) or a dict keyed by node to the label.
  351. If a dict is specified, these keys are read to further control the label:
  352. * label : The text of the label; default: name of the node
  353. * size : Font size of the label; default: 12
  354. * color : Font color of the label; default: black
  355. * family : Font family of the label; default: "sans-serif"
  356. * weight : Font weight of the label; default: "normal"
  357. * alpha : Alpha value of the label; default: 1.0
  358. * h_align : The horizontal alignment of the label.
  359. one of "left", "center", "right"; default: "center"
  360. * v_align : The vertical alignment of the label.
  361. one of "top", "center", "bottom"; default: "center"
  362. * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
  363. Visible nodes without this attribute will be treated as if the value was True.
  364. node_shape : string, default "shape"
  365. A string naming the node attribute which stores the label of each node.
  366. The values of this attribute are expected to be one of the matplotlib shapes,
  367. one of 'so^>v<dph8'. Visible nodes without this attribute will use 'o'.
  368. node_alpha : string, default "alpha"
  369. A string naming the node attribute which stores the alpha of each node.
  370. The values of this attribute are expected to be floats between 0.0 and 1.0.
  371. Visible nodes without this attribute will be treated as if the value was 1.0.
  372. node_border_width : string, default "border_width"
  373. A string naming the node attribute storing the width of the border of the node.
  374. The values of this attribute are expected to be numeric. Visible nodes without
  375. this attribute will use the assumed default of 1.0.
  376. node_border_color : string, default "border_color"
  377. A string naming the node attribute which storing the color of the border of the node.
  378. Visible nodes missing this attribute will use the final node_color value.
  379. edge_visible : string or bool, default "visible"
  380. A string nameing the edge attribute which stores if an edge should be drawn.
  381. If `True`, all edges will be drawn while if `False` no edges will be visible.
  382. If incomplete, edges missing this attribute will be shown by default. Values
  383. of this attribute are expected to be booleans.
  384. edge_width : string or int, default "width"
  385. A string nameing the edge attribute which stores the width of each edge.
  386. Visible edges without this attribute will use a default width of 1.0.
  387. edge_color : string or color, default "color"
  388. A string nameing the edge attribute which stores of color of each edge.
  389. Visible edges without this attribute will be drawn black. Each color can be
  390. a string or rgb (or rgba) tuple of floats from 0.0 to 1.0.
  391. edge_label : string, default "label"
  392. A string naming the edge attribute which stores the label of each edge.
  393. The values of this attribute can be a string, number or False or None. In
  394. the latter two cases, no edge label is displayed.
  395. If a dict is specified, these keys are read to further control the label:
  396. * label : The text of the label, or the name of an edge attribute holding the label.
  397. * size : Font size of the label; default: 12
  398. * color : Font color of the label; default: black
  399. * family : Font family of the label; default: "sans-serif"
  400. * weight : Font weight of the label; default: "normal"
  401. * alpha : Alpha value of the label; default: 1.0
  402. * h_align : The horizontal alignment of the label.
  403. one of "left", "center", "right"; default: "center"
  404. * v_align : The vertical alignment of the label.
  405. one of "top", "center", "bottom"; default: "center"
  406. * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
  407. * rotate : Whether to rotate labels to lie parallel to the edge, default: True.
  408. * pos : A float showing how far along the edge to put the label; default: 0.5.
  409. edge_style : string, default "style"
  410. A string naming the edge attribute which stores the style of each edge.
  411. Visible edges without this attribute will be drawn solid. Values of this
  412. attribute can be line styles, e.g. '-', '--', '-.' or ':' or words like 'solid'
  413. or 'dashed'. If no edge in the graph has this attribute and it is a non-default
  414. value, assume that it describes the edge style for all edges in the graph.
  415. edge_alpha : string or float, default "alpha"
  416. A string naming the edge attribute which stores the alpha value of each edge.
  417. Visible edges without this attribute will use an alpha value of 1.0.
  418. edge_arrowstyle : string, default "arrowstyle"
  419. A string naming the edge attribute which stores the type of arrowhead to use for
  420. each edge. Visible edges without this attribute use ``"-"`` for undirected graphs
  421. and ``"-|>"`` for directed graphs.
  422. See `matplotlib.patches.ArrowStyle` for more options
  423. edge_arrowsize : string or int, default "arrowsize"
  424. A string naming the edge attribute which stores the size of the arrowhead for each
  425. edge. Visible edges without this attribute will use a default value of 10.
  426. edge_curvature : string, default "curvature"
  427. A string naming the edge attribute storing the curvature and connection style
  428. of each edge. Visible edges without this attribute will use "arc3" as a default
  429. value, resulting an a straight line between the two nodes. Curvature can be given
  430. as 'arc3,rad=0.2' to specify both the style and radius of curvature.
  431. Please see `matplotlib.patches.ConnectionStyle` and
  432. `matplotlib.patches.FancyArrowPatch` for more information.
  433. edge_source_margin : string or int, default "source_margin"
  434. A string naming the edge attribute which stores the minimum margin (gap) between
  435. the source node and the start of the edge. Visible edges without this attribute
  436. will use a default value of 0.
  437. edge_target_margin : string or int, default "target_margin"
  438. A string naming the edge attribute which stores the minimumm margin (gap) between
  439. the target node and the end of the edge. Visible edges without this attribute
  440. will use a default value of 0.
  441. hide_ticks : bool, default True
  442. Weather to remove the ticks from the axes of the matplotlib object.
  443. Raises
  444. ------
  445. NetworkXError
  446. If a node or edge is missing a required parameter such as `pos` or
  447. if `display` receives an argument not listed above.
  448. ValueError
  449. If a node or edge has an invalid color format, i.e. not a color string,
  450. rgb tuple or rgba tuple.
  451. Returns
  452. -------
  453. The input graph. This is potentially useful for dispatching visualization
  454. functions.
  455. """
  456. from collections import Counter
  457. import matplotlib as mpl
  458. import matplotlib.pyplot as plt
  459. import numpy as np
  460. defaults = {
  461. "node_pos": None,
  462. "node_visible": True,
  463. "node_color": "#1f78b4",
  464. "node_size": 300,
  465. "node_label": {
  466. "size": 12,
  467. "color": "#000000",
  468. "family": "sans-serif",
  469. "weight": "normal",
  470. "alpha": 1.0,
  471. "h_align": "center",
  472. "v_align": "center",
  473. "bbox": None,
  474. },
  475. "node_shape": "o",
  476. "node_alpha": 1.0,
  477. "node_border_width": 1.0,
  478. "node_border_color": "face",
  479. "edge_visible": True,
  480. "edge_width": 1.0,
  481. "edge_color": "#000000",
  482. "edge_label": {
  483. "size": 12,
  484. "color": "#000000",
  485. "family": "sans-serif",
  486. "weight": "normal",
  487. "alpha": 1.0,
  488. "bbox": {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)},
  489. "h_align": "center",
  490. "v_align": "center",
  491. "pos": 0.5,
  492. "rotate": True,
  493. },
  494. "edge_style": "-",
  495. "edge_alpha": 1.0,
  496. "edge_arrowstyle": "-|>" if G.is_directed() else "-",
  497. "edge_arrowsize": 10 if G.is_directed() else 0,
  498. "edge_curvature": "arc3",
  499. "edge_source_margin": 0,
  500. "edge_target_margin": 0,
  501. "hide_ticks": True,
  502. }
  503. # Check arguments
  504. for kwarg in kwargs:
  505. if kwarg not in defaults:
  506. raise nx.NetworkXError(
  507. f"Unrecognized visualization keyword argument: {kwarg}"
  508. )
  509. if canvas is None:
  510. canvas = plt.gca()
  511. if kwargs.get("hide_ticks", defaults["hide_ticks"]):
  512. canvas.tick_params(
  513. axis="both",
  514. which="both",
  515. bottom=False,
  516. left=False,
  517. labelbottom=False,
  518. labelleft=False,
  519. )
  520. ### Helper methods and classes
  521. def node_property_sequence(seq, attr):
  522. """Return a list of attribute values for `seq`, using a default if needed"""
  523. # All node attribute parameters start with "node_"
  524. param_name = f"node_{attr}"
  525. default = defaults[param_name]
  526. attr = kwargs.get(param_name, attr)
  527. if default is None:
  528. # raise instead of using non-existant default value
  529. for n in seq:
  530. if attr not in node_subgraph.nodes[n]:
  531. raise nx.NetworkXError(f"Attribute '{attr}' missing for node {n}")
  532. # If `attr` is not a graph attr and was explicitly passed as an argument
  533. # it must be a user-default value. Allow attr=None to tell draw to skip
  534. # attributes which are on the graph
  535. if (
  536. attr is not None
  537. and nx.get_node_attributes(node_subgraph, attr) == {}
  538. and any(attr == v for k, v in kwargs.items() if "node" in k)
  539. ):
  540. return [attr for _ in seq]
  541. return [node_subgraph.nodes[n].get(attr, default) for n in seq]
  542. def compute_colors(color, alpha):
  543. if isinstance(color, str):
  544. rgba = mpl.colors.colorConverter.to_rgba(color)
  545. # Using a non-default alpha value overrides any alpha value in the color
  546. if alpha != defaults["node_alpha"]:
  547. return (rgba[0], rgba[1], rgba[2], alpha)
  548. return rgba
  549. if isinstance(color, tuple) and len(color) == 3:
  550. return (color[0], color[1], color[2], alpha)
  551. if isinstance(color, tuple) and len(color) == 4:
  552. return color
  553. raise ValueError(f"Invalid format for color: {color}")
  554. # Find which edges can be plotted as a line collection
  555. #
  556. # Non-default values for these attributes require fancy arrow patches:
  557. # - any arrow style (including the default -|> for directed graphs)
  558. # - arrow size (by extension of style)
  559. # - connection style
  560. # - min_source_margin
  561. # - min_target_margin
  562. def collection_compatible(e):
  563. return (
  564. get_edge_attr(e, "arrowstyle") == "-"
  565. and get_edge_attr(e, "curvature") == "arc3"
  566. and get_edge_attr(e, "source_margin") == 0
  567. and get_edge_attr(e, "target_margin") == 0
  568. # Self-loops will use fancy arrow patches
  569. and e[0] != e[1]
  570. )
  571. def edge_property_sequence(seq, attr):
  572. """Return a list of attribute values for `seq`, using a default if needed"""
  573. param_name = f"edge_{attr}"
  574. default = defaults[param_name]
  575. attr = kwargs.get(param_name, attr)
  576. if default is None:
  577. # raise instead of using non-existant default value
  578. for e in seq:
  579. if attr not in edge_subgraph.edges[e]:
  580. raise nx.NetworkXError(f"Attribute '{attr}' missing for edge {e}")
  581. if (
  582. attr is not None
  583. and nx.get_edge_attributes(edge_subgraph, attr) == {}
  584. and any(attr == v for k, v in kwargs.items() if "edge" in k)
  585. ):
  586. return [attr for _ in seq]
  587. return [edge_subgraph.edges[e].get(attr, default) for e in seq]
  588. def get_edge_attr(e, attr):
  589. """Return the final edge attribute value, using default if not None"""
  590. param_name = f"edge_{attr}"
  591. default = defaults[param_name]
  592. attr = kwargs.get(param_name, attr)
  593. if default is None and attr not in edge_subgraph.edges[e]:
  594. raise nx.NetworkXError(f"Attribute '{attr}' missing from edge {e}")
  595. if (
  596. attr is not None
  597. and nx.get_edge_attributes(edge_subgraph, attr) == {}
  598. and attr in kwargs.values()
  599. ):
  600. return attr
  601. return edge_subgraph.edges[e].get(attr, default)
  602. def get_node_attr(n, attr, use_edge_subgraph=True):
  603. """Return the final node attribute value, using default if not None"""
  604. subgraph = edge_subgraph if use_edge_subgraph else node_subgraph
  605. param_name = f"node_{attr}"
  606. default = defaults[param_name]
  607. attr = kwargs.get(param_name, attr)
  608. if default is None and attr not in subgraph.nodes[n]:
  609. raise nx.NetworkXError(f"Attribute '{attr}' missing from node {n}")
  610. if (
  611. attr is not None
  612. and nx.get_node_attributes(subgraph, attr) == {}
  613. and attr in kwargs.values()
  614. ):
  615. return attr
  616. return subgraph.nodes[n].get(attr, default)
  617. # Taken from ConnectionStyleFactory
  618. def self_loop(edge_index, node_size):
  619. def self_loop_connection(posA, posB, *args, **kwargs):
  620. if not np.all(posA == posB):
  621. raise nx.NetworkXError(
  622. "`self_loop` connection style method"
  623. "is only to be used for self-loops"
  624. )
  625. # this is called with _screen space_ values
  626. # so convert back to data space
  627. data_loc = canvas.transData.inverted().transform(posA)
  628. # Scale self loop based on the size of the base node
  629. # Size of nodes are given in points ** 2 and each point is 1/72 of an inch
  630. v_shift = np.sqrt(node_size) / 72
  631. h_shift = v_shift * 0.5
  632. # put the top of the loop first so arrow is not hidden by node
  633. path = np.asarray(
  634. [
  635. # 1
  636. [0, v_shift],
  637. # 4 4 4
  638. [h_shift, v_shift],
  639. [h_shift, 0],
  640. [0, 0],
  641. # 4 4 4
  642. [-h_shift, 0],
  643. [-h_shift, v_shift],
  644. [0, v_shift],
  645. ]
  646. )
  647. # Rotate self loop 90 deg. if more than 1
  648. # This will allow for maximum of 4 visible self loops
  649. if edge_index % 4:
  650. x, y = path.T
  651. for _ in range(edge_index % 4):
  652. x, y = y, -x
  653. path = np.array([x, y]).T
  654. return mpl.path.Path(
  655. canvas.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
  656. )
  657. return self_loop_connection
  658. def to_marker_edge(size, marker):
  659. if marker in "s^>v<d":
  660. return np.sqrt(2 * size) / 2
  661. else:
  662. return np.sqrt(size) / 2
  663. def build_fancy_arrow(e):
  664. source_margin = to_marker_edge(
  665. get_node_attr(e[0], "size"),
  666. get_node_attr(e[0], "shape"),
  667. )
  668. source_margin = max(
  669. source_margin,
  670. get_edge_attr(e, "source_margin"),
  671. )
  672. target_margin = to_marker_edge(
  673. get_node_attr(e[1], "size"),
  674. get_node_attr(e[1], "shape"),
  675. )
  676. target_margin = max(
  677. target_margin,
  678. get_edge_attr(e, "target_margin"),
  679. )
  680. return mpl.patches.FancyArrowPatch(
  681. edge_subgraph.nodes[e[0]][pos],
  682. edge_subgraph.nodes[e[1]][pos],
  683. arrowstyle=get_edge_attr(e, "arrowstyle"),
  684. connectionstyle=(
  685. get_edge_attr(e, "curvature")
  686. if e[0] != e[1]
  687. else self_loop(
  688. 0 if len(e) == 2 else e[2] % 4,
  689. get_node_attr(e[0], "size"),
  690. )
  691. ),
  692. color=get_edge_attr(e, "color"),
  693. linestyle=get_edge_attr(e, "style"),
  694. linewidth=get_edge_attr(e, "width"),
  695. mutation_scale=get_edge_attr(e, "arrowsize"),
  696. shrinkA=source_margin,
  697. shrinkB=source_margin,
  698. zorder=1,
  699. )
  700. class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
  701. pass
  702. ### Draw the nodes first
  703. node_visible = kwargs.get("node_visible", "visible")
  704. if isinstance(node_visible, bool):
  705. if node_visible:
  706. visible_nodes = G.nodes()
  707. else:
  708. visible_nodes = []
  709. else:
  710. visible_nodes = [
  711. n for n, v in nx.get_node_attributes(G, node_visible, True).items() if v
  712. ]
  713. node_subgraph = G.subgraph(visible_nodes)
  714. # Ignore the default dict value since that's for default values to use, not
  715. # default attribute name
  716. pos = kwargs.get("node_pos", "pos")
  717. default_display_pos_attr = "display's position attribute name"
  718. if callable(pos):
  719. nx.set_node_attributes(
  720. node_subgraph, pos(node_subgraph), default_display_pos_attr
  721. )
  722. pos = default_display_pos_attr
  723. kwargs["node_pos"] = default_display_pos_attr
  724. elif nx.get_node_attributes(G, pos) == {}:
  725. nx.set_node_attributes(
  726. node_subgraph, nx.spring_layout(node_subgraph), default_display_pos_attr
  727. )
  728. pos = default_display_pos_attr
  729. kwargs["node_pos"] = default_display_pos_attr
  730. # Each shape requires a new scatter object since they can't have different
  731. # shapes.
  732. if len(visible_nodes) > 0:
  733. node_shape = kwargs.get("node_shape", "shape")
  734. for shape in Counter(
  735. nx.get_node_attributes(
  736. node_subgraph, node_shape, defaults["node_shape"]
  737. ).values()
  738. ):
  739. # Filter position just on this shape.
  740. nodes_with_shape = [
  741. n
  742. for n, s in node_subgraph.nodes(data=node_shape)
  743. if s == shape or (s is None and shape == defaults["node_shape"])
  744. ]
  745. # There are two property sequences to create before hand.
  746. # 1. position, since it is used for x and y parameters to scatter
  747. # 2. edgecolor, since the spaeical 'face' parameter value can only be
  748. # be passed in as the sole string, not part of a list of strings.
  749. position = np.asarray(node_property_sequence(nodes_with_shape, "pos"))
  750. color = np.asarray(
  751. [
  752. compute_colors(c, a)
  753. for c, a in zip(
  754. node_property_sequence(nodes_with_shape, "color"),
  755. node_property_sequence(nodes_with_shape, "alpha"),
  756. )
  757. ]
  758. )
  759. border_color = np.asarray(
  760. [
  761. (
  762. c
  763. if (
  764. c := get_node_attr(
  765. n,
  766. "border_color",
  767. False,
  768. )
  769. )
  770. != "face"
  771. else color[i]
  772. )
  773. for i, n in enumerate(nodes_with_shape)
  774. ]
  775. )
  776. canvas.scatter(
  777. position[:, 0],
  778. position[:, 1],
  779. s=node_property_sequence(nodes_with_shape, "size"),
  780. c=color,
  781. marker=shape,
  782. linewidths=node_property_sequence(nodes_with_shape, "border_width"),
  783. edgecolors=border_color,
  784. zorder=2,
  785. )
  786. ### Draw node labels
  787. node_label = kwargs.get("node_label", "label")
  788. # Plot labels if node_label is not None and not False
  789. if node_label is not None and node_label is not False:
  790. default_dict = {}
  791. if isinstance(node_label, dict):
  792. default_dict = node_label
  793. node_label = None
  794. for n, lbl in node_subgraph.nodes(data=node_label):
  795. if lbl is False:
  796. continue
  797. # We work with label dicts down here...
  798. if not isinstance(lbl, dict):
  799. lbl = {"label": lbl if lbl is not None else n}
  800. lbl_text = lbl.get("label", n)
  801. if not isinstance(lbl_text, str):
  802. lbl_text = str(lbl_text)
  803. lbl.update(default_dict)
  804. x, y = node_subgraph.nodes[n][pos]
  805. canvas.text(
  806. x,
  807. y,
  808. lbl_text,
  809. size=lbl.get("size", defaults["node_label"]["size"]),
  810. color=lbl.get("color", defaults["node_label"]["color"]),
  811. family=lbl.get("family", defaults["node_label"]["family"]),
  812. weight=lbl.get("weight", defaults["node_label"]["weight"]),
  813. horizontalalignment=lbl.get(
  814. "h_align", defaults["node_label"]["h_align"]
  815. ),
  816. verticalalignment=lbl.get("v_align", defaults["node_label"]["v_align"]),
  817. transform=canvas.transData,
  818. bbox=lbl.get("bbox", defaults["node_label"]["bbox"]),
  819. )
  820. ### Draw edges
  821. edge_visible = kwargs.get("edge_visible", "visible")
  822. if isinstance(edge_visible, bool):
  823. if edge_visible:
  824. visible_edges = G.edges()
  825. else:
  826. visible_edges = []
  827. else:
  828. visible_edges = [
  829. e for e, v in nx.get_edge_attributes(G, edge_visible, True).items() if v
  830. ]
  831. edge_subgraph = G.edge_subgraph(visible_edges)
  832. nx.set_node_attributes(
  833. edge_subgraph, nx.get_node_attributes(node_subgraph, pos), name=pos
  834. )
  835. collection_edges = (
  836. [e for e in edge_subgraph.edges(keys=True) if collection_compatible(e)]
  837. if edge_subgraph.is_multigraph()
  838. else [e for e in edge_subgraph.edges() if collection_compatible(e)]
  839. )
  840. non_collection_edges = (
  841. [e for e in edge_subgraph.edges(keys=True) if not collection_compatible(e)]
  842. if edge_subgraph.is_multigraph()
  843. else [e for e in edge_subgraph.edges() if not collection_compatible(e)]
  844. )
  845. edge_position = np.asarray(
  846. [
  847. (
  848. get_node_attr(u, "pos", use_edge_subgraph=True),
  849. get_node_attr(v, "pos", use_edge_subgraph=True),
  850. )
  851. for u, v, *_ in collection_edges
  852. ]
  853. )
  854. # Only plot a line collection if needed
  855. if len(collection_edges) > 0:
  856. edge_collection = mpl.collections.LineCollection(
  857. edge_position,
  858. colors=edge_property_sequence(collection_edges, "color"),
  859. linewidths=edge_property_sequence(collection_edges, "width"),
  860. linestyle=edge_property_sequence(collection_edges, "style"),
  861. alpha=edge_property_sequence(collection_edges, "alpha"),
  862. antialiaseds=(1,),
  863. zorder=1,
  864. )
  865. canvas.add_collection(edge_collection)
  866. fancy_arrows = {}
  867. if len(non_collection_edges) > 0:
  868. for e in non_collection_edges:
  869. # Cache results for use in edge labels
  870. fancy_arrows[e] = build_fancy_arrow(e)
  871. canvas.add_patch(fancy_arrows[e])
  872. ### Draw edge labels
  873. edge_label = kwargs.get("edge_label", "label")
  874. default_dict = {}
  875. if isinstance(edge_label, dict):
  876. default_dict = edge_label
  877. # Restore the default label attribute key of 'label'
  878. edge_label = "label"
  879. # Handle multigraphs
  880. edge_label_data = (
  881. edge_subgraph.edges(data=edge_label, keys=True)
  882. if edge_subgraph.is_multigraph()
  883. else edge_subgraph.edges(data=edge_label)
  884. )
  885. if edge_label is not None and edge_label is not False:
  886. for *e, lbl in edge_label_data:
  887. e = tuple(e)
  888. # I'm not sure how I want to handle None here... For now it means no label
  889. if lbl is False or lbl is None:
  890. continue
  891. if not isinstance(lbl, dict):
  892. lbl = {"label": lbl}
  893. lbl.update(default_dict)
  894. lbl_text = lbl.get("label")
  895. if not isinstance(lbl_text, str):
  896. lbl_text = str(lbl_text)
  897. # In the old code, every non-self-loop is placed via a fancy arrow patch
  898. # Only compute a new fancy arrow if needed by caching the results from
  899. # edge placement.
  900. try:
  901. arrow = fancy_arrows[e]
  902. except KeyError:
  903. arrow = build_fancy_arrow(e)
  904. if e[0] == e[1]:
  905. # Taken directly from draw_networkx_edge_labels
  906. connectionstyle_obj = arrow.get_connectionstyle()
  907. posA = canvas.transData.transform(edge_subgraph.nodes[e[0]][pos])
  908. path_disp = connectionstyle_obj(posA, posA)
  909. path_data = canvas.transData.inverted().transform_path(path_disp)
  910. x, y = path_data.vertices[0]
  911. canvas.text(
  912. x,
  913. y,
  914. lbl_text,
  915. size=lbl.get("size", defaults["edge_label"]["size"]),
  916. color=lbl.get("color", defaults["edge_label"]["color"]),
  917. family=lbl.get("family", defaults["edge_label"]["family"]),
  918. weight=lbl.get("weight", defaults["edge_label"]["weight"]),
  919. alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
  920. horizontalalignment=lbl.get(
  921. "h_align", defaults["edge_label"]["h_align"]
  922. ),
  923. verticalalignment=lbl.get(
  924. "v_align", defaults["edge_label"]["v_align"]
  925. ),
  926. rotation=0,
  927. transform=canvas.transData,
  928. bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
  929. zorder=1,
  930. )
  931. continue
  932. CurvedArrowText(
  933. arrow,
  934. lbl_text,
  935. size=lbl.get("size", defaults["edge_label"]["size"]),
  936. color=lbl.get("color", defaults["edge_label"]["color"]),
  937. family=lbl.get("family", defaults["edge_label"]["family"]),
  938. weight=lbl.get("weight", defaults["edge_label"]["weight"]),
  939. alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
  940. bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
  941. horizontalalignment=lbl.get(
  942. "h_align", defaults["edge_label"]["h_align"]
  943. ),
  944. verticalalignment=lbl.get("v_align", defaults["edge_label"]["v_align"]),
  945. label_pos=lbl.get("pos", defaults["edge_label"]["pos"]),
  946. labels_horizontal=lbl.get("rotate", defaults["edge_label"]["rotate"]),
  947. transform=canvas.transData,
  948. zorder=1,
  949. ax=canvas,
  950. )
  951. # If we had to add an attribute, remove it here
  952. if pos == default_display_pos_attr:
  953. nx.remove_node_attributes(G, default_display_pos_attr)
  954. return G
  955. def draw(G, pos=None, ax=None, **kwds):
  956. """Draw the graph G with Matplotlib.
  957. Draw the graph as a simple representation with no node
  958. labels or edge labels and using the full Matplotlib figure area
  959. and no axis labels by default. See draw_networkx() for more
  960. full-featured drawing that allows title, axis labels etc.
  961. Parameters
  962. ----------
  963. G : graph
  964. A networkx graph
  965. pos : dictionary, optional
  966. A dictionary with nodes as keys and positions as values.
  967. If not specified a spring layout positioning will be computed.
  968. See :py:mod:`networkx.drawing.layout` for functions that
  969. compute node positions.
  970. ax : Matplotlib Axes object, optional
  971. Draw the graph in specified Matplotlib axes.
  972. kwds : optional keywords
  973. See networkx.draw_networkx() for a description of optional keywords.
  974. Examples
  975. --------
  976. >>> G = nx.dodecahedral_graph()
  977. >>> nx.draw(G)
  978. >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
  979. See Also
  980. --------
  981. draw_networkx
  982. draw_networkx_nodes
  983. draw_networkx_edges
  984. draw_networkx_labels
  985. draw_networkx_edge_labels
  986. Notes
  987. -----
  988. This function has the same name as pylab.draw and pyplot.draw
  989. so beware when using `from networkx import *`
  990. since you might overwrite the pylab.draw function.
  991. With pyplot use
  992. >>> import matplotlib.pyplot as plt
  993. >>> G = nx.dodecahedral_graph()
  994. >>> nx.draw(G) # networkx draw()
  995. >>> plt.draw() # pyplot draw()
  996. Also see the NetworkX drawing examples at
  997. https://networkx.org/documentation/latest/auto_examples/index.html
  998. """
  999. import matplotlib.pyplot as plt
  1000. if ax is None:
  1001. cf = plt.gcf()
  1002. else:
  1003. cf = ax.get_figure()
  1004. cf.set_facecolor("w")
  1005. if ax is None:
  1006. if cf.axes:
  1007. ax = cf.gca()
  1008. else:
  1009. ax = cf.add_axes((0, 0, 1, 1))
  1010. if "with_labels" not in kwds:
  1011. kwds["with_labels"] = "labels" in kwds
  1012. draw_networkx(G, pos=pos, ax=ax, **kwds)
  1013. ax.set_axis_off()
  1014. plt.draw_if_interactive()
  1015. return
  1016. def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
  1017. r"""Draw the graph G using Matplotlib.
  1018. Draw the graph with Matplotlib with options for node positions,
  1019. labeling, titles, and many other drawing features.
  1020. See draw() for simple drawing without labels or axes.
  1021. Parameters
  1022. ----------
  1023. G : graph
  1024. A networkx graph
  1025. pos : dictionary, optional
  1026. A dictionary with nodes as keys and positions as values.
  1027. If not specified a spring layout positioning will be computed.
  1028. See :py:mod:`networkx.drawing.layout` for functions that
  1029. compute node positions.
  1030. arrows : bool or None, optional (default=None)
  1031. If `None`, directed graphs draw arrowheads with
  1032. `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
  1033. via `~matplotlib.collections.LineCollection` for speed.
  1034. If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
  1035. If `False`, draw edges using LineCollection (linear and fast).
  1036. For directed graphs, if True draw arrowheads.
  1037. Note: Arrows will be the same color as edges.
  1038. arrowstyle : str (default='-\|>' for directed graphs)
  1039. For directed graphs, choose the style of the arrowsheads.
  1040. For undirected graphs default to '-'
  1041. See `matplotlib.patches.ArrowStyle` for more options.
  1042. arrowsize : int or list (default=10)
  1043. For directed graphs, choose the size of the arrow head's length and
  1044. width. A list of values can be passed in to assign a different size for arrow head's length and width.
  1045. See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale`
  1046. for more info.
  1047. with_labels : bool (default=True)
  1048. Set to True to draw labels on the nodes.
  1049. ax : Matplotlib Axes object, optional
  1050. Draw the graph in the specified Matplotlib axes.
  1051. nodelist : list (default=list(G))
  1052. Draw only specified nodes
  1053. edgelist : list (default=list(G.edges()))
  1054. Draw only specified edges
  1055. node_size : scalar or array (default=300)
  1056. Size of nodes. If an array is specified it must be the
  1057. same length as nodelist.
  1058. node_color : color or array of colors (default='#1f78b4')
  1059. Node color. Can be a single color or a sequence of colors with the same
  1060. length as nodelist. Color can be string or rgb (or rgba) tuple of
  1061. floats from 0-1. If numeric values are specified they will be
  1062. mapped to colors using the cmap and vmin,vmax parameters. See
  1063. matplotlib.scatter for more details.
  1064. node_shape : string (default='o')
  1065. The shape of the node. Specification is as matplotlib.scatter
  1066. marker, one of 'so^>v<dph8'.
  1067. alpha : float or None (default=None)
  1068. The node and edge transparency
  1069. cmap : Matplotlib colormap, optional
  1070. Colormap for mapping intensities of nodes
  1071. vmin,vmax : float, optional
  1072. Minimum and maximum for node colormap scaling
  1073. linewidths : scalar or sequence (default=1.0)
  1074. Line width of symbol border
  1075. width : float or array of floats (default=1.0)
  1076. Line width of edges
  1077. edge_color : color or array of colors (default='k')
  1078. Edge color. Can be a single color or a sequence of colors with the same
  1079. length as edgelist. Color can be string or rgb (or rgba) tuple of
  1080. floats from 0-1. If numeric values are specified they will be
  1081. mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
  1082. edge_cmap : Matplotlib colormap, optional
  1083. Colormap for mapping intensities of edges
  1084. edge_vmin,edge_vmax : floats, optional
  1085. Minimum and maximum for edge colormap scaling
  1086. style : string (default=solid line)
  1087. Edge line style e.g.: '-', '--', '-.', ':'
  1088. or words like 'solid' or 'dashed'.
  1089. (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
  1090. labels : dictionary (default=None)
  1091. Node labels in a dictionary of text labels keyed by node
  1092. font_size : int (default=12 for nodes, 10 for edges)
  1093. Font size for text labels
  1094. font_color : color (default='k' black)
  1095. Font color string. Color can be string or rgb (or rgba) tuple of
  1096. floats from 0-1.
  1097. font_weight : string (default='normal')
  1098. Font weight
  1099. font_family : string (default='sans-serif')
  1100. Font family
  1101. label : string, optional
  1102. Label for graph legend
  1103. hide_ticks : bool, optional
  1104. Hide ticks of axes. When `True` (the default), ticks and ticklabels
  1105. are removed from the axes. To set ticks and tick labels to the pyplot default,
  1106. use ``hide_ticks=False``.
  1107. kwds : optional keywords
  1108. See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and
  1109. networkx.draw_networkx_labels() for a description of optional keywords.
  1110. Notes
  1111. -----
  1112. For directed graphs, arrows are drawn at the head end. Arrows can be
  1113. turned off with keyword arrows=False.
  1114. Examples
  1115. --------
  1116. >>> G = nx.dodecahedral_graph()
  1117. >>> nx.draw(G)
  1118. >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
  1119. >>> import matplotlib.pyplot as plt
  1120. >>> limits = plt.axis("off") # turn off axis
  1121. Also see the NetworkX drawing examples at
  1122. https://networkx.org/documentation/latest/auto_examples/index.html
  1123. See Also
  1124. --------
  1125. draw
  1126. draw_networkx_nodes
  1127. draw_networkx_edges
  1128. draw_networkx_labels
  1129. draw_networkx_edge_labels
  1130. """
  1131. from inspect import signature
  1132. import matplotlib.pyplot as plt
  1133. # Get all valid keywords by inspecting the signatures of draw_networkx_nodes,
  1134. # draw_networkx_edges, draw_networkx_labels
  1135. valid_node_kwds = signature(draw_networkx_nodes).parameters.keys()
  1136. valid_edge_kwds = signature(draw_networkx_edges).parameters.keys()
  1137. valid_label_kwds = signature(draw_networkx_labels).parameters.keys()
  1138. # Create a set with all valid keywords across the three functions and
  1139. # remove the arguments of this function (draw_networkx)
  1140. valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - {
  1141. "G",
  1142. "pos",
  1143. "arrows",
  1144. "with_labels",
  1145. }
  1146. if any(k not in valid_kwds for k in kwds):
  1147. invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
  1148. raise ValueError(f"Received invalid argument(s): {invalid_args}")
  1149. node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
  1150. edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
  1151. label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
  1152. if pos is None:
  1153. pos = nx.drawing.spring_layout(G) # default to spring layout
  1154. draw_networkx_nodes(G, pos, **node_kwds)
  1155. draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
  1156. if with_labels:
  1157. draw_networkx_labels(G, pos, **label_kwds)
  1158. plt.draw_if_interactive()
  1159. def draw_networkx_nodes(
  1160. G,
  1161. pos,
  1162. nodelist=None,
  1163. node_size=300,
  1164. node_color="#1f78b4",
  1165. node_shape="o",
  1166. alpha=None,
  1167. cmap=None,
  1168. vmin=None,
  1169. vmax=None,
  1170. ax=None,
  1171. linewidths=None,
  1172. edgecolors=None,
  1173. label=None,
  1174. margins=None,
  1175. hide_ticks=True,
  1176. ):
  1177. """Draw the nodes of the graph G.
  1178. This draws only the nodes of the graph G.
  1179. Parameters
  1180. ----------
  1181. G : graph
  1182. A networkx graph
  1183. pos : dictionary
  1184. A dictionary with nodes as keys and positions as values.
  1185. Positions should be sequences of length 2.
  1186. ax : Matplotlib Axes object, optional
  1187. Draw the graph in the specified Matplotlib axes.
  1188. nodelist : list (default list(G))
  1189. Draw only specified nodes
  1190. node_size : scalar or array (default=300)
  1191. Size of nodes. If an array it must be the same length as nodelist.
  1192. node_color : color or array of colors (default='#1f78b4')
  1193. Node color. Can be a single color or a sequence of colors with the same
  1194. length as nodelist. Color can be string or rgb (or rgba) tuple of
  1195. floats from 0-1. If numeric values are specified they will be
  1196. mapped to colors using the cmap and vmin,vmax parameters. See
  1197. matplotlib.scatter for more details.
  1198. node_shape : string (default='o')
  1199. The shape of the node. Specification is as matplotlib.scatter
  1200. marker, one of 'so^>v<dph8'.
  1201. alpha : float or array of floats (default=None)
  1202. The node transparency. This can be a single alpha value,
  1203. in which case it will be applied to all the nodes of color. Otherwise,
  1204. if it is an array, the elements of alpha will be applied to the colors
  1205. in order (cycling through alpha multiple times if necessary).
  1206. cmap : Matplotlib colormap (default=None)
  1207. Colormap for mapping intensities of nodes
  1208. vmin,vmax : floats or None (default=None)
  1209. Minimum and maximum for node colormap scaling
  1210. linewidths : [None | scalar | sequence] (default=1.0)
  1211. Line width of symbol border
  1212. edgecolors : [None | scalar | sequence] (default = node_color)
  1213. Colors of node borders. Can be a single color or a sequence of colors with the
  1214. same length as nodelist. Color can be string or rgb (or rgba) tuple of floats
  1215. from 0-1. If numeric values are specified they will be mapped to colors
  1216. using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details.
  1217. label : [None | string]
  1218. Label for legend
  1219. margins : float or 2-tuple, optional
  1220. Sets the padding for axis autoscaling. Increase margin to prevent
  1221. clipping for nodes that are near the edges of an image. Values should
  1222. be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins`
  1223. for details. The default is `None`, which uses the Matplotlib default.
  1224. hide_ticks : bool, optional
  1225. Hide ticks of axes. When `True` (the default), ticks and ticklabels
  1226. are removed from the axes. To set ticks and tick labels to the pyplot default,
  1227. use ``hide_ticks=False``.
  1228. Returns
  1229. -------
  1230. matplotlib.collections.PathCollection
  1231. `PathCollection` of the nodes.
  1232. Examples
  1233. --------
  1234. >>> G = nx.dodecahedral_graph()
  1235. >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
  1236. Also see the NetworkX drawing examples at
  1237. https://networkx.org/documentation/latest/auto_examples/index.html
  1238. See Also
  1239. --------
  1240. draw
  1241. draw_networkx
  1242. draw_networkx_edges
  1243. draw_networkx_labels
  1244. draw_networkx_edge_labels
  1245. """
  1246. from collections.abc import Iterable
  1247. import matplotlib as mpl
  1248. import matplotlib.collections # call as mpl.collections
  1249. import matplotlib.pyplot as plt
  1250. import numpy as np
  1251. if ax is None:
  1252. ax = plt.gca()
  1253. if nodelist is None:
  1254. nodelist = list(G)
  1255. if len(nodelist) == 0: # empty nodelist, no drawing
  1256. return mpl.collections.PathCollection(None)
  1257. try:
  1258. xy = np.asarray([pos[v] for v in nodelist])
  1259. except KeyError as err:
  1260. raise nx.NetworkXError(f"Node {err} has no position.") from err
  1261. if isinstance(alpha, Iterable):
  1262. node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
  1263. alpha = None
  1264. if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list):
  1265. node_shape = np.array([node_shape for _ in range(len(nodelist))])
  1266. elif isinstance(node_shape, list):
  1267. node_shape = np.asarray(node_shape)
  1268. for shape in np.unique(node_shape):
  1269. node_collection = ax.scatter(
  1270. xy[node_shape == shape, 0],
  1271. xy[node_shape == shape, 1],
  1272. s=node_size,
  1273. c=node_color,
  1274. marker=shape,
  1275. cmap=cmap,
  1276. vmin=vmin,
  1277. vmax=vmax,
  1278. alpha=alpha,
  1279. linewidths=linewidths,
  1280. edgecolors=edgecolors,
  1281. label=label,
  1282. )
  1283. if hide_ticks:
  1284. ax.tick_params(
  1285. axis="both",
  1286. which="both",
  1287. bottom=False,
  1288. left=False,
  1289. labelbottom=False,
  1290. labelleft=False,
  1291. )
  1292. if margins is not None:
  1293. if isinstance(margins, Iterable):
  1294. ax.margins(*margins)
  1295. else:
  1296. ax.margins(margins)
  1297. node_collection.set_zorder(2)
  1298. return node_collection
  1299. class FancyArrowFactory:
  1300. """Draw arrows with `matplotlib.patches.FancyarrowPatch`"""
  1301. class ConnectionStyleFactory:
  1302. def __init__(self, connectionstyles, selfloop_height, ax=None):
  1303. import matplotlib as mpl
  1304. import matplotlib.path # call as mpl.path
  1305. import numpy as np
  1306. self.ax = ax
  1307. self.mpl = mpl
  1308. self.np = np
  1309. self.base_connection_styles = [
  1310. mpl.patches.ConnectionStyle(cs) for cs in connectionstyles
  1311. ]
  1312. self.n = len(self.base_connection_styles)
  1313. self.selfloop_height = selfloop_height
  1314. def curved(self, edge_index):
  1315. return self.base_connection_styles[edge_index % self.n]
  1316. def self_loop(self, edge_index):
  1317. def self_loop_connection(posA, posB, *args, **kwargs):
  1318. if not self.np.all(posA == posB):
  1319. raise nx.NetworkXError(
  1320. "`self_loop` connection style method"
  1321. "is only to be used for self-loops"
  1322. )
  1323. # this is called with _screen space_ values
  1324. # so convert back to data space
  1325. data_loc = self.ax.transData.inverted().transform(posA)
  1326. v_shift = 0.1 * self.selfloop_height
  1327. h_shift = v_shift * 0.5
  1328. # put the top of the loop first so arrow is not hidden by node
  1329. path = self.np.asarray(
  1330. [
  1331. # 1
  1332. [0, v_shift],
  1333. # 4 4 4
  1334. [h_shift, v_shift],
  1335. [h_shift, 0],
  1336. [0, 0],
  1337. # 4 4 4
  1338. [-h_shift, 0],
  1339. [-h_shift, v_shift],
  1340. [0, v_shift],
  1341. ]
  1342. )
  1343. # Rotate self loop 90 deg. if more than 1
  1344. # This will allow for maximum of 4 visible self loops
  1345. if edge_index % 4:
  1346. x, y = path.T
  1347. for _ in range(edge_index % 4):
  1348. x, y = y, -x
  1349. path = self.np.array([x, y]).T
  1350. return self.mpl.path.Path(
  1351. self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
  1352. )
  1353. return self_loop_connection
  1354. def __init__(
  1355. self,
  1356. edge_pos,
  1357. edgelist,
  1358. nodelist,
  1359. edge_indices,
  1360. node_size,
  1361. selfloop_height,
  1362. connectionstyle="arc3",
  1363. node_shape="o",
  1364. arrowstyle="-",
  1365. arrowsize=10,
  1366. edge_color="k",
  1367. alpha=None,
  1368. linewidth=1.0,
  1369. style="solid",
  1370. min_source_margin=0,
  1371. min_target_margin=0,
  1372. ax=None,
  1373. ):
  1374. import matplotlib as mpl
  1375. import matplotlib.patches # call as mpl.patches
  1376. import matplotlib.pyplot as plt
  1377. import numpy as np
  1378. if isinstance(connectionstyle, str):
  1379. connectionstyle = [connectionstyle]
  1380. elif np.iterable(connectionstyle):
  1381. connectionstyle = list(connectionstyle)
  1382. else:
  1383. msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable"
  1384. raise nx.NetworkXError(msg)
  1385. self.ax = ax
  1386. self.mpl = mpl
  1387. self.np = np
  1388. self.edge_pos = edge_pos
  1389. self.edgelist = edgelist
  1390. self.nodelist = nodelist
  1391. self.node_shape = node_shape
  1392. self.min_source_margin = min_source_margin
  1393. self.min_target_margin = min_target_margin
  1394. self.edge_indices = edge_indices
  1395. self.node_size = node_size
  1396. self.connectionstyle_factory = self.ConnectionStyleFactory(
  1397. connectionstyle, selfloop_height, ax
  1398. )
  1399. self.arrowstyle = arrowstyle
  1400. self.arrowsize = arrowsize
  1401. self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
  1402. self.linewidth = linewidth
  1403. self.style = style
  1404. if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos):
  1405. raise ValueError("arrowsize should have the same length as edgelist")
  1406. def __call__(self, i):
  1407. (x1, y1), (x2, y2) = self.edge_pos[i]
  1408. shrink_source = 0 # space from source to tail
  1409. shrink_target = 0 # space from head to target
  1410. if (
  1411. self.np.iterable(self.min_source_margin)
  1412. and not isinstance(self.min_source_margin, str)
  1413. and not isinstance(self.min_source_margin, tuple)
  1414. ):
  1415. min_source_margin = self.min_source_margin[i]
  1416. else:
  1417. min_source_margin = self.min_source_margin
  1418. if (
  1419. self.np.iterable(self.min_target_margin)
  1420. and not isinstance(self.min_target_margin, str)
  1421. and not isinstance(self.min_target_margin, tuple)
  1422. ):
  1423. min_target_margin = self.min_target_margin[i]
  1424. else:
  1425. min_target_margin = self.min_target_margin
  1426. if self.np.iterable(self.node_size): # many node sizes
  1427. source, target = self.edgelist[i][:2]
  1428. source_node_size = self.node_size[self.nodelist.index(source)]
  1429. target_node_size = self.node_size[self.nodelist.index(target)]
  1430. shrink_source = self.to_marker_edge(source_node_size, self.node_shape)
  1431. shrink_target = self.to_marker_edge(target_node_size, self.node_shape)
  1432. else:
  1433. shrink_source = self.to_marker_edge(self.node_size, self.node_shape)
  1434. shrink_target = shrink_source
  1435. shrink_source = max(shrink_source, min_source_margin)
  1436. shrink_target = max(shrink_target, min_target_margin)
  1437. # scale factor of arrow head
  1438. if isinstance(self.arrowsize, list):
  1439. mutation_scale = self.arrowsize[i]
  1440. else:
  1441. mutation_scale = self.arrowsize
  1442. if len(self.arrow_colors) > i:
  1443. arrow_color = self.arrow_colors[i]
  1444. elif len(self.arrow_colors) == 1:
  1445. arrow_color = self.arrow_colors[0]
  1446. else: # Cycle through colors
  1447. arrow_color = self.arrow_colors[i % len(self.arrow_colors)]
  1448. if self.np.iterable(self.linewidth):
  1449. if len(self.linewidth) > i:
  1450. linewidth = self.linewidth[i]
  1451. else:
  1452. linewidth = self.linewidth[i % len(self.linewidth)]
  1453. else:
  1454. linewidth = self.linewidth
  1455. if (
  1456. self.np.iterable(self.style)
  1457. and not isinstance(self.style, str)
  1458. and not isinstance(self.style, tuple)
  1459. ):
  1460. if len(self.style) > i:
  1461. linestyle = self.style[i]
  1462. else: # Cycle through styles
  1463. linestyle = self.style[i % len(self.style)]
  1464. else:
  1465. linestyle = self.style
  1466. if x1 == x2 and y1 == y2:
  1467. connectionstyle = self.connectionstyle_factory.self_loop(
  1468. self.edge_indices[i]
  1469. )
  1470. else:
  1471. connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i])
  1472. if (
  1473. self.np.iterable(self.arrowstyle)
  1474. and not isinstance(self.arrowstyle, str)
  1475. and not isinstance(self.arrowstyle, tuple)
  1476. ):
  1477. arrowstyle = self.arrowstyle[i]
  1478. else:
  1479. arrowstyle = self.arrowstyle
  1480. return self.mpl.patches.FancyArrowPatch(
  1481. (x1, y1),
  1482. (x2, y2),
  1483. arrowstyle=arrowstyle,
  1484. shrinkA=shrink_source,
  1485. shrinkB=shrink_target,
  1486. mutation_scale=mutation_scale,
  1487. color=arrow_color,
  1488. linewidth=linewidth,
  1489. connectionstyle=connectionstyle,
  1490. linestyle=linestyle,
  1491. zorder=1, # arrows go behind nodes
  1492. )
  1493. def to_marker_edge(self, marker_size, marker):
  1494. if marker in "s^>v<d": # `large` markers need extra space
  1495. return self.np.sqrt(2 * marker_size) / 2
  1496. else:
  1497. return self.np.sqrt(marker_size) / 2
  1498. def draw_networkx_edges(
  1499. G,
  1500. pos,
  1501. edgelist=None,
  1502. width=1.0,
  1503. edge_color="k",
  1504. style="solid",
  1505. alpha=None,
  1506. arrowstyle=None,
  1507. arrowsize=10,
  1508. edge_cmap=None,
  1509. edge_vmin=None,
  1510. edge_vmax=None,
  1511. ax=None,
  1512. arrows=None,
  1513. label=None,
  1514. node_size=300,
  1515. nodelist=None,
  1516. node_shape="o",
  1517. connectionstyle="arc3",
  1518. min_source_margin=0,
  1519. min_target_margin=0,
  1520. hide_ticks=True,
  1521. ):
  1522. r"""Draw the edges of the graph G.
  1523. This draws only the edges of the graph G.
  1524. Parameters
  1525. ----------
  1526. G : graph
  1527. A networkx graph
  1528. pos : dictionary
  1529. A dictionary with nodes as keys and positions as values.
  1530. Positions should be sequences of length 2.
  1531. edgelist : collection of edge tuples (default=G.edges())
  1532. Draw only specified edges
  1533. width : float or array of floats (default=1.0)
  1534. Line width of edges
  1535. edge_color : color or array of colors (default='k')
  1536. Edge color. Can be a single color or a sequence of colors with the same
  1537. length as edgelist. Color can be string or rgb (or rgba) tuple of
  1538. floats from 0-1. If numeric values are specified they will be
  1539. mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
  1540. style : string or array of strings (default='solid')
  1541. Edge line style e.g.: '-', '--', '-.', ':'
  1542. or words like 'solid' or 'dashed'.
  1543. Can be a single style or a sequence of styles with the same
  1544. length as the edge list.
  1545. If less styles than edges are given the styles will cycle.
  1546. If more styles than edges are given the styles will be used sequentially
  1547. and not be exhausted.
  1548. Also, `(offset, onoffseq)` tuples can be used as style instead of a strings.
  1549. (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
  1550. alpha : float or array of floats (default=None)
  1551. The edge transparency. This can be a single alpha value,
  1552. in which case it will be applied to all specified edges. Otherwise,
  1553. if it is an array, the elements of alpha will be applied to the colors
  1554. in order (cycling through alpha multiple times if necessary).
  1555. edge_cmap : Matplotlib colormap, optional
  1556. Colormap for mapping intensities of edges
  1557. edge_vmin,edge_vmax : floats, optional
  1558. Minimum and maximum for edge colormap scaling
  1559. ax : Matplotlib Axes object, optional
  1560. Draw the graph in the specified Matplotlib axes.
  1561. arrows : bool or None, optional (default=None)
  1562. If `None`, directed graphs draw arrowheads with
  1563. `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
  1564. via `~matplotlib.collections.LineCollection` for speed.
  1565. If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
  1566. If `False`, draw edges using LineCollection (linear and fast).
  1567. Note: Arrowheads will be the same color as edges.
  1568. arrowstyle : str or list of strs (default='-\|>' for directed graphs)
  1569. For directed graphs and `arrows==True` defaults to '-\|>',
  1570. For undirected graphs default to '-'.
  1571. See `matplotlib.patches.ArrowStyle` for more options.
  1572. arrowsize : int or list of ints(default=10)
  1573. For directed graphs, choose the size of the arrow head's length and
  1574. width. See `matplotlib.patches.FancyArrowPatch` for attribute
  1575. `mutation_scale` for more info.
  1576. connectionstyle : string or iterable of strings (default="arc3")
  1577. Pass the connectionstyle parameter to create curved arc of rounding
  1578. radius rad. For example, connectionstyle='arc3,rad=0.2'.
  1579. See `matplotlib.patches.ConnectionStyle` and
  1580. `matplotlib.patches.FancyArrowPatch` for more info.
  1581. If Iterable, index indicates i'th edge key of MultiGraph
  1582. node_size : scalar or array (default=300)
  1583. Size of nodes. Though the nodes are not drawn with this function, the
  1584. node size is used in determining edge positioning.
  1585. nodelist : list, optional (default=G.nodes())
  1586. This provides the node order for the `node_size` array (if it is an array).
  1587. node_shape : string (default='o')
  1588. The marker used for nodes, used in determining edge positioning.
  1589. Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'.
  1590. label : None or string
  1591. Label for legend
  1592. min_source_margin : int or list of ints (default=0)
  1593. The minimum margin (gap) at the beginning of the edge at the source.
  1594. min_target_margin : int or list of ints (default=0)
  1595. The minimum margin (gap) at the end of the edge at the target.
  1596. hide_ticks : bool, optional
  1597. Hide ticks of axes. When `True` (the default), ticks and ticklabels
  1598. are removed from the axes. To set ticks and tick labels to the pyplot default,
  1599. use ``hide_ticks=False``.
  1600. Returns
  1601. -------
  1602. matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch
  1603. If ``arrows=True``, a list of FancyArrowPatches is returned.
  1604. If ``arrows=False``, a LineCollection is returned.
  1605. If ``arrows=None`` (the default), then a LineCollection is returned if
  1606. `G` is undirected, otherwise returns a list of FancyArrowPatches.
  1607. Notes
  1608. -----
  1609. For directed graphs, arrows are drawn at the head end. Arrows can be
  1610. turned off with keyword arrows=False or by passing an arrowstyle without
  1611. an arrow on the end.
  1612. Be sure to include `node_size` as a keyword argument; arrows are
  1613. drawn considering the size of nodes.
  1614. Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch`
  1615. regardless of the value of `arrows` or whether `G` is directed.
  1616. When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the
  1617. FancyArrowPatches corresponding to the self-loops are not explicitly
  1618. returned. They should instead be accessed via the ``Axes.patches``
  1619. attribute (see examples).
  1620. Examples
  1621. --------
  1622. >>> G = nx.dodecahedral_graph()
  1623. >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
  1624. >>> G = nx.DiGraph()
  1625. >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
  1626. >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
  1627. >>> alphas = [0.3, 0.4, 0.5]
  1628. >>> for i, arc in enumerate(arcs): # change alpha values of arcs
  1629. ... arc.set_alpha(alphas[i])
  1630. The FancyArrowPatches corresponding to self-loops are not always
  1631. returned, but can always be accessed via the ``patches`` attribute of the
  1632. `matplotlib.Axes` object.
  1633. >>> import matplotlib.pyplot as plt
  1634. >>> fig, ax = plt.subplots()
  1635. >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0
  1636. >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax)
  1637. >>> self_loop_fap = ax.patches[0]
  1638. Also see the NetworkX drawing examples at
  1639. https://networkx.org/documentation/latest/auto_examples/index.html
  1640. See Also
  1641. --------
  1642. draw
  1643. draw_networkx
  1644. draw_networkx_nodes
  1645. draw_networkx_labels
  1646. draw_networkx_edge_labels
  1647. """
  1648. import warnings
  1649. import matplotlib as mpl
  1650. import matplotlib.collections # call as mpl.collections
  1651. import matplotlib.colors # call as mpl.colors
  1652. import matplotlib.pyplot as plt
  1653. import numpy as np
  1654. # The default behavior is to use LineCollection to draw edges for
  1655. # undirected graphs (for performance reasons) and use FancyArrowPatches
  1656. # for directed graphs.
  1657. # The `arrows` keyword can be used to override the default behavior
  1658. if arrows is None:
  1659. use_linecollection = not (G.is_directed() or G.is_multigraph())
  1660. else:
  1661. if not isinstance(arrows, bool):
  1662. raise TypeError("Argument `arrows` must be of type bool or None")
  1663. use_linecollection = not arrows
  1664. if isinstance(connectionstyle, str):
  1665. connectionstyle = [connectionstyle]
  1666. elif np.iterable(connectionstyle):
  1667. connectionstyle = list(connectionstyle)
  1668. else:
  1669. msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable"
  1670. raise nx.NetworkXError(msg)
  1671. # Some kwargs only apply to FancyArrowPatches. Warn users when they use
  1672. # non-default values for these kwargs when LineCollection is being used
  1673. # instead of silently ignoring the specified option
  1674. if use_linecollection:
  1675. msg = (
  1676. "\n\nThe {0} keyword argument is not applicable when drawing edges\n"
  1677. "with LineCollection.\n\n"
  1678. "To make this warning go away, either specify `arrows=True` to\n"
  1679. "force FancyArrowPatches or use the default values.\n"
  1680. "Note that using FancyArrowPatches may be slow for large graphs.\n"
  1681. )
  1682. if arrowstyle is not None:
  1683. warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2)
  1684. if arrowsize != 10:
  1685. warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2)
  1686. if min_source_margin != 0:
  1687. warnings.warn(
  1688. msg.format("min_source_margin"), category=UserWarning, stacklevel=2
  1689. )
  1690. if min_target_margin != 0:
  1691. warnings.warn(
  1692. msg.format("min_target_margin"), category=UserWarning, stacklevel=2
  1693. )
  1694. if any(cs != "arc3" for cs in connectionstyle):
  1695. warnings.warn(
  1696. msg.format("connectionstyle"), category=UserWarning, stacklevel=2
  1697. )
  1698. # NOTE: Arrowstyle modification must occur after the warnings section
  1699. if arrowstyle is None:
  1700. arrowstyle = "-|>" if G.is_directed() else "-"
  1701. if ax is None:
  1702. ax = plt.gca()
  1703. if edgelist is None:
  1704. edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise
  1705. if len(edgelist):
  1706. if G.is_multigraph():
  1707. key_count = collections.defaultdict(lambda: itertools.count(0))
  1708. edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
  1709. else:
  1710. edge_indices = [0] * len(edgelist)
  1711. else: # no edges!
  1712. return []
  1713. if nodelist is None:
  1714. nodelist = list(G.nodes())
  1715. # FancyArrowPatch handles color=None different from LineCollection
  1716. if edge_color is None:
  1717. edge_color = "k"
  1718. # set edge positions
  1719. edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
  1720. # Check if edge_color is an array of floats and map to edge_cmap.
  1721. # This is the only case handled differently from matplotlib
  1722. if (
  1723. np.iterable(edge_color)
  1724. and (len(edge_color) == len(edge_pos))
  1725. and np.all([isinstance(c, Number) for c in edge_color])
  1726. ):
  1727. if edge_cmap is not None:
  1728. assert isinstance(edge_cmap, mpl.colors.Colormap)
  1729. else:
  1730. edge_cmap = plt.get_cmap()
  1731. if edge_vmin is None:
  1732. edge_vmin = min(edge_color)
  1733. if edge_vmax is None:
  1734. edge_vmax = max(edge_color)
  1735. color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
  1736. edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
  1737. # compute initial view
  1738. minx = np.amin(np.ravel(edge_pos[:, :, 0]))
  1739. maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
  1740. miny = np.amin(np.ravel(edge_pos[:, :, 1]))
  1741. maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
  1742. w = maxx - minx
  1743. h = maxy - miny
  1744. # Self-loops are scaled by view extent, except in cases the extent
  1745. # is 0, e.g. for a single node. In this case, fall back to scaling
  1746. # by the maximum node size
  1747. selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
  1748. fancy_arrow_factory = FancyArrowFactory(
  1749. edge_pos,
  1750. edgelist,
  1751. nodelist,
  1752. edge_indices,
  1753. node_size,
  1754. selfloop_height,
  1755. connectionstyle,
  1756. node_shape,
  1757. arrowstyle,
  1758. arrowsize,
  1759. edge_color,
  1760. alpha,
  1761. width,
  1762. style,
  1763. min_source_margin,
  1764. min_target_margin,
  1765. ax=ax,
  1766. )
  1767. # Draw the edges
  1768. if use_linecollection:
  1769. edge_collection = mpl.collections.LineCollection(
  1770. edge_pos,
  1771. colors=edge_color,
  1772. linewidths=width,
  1773. antialiaseds=(1,),
  1774. linestyle=style,
  1775. alpha=alpha,
  1776. )
  1777. edge_collection.set_cmap(edge_cmap)
  1778. edge_collection.set_clim(edge_vmin, edge_vmax)
  1779. edge_collection.set_zorder(1) # edges go behind nodes
  1780. edge_collection.set_label(label)
  1781. ax.add_collection(edge_collection)
  1782. edge_viz_obj = edge_collection
  1783. # Make sure selfloop edges are also drawn
  1784. # ---------------------------------------
  1785. selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist]
  1786. if selfloops_to_draw:
  1787. edgelist_tuple = list(map(tuple, edgelist))
  1788. arrow_collection = []
  1789. for loop in selfloops_to_draw:
  1790. i = edgelist_tuple.index(loop)
  1791. arrow = fancy_arrow_factory(i)
  1792. arrow_collection.append(arrow)
  1793. ax.add_patch(arrow)
  1794. else:
  1795. edge_viz_obj = []
  1796. for i in range(len(edgelist)):
  1797. arrow = fancy_arrow_factory(i)
  1798. ax.add_patch(arrow)
  1799. edge_viz_obj.append(arrow)
  1800. # update view after drawing
  1801. padx, pady = 0.05 * w, 0.05 * h
  1802. corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
  1803. ax.update_datalim(corners)
  1804. ax.autoscale_view()
  1805. if hide_ticks:
  1806. ax.tick_params(
  1807. axis="both",
  1808. which="both",
  1809. bottom=False,
  1810. left=False,
  1811. labelbottom=False,
  1812. labelleft=False,
  1813. )
  1814. return edge_viz_obj
  1815. def draw_networkx_labels(
  1816. G,
  1817. pos,
  1818. labels=None,
  1819. font_size=12,
  1820. font_color="k",
  1821. font_family="sans-serif",
  1822. font_weight="normal",
  1823. alpha=None,
  1824. bbox=None,
  1825. horizontalalignment="center",
  1826. verticalalignment="center",
  1827. ax=None,
  1828. clip_on=True,
  1829. hide_ticks=True,
  1830. ):
  1831. """Draw node labels on the graph G.
  1832. Parameters
  1833. ----------
  1834. G : graph
  1835. A networkx graph
  1836. pos : dictionary
  1837. A dictionary with nodes as keys and positions as values.
  1838. Positions should be sequences of length 2.
  1839. labels : dictionary (default={n: n for n in G})
  1840. Node labels in a dictionary of text labels keyed by node.
  1841. Node-keys in labels should appear as keys in `pos`.
  1842. If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
  1843. font_size : int or dictionary of nodes to ints (default=12)
  1844. Font size for text labels.
  1845. font_color : color or dictionary of nodes to colors (default='k' black)
  1846. Font color string. Color can be string or rgb (or rgba) tuple of
  1847. floats from 0-1.
  1848. font_weight : string or dictionary of nodes to strings (default='normal')
  1849. Font weight.
  1850. font_family : string or dictionary of nodes to strings (default='sans-serif')
  1851. Font family.
  1852. alpha : float or None or dictionary of nodes to floats (default=None)
  1853. The text transparency.
  1854. bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
  1855. Specify text box properties (e.g. shape, color etc.) for node labels.
  1856. horizontalalignment : string or array of strings (default='center')
  1857. Horizontal alignment {'center', 'right', 'left'}. If an array is
  1858. specified it must be the same length as `nodelist`.
  1859. verticalalignment : string (default='center')
  1860. Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}.
  1861. If an array is specified it must be the same length as `nodelist`.
  1862. ax : Matplotlib Axes object, optional
  1863. Draw the graph in the specified Matplotlib axes.
  1864. clip_on : bool (default=True)
  1865. Turn on clipping of node labels at axis boundaries
  1866. hide_ticks : bool, optional
  1867. Hide ticks of axes. When `True` (the default), ticks and ticklabels
  1868. are removed from the axes. To set ticks and tick labels to the pyplot default,
  1869. use ``hide_ticks=False``.
  1870. Returns
  1871. -------
  1872. dict
  1873. `dict` of labels keyed on the nodes
  1874. Examples
  1875. --------
  1876. >>> G = nx.dodecahedral_graph()
  1877. >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
  1878. Also see the NetworkX drawing examples at
  1879. https://networkx.org/documentation/latest/auto_examples/index.html
  1880. See Also
  1881. --------
  1882. draw
  1883. draw_networkx
  1884. draw_networkx_nodes
  1885. draw_networkx_edges
  1886. draw_networkx_edge_labels
  1887. """
  1888. import matplotlib.pyplot as plt
  1889. if ax is None:
  1890. ax = plt.gca()
  1891. if labels is None:
  1892. labels = {n: n for n in G.nodes()}
  1893. individual_params = set()
  1894. def check_individual_params(p_value, p_name):
  1895. if isinstance(p_value, dict):
  1896. if len(p_value) != len(labels):
  1897. raise ValueError(f"{p_name} must have the same length as labels.")
  1898. individual_params.add(p_name)
  1899. def get_param_value(node, p_value, p_name):
  1900. if p_name in individual_params:
  1901. return p_value[node]
  1902. return p_value
  1903. check_individual_params(font_size, "font_size")
  1904. check_individual_params(font_color, "font_color")
  1905. check_individual_params(font_weight, "font_weight")
  1906. check_individual_params(font_family, "font_family")
  1907. check_individual_params(alpha, "alpha")
  1908. text_items = {} # there is no text collection so we'll fake one
  1909. for n, label in labels.items():
  1910. (x, y) = pos[n]
  1911. if not isinstance(label, str):
  1912. label = str(label) # this makes "1" and 1 labeled the same
  1913. t = ax.text(
  1914. x,
  1915. y,
  1916. label,
  1917. size=get_param_value(n, font_size, "font_size"),
  1918. color=get_param_value(n, font_color, "font_color"),
  1919. family=get_param_value(n, font_family, "font_family"),
  1920. weight=get_param_value(n, font_weight, "font_weight"),
  1921. alpha=get_param_value(n, alpha, "alpha"),
  1922. horizontalalignment=horizontalalignment,
  1923. verticalalignment=verticalalignment,
  1924. transform=ax.transData,
  1925. bbox=bbox,
  1926. clip_on=clip_on,
  1927. )
  1928. text_items[n] = t
  1929. if hide_ticks:
  1930. ax.tick_params(
  1931. axis="both",
  1932. which="both",
  1933. bottom=False,
  1934. left=False,
  1935. labelbottom=False,
  1936. labelleft=False,
  1937. )
  1938. return text_items
  1939. def draw_networkx_edge_labels(
  1940. G,
  1941. pos,
  1942. edge_labels=None,
  1943. label_pos=0.5,
  1944. font_size=10,
  1945. font_color="k",
  1946. font_family="sans-serif",
  1947. font_weight="normal",
  1948. alpha=None,
  1949. bbox=None,
  1950. horizontalalignment="center",
  1951. verticalalignment="center",
  1952. ax=None,
  1953. rotate=True,
  1954. clip_on=True,
  1955. node_size=300,
  1956. nodelist=None,
  1957. connectionstyle="arc3",
  1958. hide_ticks=True,
  1959. ):
  1960. """Draw edge labels.
  1961. Parameters
  1962. ----------
  1963. G : graph
  1964. A networkx graph
  1965. pos : dictionary
  1966. A dictionary with nodes as keys and positions as values.
  1967. Positions should be sequences of length 2.
  1968. edge_labels : dictionary (default=None)
  1969. Edge labels in a dictionary of labels keyed by edge two-tuple.
  1970. Only labels for the keys in the dictionary are drawn.
  1971. label_pos : float (default=0.5)
  1972. Position of edge label along edge (0=head, 0.5=center, 1=tail)
  1973. font_size : int (default=10)
  1974. Font size for text labels
  1975. font_color : color (default='k' black)
  1976. Font color string. Color can be string or rgb (or rgba) tuple of
  1977. floats from 0-1.
  1978. font_weight : string (default='normal')
  1979. Font weight
  1980. font_family : string (default='sans-serif')
  1981. Font family
  1982. alpha : float or None (default=None)
  1983. The text transparency
  1984. bbox : Matplotlib bbox, optional
  1985. Specify text box properties (e.g. shape, color etc.) for edge labels.
  1986. Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
  1987. horizontalalignment : string (default='center')
  1988. Horizontal alignment {'center', 'right', 'left'}
  1989. verticalalignment : string (default='center')
  1990. Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
  1991. ax : Matplotlib Axes object, optional
  1992. Draw the graph in the specified Matplotlib axes.
  1993. rotate : bool (default=True)
  1994. Rotate edge labels to lie parallel to edges
  1995. clip_on : bool (default=True)
  1996. Turn on clipping of edge labels at axis boundaries
  1997. node_size : scalar or array (default=300)
  1998. Size of nodes. If an array it must be the same length as nodelist.
  1999. nodelist : list, optional (default=G.nodes())
  2000. This provides the node order for the `node_size` array (if it is an array).
  2001. connectionstyle : string or iterable of strings (default="arc3")
  2002. Pass the connectionstyle parameter to create curved arc of rounding
  2003. radius rad. For example, connectionstyle='arc3,rad=0.2'.
  2004. See `matplotlib.patches.ConnectionStyle` and
  2005. `matplotlib.patches.FancyArrowPatch` for more info.
  2006. If Iterable, index indicates i'th edge key of MultiGraph
  2007. hide_ticks : bool, optional
  2008. Hide ticks of axes. When `True` (the default), ticks and ticklabels
  2009. are removed from the axes. To set ticks and tick labels to the pyplot default,
  2010. use ``hide_ticks=False``.
  2011. Returns
  2012. -------
  2013. dict
  2014. `dict` of labels keyed by edge
  2015. Examples
  2016. --------
  2017. >>> G = nx.dodecahedral_graph()
  2018. >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
  2019. Also see the NetworkX drawing examples at
  2020. https://networkx.org/documentation/latest/auto_examples/index.html
  2021. See Also
  2022. --------
  2023. draw
  2024. draw_networkx
  2025. draw_networkx_nodes
  2026. draw_networkx_edges
  2027. draw_networkx_labels
  2028. """
  2029. import matplotlib as mpl
  2030. import matplotlib.pyplot as plt
  2031. import numpy as np
  2032. class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
  2033. pass
  2034. # use default box of white with white border
  2035. if bbox is None:
  2036. bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
  2037. if isinstance(connectionstyle, str):
  2038. connectionstyle = [connectionstyle]
  2039. elif np.iterable(connectionstyle):
  2040. connectionstyle = list(connectionstyle)
  2041. else:
  2042. raise nx.NetworkXError(
  2043. "draw_networkx_edges arg `connectionstyle` must be"
  2044. "string or iterable of strings"
  2045. )
  2046. if ax is None:
  2047. ax = plt.gca()
  2048. if edge_labels is None:
  2049. kwds = {"keys": True} if G.is_multigraph() else {}
  2050. edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)}
  2051. # NOTHING TO PLOT
  2052. if not edge_labels:
  2053. return {}
  2054. edgelist, labels = zip(*edge_labels.items())
  2055. if nodelist is None:
  2056. nodelist = list(G.nodes())
  2057. # set edge positions
  2058. edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
  2059. if G.is_multigraph():
  2060. key_count = collections.defaultdict(lambda: itertools.count(0))
  2061. edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
  2062. else:
  2063. edge_indices = [0] * len(edgelist)
  2064. # Used to determine self loop mid-point
  2065. # Note, that this will not be accurate,
  2066. # if not drawing edge_labels for all edges drawn
  2067. h = 0
  2068. if edge_labels:
  2069. miny = np.amin(np.ravel(edge_pos[:, :, 1]))
  2070. maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
  2071. h = maxy - miny
  2072. selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
  2073. fancy_arrow_factory = FancyArrowFactory(
  2074. edge_pos,
  2075. edgelist,
  2076. nodelist,
  2077. edge_indices,
  2078. node_size,
  2079. selfloop_height,
  2080. connectionstyle,
  2081. ax=ax,
  2082. )
  2083. individual_params = {}
  2084. def check_individual_params(p_value, p_name):
  2085. # TODO should this be list or array (as in a numpy array)?
  2086. if isinstance(p_value, list):
  2087. if len(p_value) != len(edgelist):
  2088. raise ValueError(f"{p_name} must have the same length as edgelist.")
  2089. individual_params[p_name] = p_value.iter()
  2090. # Don't need to pass in an edge because these are lists, not dicts
  2091. def get_param_value(p_value, p_name):
  2092. if p_name in individual_params:
  2093. return next(individual_params[p_name])
  2094. return p_value
  2095. check_individual_params(font_size, "font_size")
  2096. check_individual_params(font_color, "font_color")
  2097. check_individual_params(font_weight, "font_weight")
  2098. check_individual_params(alpha, "alpha")
  2099. check_individual_params(horizontalalignment, "horizontalalignment")
  2100. check_individual_params(verticalalignment, "verticalalignment")
  2101. check_individual_params(rotate, "rotate")
  2102. check_individual_params(label_pos, "label_pos")
  2103. text_items = {}
  2104. for i, (edge, label) in enumerate(zip(edgelist, labels)):
  2105. if not isinstance(label, str):
  2106. label = str(label) # this makes "1" and 1 labeled the same
  2107. n1, n2 = edge[:2]
  2108. arrow = fancy_arrow_factory(i)
  2109. if n1 == n2:
  2110. connectionstyle_obj = arrow.get_connectionstyle()
  2111. posA = ax.transData.transform(pos[n1])
  2112. path_disp = connectionstyle_obj(posA, posA)
  2113. path_data = ax.transData.inverted().transform_path(path_disp)
  2114. x, y = path_data.vertices[0]
  2115. text_items[edge] = ax.text(
  2116. x,
  2117. y,
  2118. label,
  2119. size=get_param_value(font_size, "font_size"),
  2120. color=get_param_value(font_color, "font_color"),
  2121. family=get_param_value(font_family, "font_family"),
  2122. weight=get_param_value(font_weight, "font_weight"),
  2123. alpha=get_param_value(alpha, "alpha"),
  2124. horizontalalignment=get_param_value(
  2125. horizontalalignment, "horizontalalignment"
  2126. ),
  2127. verticalalignment=get_param_value(
  2128. verticalalignment, "verticalalignment"
  2129. ),
  2130. rotation=0,
  2131. transform=ax.transData,
  2132. bbox=bbox,
  2133. zorder=1,
  2134. clip_on=clip_on,
  2135. )
  2136. else:
  2137. text_items[edge] = CurvedArrowText(
  2138. arrow,
  2139. label,
  2140. size=get_param_value(font_size, "font_size"),
  2141. color=get_param_value(font_color, "font_color"),
  2142. family=get_param_value(font_family, "font_family"),
  2143. weight=get_param_value(font_weight, "font_weight"),
  2144. alpha=get_param_value(alpha, "alpha"),
  2145. horizontalalignment=get_param_value(
  2146. horizontalalignment, "horizontalalignment"
  2147. ),
  2148. verticalalignment=get_param_value(
  2149. verticalalignment, "verticalalignment"
  2150. ),
  2151. transform=ax.transData,
  2152. bbox=bbox,
  2153. zorder=1,
  2154. clip_on=clip_on,
  2155. label_pos=get_param_value(label_pos, "label_pos"),
  2156. labels_horizontal=not get_param_value(rotate, "rotate"),
  2157. ax=ax,
  2158. )
  2159. if hide_ticks:
  2160. ax.tick_params(
  2161. axis="both",
  2162. which="both",
  2163. bottom=False,
  2164. left=False,
  2165. labelbottom=False,
  2166. labelleft=False,
  2167. )
  2168. return text_items
  2169. def draw_bipartite(G, **kwargs):
  2170. """Draw the graph `G` with a bipartite layout.
  2171. This is a convenience function equivalent to::
  2172. nx.draw(G, pos=nx.bipartite_layout(G), **kwargs)
  2173. Parameters
  2174. ----------
  2175. G : graph
  2176. A networkx graph
  2177. kwargs : optional keywords
  2178. See `draw_networkx` for a description of optional keywords.
  2179. Raises
  2180. ------
  2181. NetworkXError :
  2182. If `G` is not bipartite.
  2183. Notes
  2184. -----
  2185. The layout is computed each time this function is called. For
  2186. repeated drawing it is much more efficient to call
  2187. `~networkx.drawing.layout.bipartite_layout` directly and reuse the result::
  2188. >>> G = nx.complete_bipartite_graph(3, 3)
  2189. >>> pos = nx.bipartite_layout(G)
  2190. >>> nx.draw(G, pos=pos) # Draw the original graph
  2191. >>> # Draw a subgraph, reusing the same node positions
  2192. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2193. Examples
  2194. --------
  2195. >>> G = nx.complete_bipartite_graph(2, 5)
  2196. >>> nx.draw_bipartite(G)
  2197. See Also
  2198. --------
  2199. :func:`~networkx.drawing.layout.bipartite_layout`
  2200. """
  2201. draw(G, pos=nx.bipartite_layout(G), **kwargs)
  2202. def draw_circular(G, **kwargs):
  2203. """Draw the graph `G` with a circular layout.
  2204. This is a convenience function equivalent to::
  2205. nx.draw(G, pos=nx.circular_layout(G), **kwargs)
  2206. Parameters
  2207. ----------
  2208. G : graph
  2209. A networkx graph
  2210. kwargs : optional keywords
  2211. See `draw_networkx` for a description of optional keywords.
  2212. Notes
  2213. -----
  2214. The layout is computed each time this function is called. For
  2215. repeated drawing it is much more efficient to call
  2216. `~networkx.drawing.layout.circular_layout` directly and reuse the result::
  2217. >>> G = nx.complete_graph(5)
  2218. >>> pos = nx.circular_layout(G)
  2219. >>> nx.draw(G, pos=pos) # Draw the original graph
  2220. >>> # Draw a subgraph, reusing the same node positions
  2221. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2222. Examples
  2223. --------
  2224. >>> G = nx.path_graph(5)
  2225. >>> nx.draw_circular(G)
  2226. See Also
  2227. --------
  2228. :func:`~networkx.drawing.layout.circular_layout`
  2229. """
  2230. draw(G, pos=nx.circular_layout(G), **kwargs)
  2231. def draw_kamada_kawai(G, **kwargs):
  2232. """Draw the graph `G` with a Kamada-Kawai force-directed layout.
  2233. This is a convenience function equivalent to::
  2234. nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
  2235. Parameters
  2236. ----------
  2237. G : graph
  2238. A networkx graph
  2239. kwargs : optional keywords
  2240. See `draw_networkx` for a description of optional keywords.
  2241. Notes
  2242. -----
  2243. The layout is computed each time this function is called.
  2244. For repeated drawing it is much more efficient to call
  2245. `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the
  2246. result::
  2247. >>> G = nx.complete_graph(5)
  2248. >>> pos = nx.kamada_kawai_layout(G)
  2249. >>> nx.draw(G, pos=pos) # Draw the original graph
  2250. >>> # Draw a subgraph, reusing the same node positions
  2251. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2252. Examples
  2253. --------
  2254. >>> G = nx.path_graph(5)
  2255. >>> nx.draw_kamada_kawai(G)
  2256. See Also
  2257. --------
  2258. :func:`~networkx.drawing.layout.kamada_kawai_layout`
  2259. """
  2260. draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
  2261. def draw_random(G, **kwargs):
  2262. """Draw the graph `G` with a random layout.
  2263. This is a convenience function equivalent to::
  2264. nx.draw(G, pos=nx.random_layout(G), **kwargs)
  2265. Parameters
  2266. ----------
  2267. G : graph
  2268. A networkx graph
  2269. kwargs : optional keywords
  2270. See `draw_networkx` for a description of optional keywords.
  2271. Notes
  2272. -----
  2273. The layout is computed each time this function is called.
  2274. For repeated drawing it is much more efficient to call
  2275. `~networkx.drawing.layout.random_layout` directly and reuse the result::
  2276. >>> G = nx.complete_graph(5)
  2277. >>> pos = nx.random_layout(G)
  2278. >>> nx.draw(G, pos=pos) # Draw the original graph
  2279. >>> # Draw a subgraph, reusing the same node positions
  2280. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2281. Examples
  2282. --------
  2283. >>> G = nx.lollipop_graph(4, 3)
  2284. >>> nx.draw_random(G)
  2285. See Also
  2286. --------
  2287. :func:`~networkx.drawing.layout.random_layout`
  2288. """
  2289. draw(G, pos=nx.random_layout(G), **kwargs)
  2290. def draw_spectral(G, **kwargs):
  2291. """Draw the graph `G` with a spectral 2D layout.
  2292. This is a convenience function equivalent to::
  2293. nx.draw(G, pos=nx.spectral_layout(G), **kwargs)
  2294. For more information about how node positions are determined, see
  2295. `~networkx.drawing.layout.spectral_layout`.
  2296. Parameters
  2297. ----------
  2298. G : graph
  2299. A networkx graph
  2300. kwargs : optional keywords
  2301. See `draw_networkx` for a description of optional keywords.
  2302. Notes
  2303. -----
  2304. The layout is computed each time this function is called.
  2305. For repeated drawing it is much more efficient to call
  2306. `~networkx.drawing.layout.spectral_layout` directly and reuse the result::
  2307. >>> G = nx.complete_graph(5)
  2308. >>> pos = nx.spectral_layout(G)
  2309. >>> nx.draw(G, pos=pos) # Draw the original graph
  2310. >>> # Draw a subgraph, reusing the same node positions
  2311. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2312. Examples
  2313. --------
  2314. >>> G = nx.path_graph(5)
  2315. >>> nx.draw_spectral(G)
  2316. See Also
  2317. --------
  2318. :func:`~networkx.drawing.layout.spectral_layout`
  2319. """
  2320. draw(G, pos=nx.spectral_layout(G), **kwargs)
  2321. def draw_spring(G, **kwargs):
  2322. """Draw the graph `G` with a spring layout.
  2323. This is a convenience function equivalent to::
  2324. nx.draw(G, pos=nx.spring_layout(G), **kwargs)
  2325. Parameters
  2326. ----------
  2327. G : graph
  2328. A networkx graph
  2329. kwargs : optional keywords
  2330. See `draw_networkx` for a description of optional keywords.
  2331. Notes
  2332. -----
  2333. `~networkx.drawing.layout.spring_layout` is also the default layout for
  2334. `draw`, so this function is equivalent to `draw`.
  2335. The layout is computed each time this function is called.
  2336. For repeated drawing it is much more efficient to call
  2337. `~networkx.drawing.layout.spring_layout` directly and reuse the result::
  2338. >>> G = nx.complete_graph(5)
  2339. >>> pos = nx.spring_layout(G)
  2340. >>> nx.draw(G, pos=pos) # Draw the original graph
  2341. >>> # Draw a subgraph, reusing the same node positions
  2342. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2343. Examples
  2344. --------
  2345. >>> G = nx.path_graph(20)
  2346. >>> nx.draw_spring(G)
  2347. See Also
  2348. --------
  2349. draw
  2350. :func:`~networkx.drawing.layout.spring_layout`
  2351. """
  2352. draw(G, pos=nx.spring_layout(G), **kwargs)
  2353. def draw_shell(G, nlist=None, **kwargs):
  2354. """Draw networkx graph `G` with shell layout.
  2355. This is a convenience function equivalent to::
  2356. nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
  2357. Parameters
  2358. ----------
  2359. G : graph
  2360. A networkx graph
  2361. nlist : list of list of nodes, optional
  2362. A list containing lists of nodes representing the shells.
  2363. Default is `None`, meaning all nodes are in a single shell.
  2364. See `~networkx.drawing.layout.shell_layout` for details.
  2365. kwargs : optional keywords
  2366. See `draw_networkx` for a description of optional keywords.
  2367. Notes
  2368. -----
  2369. The layout is computed each time this function is called.
  2370. For repeated drawing it is much more efficient to call
  2371. `~networkx.drawing.layout.shell_layout` directly and reuse the result::
  2372. >>> G = nx.complete_graph(5)
  2373. >>> pos = nx.shell_layout(G)
  2374. >>> nx.draw(G, pos=pos) # Draw the original graph
  2375. >>> # Draw a subgraph, reusing the same node positions
  2376. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2377. Examples
  2378. --------
  2379. >>> G = nx.path_graph(4)
  2380. >>> shells = [[0], [1, 2, 3]]
  2381. >>> nx.draw_shell(G, nlist=shells)
  2382. See Also
  2383. --------
  2384. :func:`~networkx.drawing.layout.shell_layout`
  2385. """
  2386. draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
  2387. def draw_planar(G, **kwargs):
  2388. """Draw a planar networkx graph `G` with planar layout.
  2389. This is a convenience function equivalent to::
  2390. nx.draw(G, pos=nx.planar_layout(G), **kwargs)
  2391. Parameters
  2392. ----------
  2393. G : graph
  2394. A planar networkx graph
  2395. kwargs : optional keywords
  2396. See `draw_networkx` for a description of optional keywords.
  2397. Raises
  2398. ------
  2399. NetworkXException
  2400. When `G` is not planar
  2401. Notes
  2402. -----
  2403. The layout is computed each time this function is called.
  2404. For repeated drawing it is much more efficient to call
  2405. `~networkx.drawing.layout.planar_layout` directly and reuse the result::
  2406. >>> G = nx.path_graph(5)
  2407. >>> pos = nx.planar_layout(G)
  2408. >>> nx.draw(G, pos=pos) # Draw the original graph
  2409. >>> # Draw a subgraph, reusing the same node positions
  2410. >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
  2411. Examples
  2412. --------
  2413. >>> G = nx.path_graph(4)
  2414. >>> nx.draw_planar(G)
  2415. See Also
  2416. --------
  2417. :func:`~networkx.drawing.layout.planar_layout`
  2418. """
  2419. draw(G, pos=nx.planar_layout(G), **kwargs)
  2420. def draw_forceatlas2(G, **kwargs):
  2421. """Draw a networkx graph with forceatlas2 layout.
  2422. This is a convenience function equivalent to::
  2423. nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
  2424. Parameters
  2425. ----------
  2426. G : graph
  2427. A networkx graph
  2428. kwargs : optional keywords
  2429. See networkx.draw_networkx() for a description of optional keywords,
  2430. with the exception of the pos parameter which is not used by this
  2431. function.
  2432. """
  2433. draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
  2434. def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
  2435. """Apply an alpha (or list of alphas) to the colors provided.
  2436. Parameters
  2437. ----------
  2438. colors : color string or array of floats (default='r')
  2439. Color of element. Can be a single color format string,
  2440. or a sequence of colors with the same length as nodelist.
  2441. If numeric values are specified they will be mapped to
  2442. colors using the cmap and vmin,vmax parameters. See
  2443. matplotlib.scatter for more details.
  2444. alpha : float or array of floats
  2445. Alpha values for elements. This can be a single alpha value, in
  2446. which case it will be applied to all the elements of color. Otherwise,
  2447. if it is an array, the elements of alpha will be applied to the colors
  2448. in order (cycling through alpha multiple times if necessary).
  2449. elem_list : array of networkx objects
  2450. The list of elements which are being colored. These could be nodes,
  2451. edges or labels.
  2452. cmap : matplotlib colormap
  2453. Color map for use if colors is a list of floats corresponding to points
  2454. on a color mapping.
  2455. vmin, vmax : float
  2456. Minimum and maximum values for normalizing colors if a colormap is used
  2457. Returns
  2458. -------
  2459. rgba_colors : numpy ndarray
  2460. Array containing RGBA format values for each of the node colours.
  2461. """
  2462. from itertools import cycle, islice
  2463. import matplotlib as mpl
  2464. import matplotlib.cm # call as mpl.cm
  2465. import matplotlib.colors # call as mpl.colors
  2466. import numpy as np
  2467. # If we have been provided with a list of numbers as long as elem_list,
  2468. # apply the color mapping.
  2469. if len(colors) == len(elem_list) and isinstance(colors[0], Number):
  2470. mapper = mpl.cm.ScalarMappable(cmap=cmap)
  2471. mapper.set_clim(vmin, vmax)
  2472. rgba_colors = mapper.to_rgba(colors)
  2473. # Otherwise, convert colors to matplotlib's RGB using the colorConverter
  2474. # object. These are converted to numpy ndarrays to be consistent with the
  2475. # to_rgba method of ScalarMappable.
  2476. else:
  2477. try:
  2478. rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
  2479. except ValueError:
  2480. rgba_colors = np.array(
  2481. [mpl.colors.colorConverter.to_rgba(color) for color in colors]
  2482. )
  2483. # Set the final column of the rgba_colors to have the relevant alpha values
  2484. try:
  2485. # If alpha is longer than the number of colors, resize to the number of
  2486. # elements. Also, if rgba_colors.size (the number of elements of
  2487. # rgba_colors) is the same as the number of elements, resize the array,
  2488. # to avoid it being interpreted as a colormap by scatter()
  2489. if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
  2490. rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
  2491. rgba_colors[1:, 0] = rgba_colors[0, 0]
  2492. rgba_colors[1:, 1] = rgba_colors[0, 1]
  2493. rgba_colors[1:, 2] = rgba_colors[0, 2]
  2494. rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
  2495. except TypeError:
  2496. rgba_colors[:, -1] = alpha
  2497. return rgba_colors