| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978 |
- """
- **********
- Matplotlib
- **********
- Draw networks with matplotlib.
- Examples
- --------
- >>> G = nx.complete_graph(5)
- >>> nx.draw(G)
- See Also
- --------
- - :doc:`matplotlib <matplotlib:index>`
- - :func:`matplotlib.pyplot.scatter`
- - :obj:`matplotlib.patches.FancyArrowPatch`
- """
- import collections
- import itertools
- import math
- from numbers import Number
- import networkx as nx
- __all__ = [
- "display",
- "apply_matplotlib_colors",
- "draw",
- "draw_networkx",
- "draw_networkx_nodes",
- "draw_networkx_edges",
- "draw_networkx_labels",
- "draw_networkx_edge_labels",
- "draw_bipartite",
- "draw_circular",
- "draw_kamada_kawai",
- "draw_random",
- "draw_spectral",
- "draw_spring",
- "draw_planar",
- "draw_shell",
- "draw_forceatlas2",
- ]
- def apply_matplotlib_colors(
- G, src_attr, dest_attr, map, vmin=None, vmax=None, nodes=True
- ):
- """
- Apply colors from a matplotlib colormap to a graph.
- Reads values from the `src_attr` and use a matplotlib colormap
- to produce a color. Write the color to `dest_attr`.
- Parameters
- ----------
- G : nx.Graph
- The graph to read and compute colors for.
- src_attr : str or other attribute name
- The name of the attribute to read from the graph.
- dest_attr : str or other attribute name
- The name of the attribute to write to on the graph.
- map : matplotlib.colormap
- The matplotlib colormap to use.
- vmin : float, default None
- The minimum value for scaling the colormap. If `None`, find the
- minimum value of `src_attr`.
- vmax : float, default None
- The maximum value for scaling the colormap. If `None`, find the
- maximum value of `src_attr`.
- nodes : bool, default True
- Whether the attribute names are edge attributes or node attributes.
- """
- import matplotlib as mpl
- if nodes:
- type_iter = G.nodes()
- elif G.is_multigraph():
- type_iter = G.edges(keys=True)
- else:
- type_iter = G.edges()
- if vmin is None or vmax is None:
- vals = [type_iter[a][src_attr] for a in type_iter]
- if vmin is None:
- vmin = min(vals)
- if vmax is None:
- vmax = max(vals)
- mapper = mpl.cm.ScalarMappable(cmap=map)
- mapper.set_clim(vmin, vmax)
- def do_map(x):
- # Cast numpy scalars to float
- return tuple(float(x) for x in mapper.to_rgba(x))
- if nodes:
- nx.set_node_attributes(
- G, {n: do_map(G.nodes[n][src_attr]) for n in G.nodes()}, dest_attr
- )
- else:
- nx.set_edge_attributes(
- G, {e: do_map(G.edges[e][src_attr]) for e in type_iter}, dest_attr
- )
- class CurvedArrowTextBase:
- def __init__(
- self,
- arrow,
- *args,
- label_pos=0.5,
- labels_horizontal=False,
- ax=None,
- **kwargs,
- ):
- # Bind to FancyArrowPatch
- self.arrow = arrow
- # how far along the text should be on the curve,
- # 0 is at start, 1 is at end etc.
- self.label_pos = label_pos
- self.labels_horizontal = labels_horizontal
- if ax is None:
- ax = plt.gca()
- self.ax = ax
- self.x, self.y, self.angle = self._update_text_pos_angle(arrow)
- # Create text object
- super().__init__(self.x, self.y, *args, rotation=self.angle, **kwargs)
- # Bind to axis
- self.ax.add_artist(self)
- def _get_arrow_path_disp(self, arrow):
- """
- This is part of FancyArrowPatch._get_path_in_displaycoord
- It omits the second part of the method where path is converted
- to polygon based on width
- The transform is taken from ax, not the object, as the object
- has not been added yet, and doesn't have transform
- """
- dpi_cor = arrow._dpi_cor
- trans_data = self.ax.transData
- if arrow._posA_posB is None:
- raise ValueError(
- "Can only draw labels for fancy arrows with "
- "posA and posB inputs, not custom path"
- )
- posA = arrow._convert_xy_units(arrow._posA_posB[0])
- posB = arrow._convert_xy_units(arrow._posA_posB[1])
- (posA, posB) = trans_data.transform((posA, posB))
- _path = arrow.get_connectionstyle()(
- posA,
- posB,
- patchA=arrow.patchA,
- patchB=arrow.patchB,
- shrinkA=arrow.shrinkA * dpi_cor,
- shrinkB=arrow.shrinkB * dpi_cor,
- )
- # Return is in display coordinates
- return _path
- def _update_text_pos_angle(self, arrow):
- # Fractional label position
- # Text position at a proportion t along the line in display coords
- # default is 0.5 so text appears at the halfway point
- import matplotlib as mpl
- import numpy as np
- t = self.label_pos
- tt = 1 - t
- path_disp = self._get_arrow_path_disp(arrow)
- conn = arrow.get_connectionstyle()
- # 1. Calculate x and y
- points = path_disp.vertices
- if is_curve := isinstance(
- conn,
- mpl.patches.ConnectionStyle.Angle3 | mpl.patches.ConnectionStyle.Arc3,
- ):
- # Arc3 or Angle3 type Connection Styles - Bezier curve
- (x1, y1), (cx, cy), (x2, y2) = points
- x = tt**2 * x1 + 2 * t * tt * cx + t**2 * x2
- y = tt**2 * y1 + 2 * t * tt * cy + t**2 * y2
- else:
- if not isinstance(
- conn,
- mpl.patches.ConnectionStyle.Angle
- | mpl.patches.ConnectionStyle.Arc
- | mpl.patches.ConnectionStyle.Bar,
- ):
- msg = f"invalid connection style: {type(conn)}"
- raise TypeError(msg)
- # A. Collect lines
- codes = path_disp.codes
- lines = [
- points[i - 1 : i + 1]
- for i in range(1, len(points))
- if codes[i] == mpl.path.Path.LINETO
- ]
- # B. If more than one line, find the right one and position in it
- if (nlines := len(lines)) != 1:
- dists = [math.dist(*line) for line in lines]
- dist_tot = sum(dists)
- cdist = 0
- last_cut = 0
- i_last = nlines - 1
- for i, dist in enumerate(dists):
- cdist += dist
- cut = cdist / dist_tot
- if i == i_last or t < cut:
- t = (t - last_cut) / (dist / dist_tot)
- tt = 1 - t
- lines = [lines[i]]
- break
- last_cut = cut
- [[(cx1, cy1), (cx2, cy2)]] = lines
- x = cx1 * tt + cx2 * t
- y = cy1 * tt + cy2 * t
- # 2. Calculate Angle
- if self.labels_horizontal:
- # Horizontal text labels
- angle = 0
- else:
- # Labels parallel to curve
- if is_curve:
- change_x = 2 * tt * (cx - x1) + 2 * t * (x2 - cx)
- change_y = 2 * tt * (cy - y1) + 2 * t * (y2 - cy)
- else:
- change_x = (cx2 - cx1) / 2
- change_y = (cy2 - cy1) / 2
- angle = np.arctan2(change_y, change_x) / (2 * np.pi) * 360
- # Text is "right way up"
- if angle > 90:
- angle -= 180
- elif angle < -90:
- angle += 180
- (x, y) = self.ax.transData.inverted().transform((x, y))
- return x, y, angle
- def draw(self, renderer):
- # recalculate the text position and angle
- self.x, self.y, self.angle = self._update_text_pos_angle(self.arrow)
- self.set_position((self.x, self.y))
- self.set_rotation(self.angle)
- # redraw text
- super().draw(renderer)
- def display(
- G,
- canvas=None,
- **kwargs,
- ):
- """Draw the graph G.
- Draw the graph as a collection of nodes connected by edges.
- The exact details of what the graph looks like are controlled by the below
- attributes. All nodes and nodes at the end of visible edges must have a
- position set, but nearly all other node and edge attributes are options and
- nodes or edges missing the attribute will use the default listed below. A more
- complete description of each parameter is given below this summary.
- .. list-table:: Default Visualization Attributes
- :widths: 25 25 50
- :header-rows: 1
- * - Parameter
- - Default Attribute
- - Default Value
- * - node_pos
- - `"pos"`
- - If there is not position, a layout will be calculated with `nx.spring_layout`.
- * - node_visible
- - `"visible"`
- - True
- * - node_color
- - `"color"`
- - #1f78b4
- * - node_size
- - `"size"`
- - 300
- * - node_label
- - `"label"`
- - Dict describing the node label. Defaults create a black text with
- the node name as the label. The dict respects these keys and defaults:
- * size : 12
- * color : black
- * family : sans serif
- * weight : normal
- * alpha : 1.0
- * h_align : center
- * v_align : center
- * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
- Default is None.
- * - node_shape
- - `"shape"`
- - "o"
- * - node_alpha
- - `"alpha"`
- - 1.0
- * - node_border_width
- - `"border_width"`
- - 1.0
- * - node_border_color
- - `"border_color"`
- - Matching node_color
- * - edge_visible
- - `"visible"`
- - True
- * - edge_width
- - `"width"`
- - 1.0
- * - edge_color
- - `"color"`
- - Black (#000000)
- * - edge_label
- - `"label"`
- - Dict describing the edge label. Defaults create black text with a
- white bounding box. The dictionary respects these keys and defaults:
- * size : 12
- * color : black
- * family : sans serif
- * weight : normal
- * alpha : 1.0
- * bbox : Dict describing a `matplotlib.patches.FancyBboxPatch`.
- Default {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
- * h_align : "center"
- * v_align : "center"
- * pos : 0.5
- * rotate : True
- * - edge_style
- - `"style"`
- - "-"
- * - edge_alpha
- - `"alpha"`
- - 1.0
- * - edge_arrowstyle
- - `"arrowstyle"`
- - ``"-|>"`` if `G` is directed else ``"-"``
- * - edge_arrowsize
- - `"arrowsize"`
- - 10 if `G` is directed else 0
- * - edge_curvature
- - `"curvature"`
- - arc3
- * - edge_source_margin
- - `"source_margin"`
- - 0
- * - edge_target_margin
- - `"target_margin"`
- - 0
- Parameters
- ----------
- G : graph
- A networkx graph
- canvas : Matplotlib Axes object, optional
- Draw the graph in specified Matplotlib axes
- node_pos : string or function, default "pos"
- A string naming the node attribute storing the position of nodes as a tuple.
- Or a function to be called with input `G` which returns the layout as a dict keyed
- by node to position tuple like the NetworkX layout functions.
- If no nodes in the graph has the attribute, a spring layout is calculated.
- node_visible : string or bool, default visible
- A string naming the node attribute which stores if a node should be drawn.
- If `True`, all nodes will be visible while if `False` no nodes will be visible.
- If incomplete, nodes missing this attribute will be shown by default.
- node_color : string, default "color"
- A string naming the node attribute which stores the color of each node.
- Visible nodes without this attribute will use '#1f78b4' as a default.
- node_size : string or number, default "size"
- A string naming the node attribute which stores the size of each node.
- Visible nodes without this attribute will use a default size of 300.
- node_label : string or bool, default "label"
- A string naming the node attribute which stores the label of each node.
- The attribute value can be a string, False (no label for that node),
- True (the node is the label) or a dict keyed by node to the label.
- If a dict is specified, these keys are read to further control the label:
- * label : The text of the label; default: name of the node
- * size : Font size of the label; default: 12
- * color : Font color of the label; default: black
- * family : Font family of the label; default: "sans-serif"
- * weight : Font weight of the label; default: "normal"
- * alpha : Alpha value of the label; default: 1.0
- * h_align : The horizontal alignment of the label.
- one of "left", "center", "right"; default: "center"
- * v_align : The vertical alignment of the label.
- one of "top", "center", "bottom"; default: "center"
- * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
- Visible nodes without this attribute will be treated as if the value was True.
- node_shape : string, default "shape"
- A string naming the node attribute which stores the label of each node.
- The values of this attribute are expected to be one of the matplotlib shapes,
- one of 'so^>v<dph8'. Visible nodes without this attribute will use 'o'.
- node_alpha : string, default "alpha"
- A string naming the node attribute which stores the alpha of each node.
- The values of this attribute are expected to be floats between 0.0 and 1.0.
- Visible nodes without this attribute will be treated as if the value was 1.0.
- node_border_width : string, default "border_width"
- A string naming the node attribute storing the width of the border of the node.
- The values of this attribute are expected to be numeric. Visible nodes without
- this attribute will use the assumed default of 1.0.
- node_border_color : string, default "border_color"
- A string naming the node attribute which storing the color of the border of the node.
- Visible nodes missing this attribute will use the final node_color value.
- edge_visible : string or bool, default "visible"
- A string nameing the edge attribute which stores if an edge should be drawn.
- If `True`, all edges will be drawn while if `False` no edges will be visible.
- If incomplete, edges missing this attribute will be shown by default. Values
- of this attribute are expected to be booleans.
- edge_width : string or int, default "width"
- A string nameing the edge attribute which stores the width of each edge.
- Visible edges without this attribute will use a default width of 1.0.
- edge_color : string or color, default "color"
- A string nameing the edge attribute which stores of color of each edge.
- Visible edges without this attribute will be drawn black. Each color can be
- a string or rgb (or rgba) tuple of floats from 0.0 to 1.0.
- edge_label : string, default "label"
- A string naming the edge attribute which stores the label of each edge.
- The values of this attribute can be a string, number or False or None. In
- the latter two cases, no edge label is displayed.
- If a dict is specified, these keys are read to further control the label:
- * label : The text of the label, or the name of an edge attribute holding the label.
- * size : Font size of the label; default: 12
- * color : Font color of the label; default: black
- * family : Font family of the label; default: "sans-serif"
- * weight : Font weight of the label; default: "normal"
- * alpha : Alpha value of the label; default: 1.0
- * h_align : The horizontal alignment of the label.
- one of "left", "center", "right"; default: "center"
- * v_align : The vertical alignment of the label.
- one of "top", "center", "bottom"; default: "center"
- * bbox : A dict of parameters for `matplotlib.patches.FancyBboxPatch`.
- * rotate : Whether to rotate labels to lie parallel to the edge, default: True.
- * pos : A float showing how far along the edge to put the label; default: 0.5.
- edge_style : string, default "style"
- A string naming the edge attribute which stores the style of each edge.
- Visible edges without this attribute will be drawn solid. Values of this
- attribute can be line styles, e.g. '-', '--', '-.' or ':' or words like 'solid'
- or 'dashed'. If no edge in the graph has this attribute and it is a non-default
- value, assume that it describes the edge style for all edges in the graph.
- edge_alpha : string or float, default "alpha"
- A string naming the edge attribute which stores the alpha value of each edge.
- Visible edges without this attribute will use an alpha value of 1.0.
- edge_arrowstyle : string, default "arrowstyle"
- A string naming the edge attribute which stores the type of arrowhead to use for
- each edge. Visible edges without this attribute use ``"-"`` for undirected graphs
- and ``"-|>"`` for directed graphs.
- See `matplotlib.patches.ArrowStyle` for more options
- edge_arrowsize : string or int, default "arrowsize"
- A string naming the edge attribute which stores the size of the arrowhead for each
- edge. Visible edges without this attribute will use a default value of 10.
- edge_curvature : string, default "curvature"
- A string naming the edge attribute storing the curvature and connection style
- of each edge. Visible edges without this attribute will use "arc3" as a default
- value, resulting an a straight line between the two nodes. Curvature can be given
- as 'arc3,rad=0.2' to specify both the style and radius of curvature.
- Please see `matplotlib.patches.ConnectionStyle` and
- `matplotlib.patches.FancyArrowPatch` for more information.
- edge_source_margin : string or int, default "source_margin"
- A string naming the edge attribute which stores the minimum margin (gap) between
- the source node and the start of the edge. Visible edges without this attribute
- will use a default value of 0.
- edge_target_margin : string or int, default "target_margin"
- A string naming the edge attribute which stores the minimumm margin (gap) between
- the target node and the end of the edge. Visible edges without this attribute
- will use a default value of 0.
- hide_ticks : bool, default True
- Weather to remove the ticks from the axes of the matplotlib object.
- Raises
- ------
- NetworkXError
- If a node or edge is missing a required parameter such as `pos` or
- if `display` receives an argument not listed above.
- ValueError
- If a node or edge has an invalid color format, i.e. not a color string,
- rgb tuple or rgba tuple.
- Returns
- -------
- The input graph. This is potentially useful for dispatching visualization
- functions.
- """
- from collections import Counter
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- import numpy as np
- defaults = {
- "node_pos": None,
- "node_visible": True,
- "node_color": "#1f78b4",
- "node_size": 300,
- "node_label": {
- "size": 12,
- "color": "#000000",
- "family": "sans-serif",
- "weight": "normal",
- "alpha": 1.0,
- "h_align": "center",
- "v_align": "center",
- "bbox": None,
- },
- "node_shape": "o",
- "node_alpha": 1.0,
- "node_border_width": 1.0,
- "node_border_color": "face",
- "edge_visible": True,
- "edge_width": 1.0,
- "edge_color": "#000000",
- "edge_label": {
- "size": 12,
- "color": "#000000",
- "family": "sans-serif",
- "weight": "normal",
- "alpha": 1.0,
- "bbox": {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)},
- "h_align": "center",
- "v_align": "center",
- "pos": 0.5,
- "rotate": True,
- },
- "edge_style": "-",
- "edge_alpha": 1.0,
- "edge_arrowstyle": "-|>" if G.is_directed() else "-",
- "edge_arrowsize": 10 if G.is_directed() else 0,
- "edge_curvature": "arc3",
- "edge_source_margin": 0,
- "edge_target_margin": 0,
- "hide_ticks": True,
- }
- # Check arguments
- for kwarg in kwargs:
- if kwarg not in defaults:
- raise nx.NetworkXError(
- f"Unrecognized visualization keyword argument: {kwarg}"
- )
- if canvas is None:
- canvas = plt.gca()
- if kwargs.get("hide_ticks", defaults["hide_ticks"]):
- canvas.tick_params(
- axis="both",
- which="both",
- bottom=False,
- left=False,
- labelbottom=False,
- labelleft=False,
- )
- ### Helper methods and classes
- def node_property_sequence(seq, attr):
- """Return a list of attribute values for `seq`, using a default if needed"""
- # All node attribute parameters start with "node_"
- param_name = f"node_{attr}"
- default = defaults[param_name]
- attr = kwargs.get(param_name, attr)
- if default is None:
- # raise instead of using non-existant default value
- for n in seq:
- if attr not in node_subgraph.nodes[n]:
- raise nx.NetworkXError(f"Attribute '{attr}' missing for node {n}")
- # If `attr` is not a graph attr and was explicitly passed as an argument
- # it must be a user-default value. Allow attr=None to tell draw to skip
- # attributes which are on the graph
- if (
- attr is not None
- and nx.get_node_attributes(node_subgraph, attr) == {}
- and any(attr == v for k, v in kwargs.items() if "node" in k)
- ):
- return [attr for _ in seq]
- return [node_subgraph.nodes[n].get(attr, default) for n in seq]
- def compute_colors(color, alpha):
- if isinstance(color, str):
- rgba = mpl.colors.colorConverter.to_rgba(color)
- # Using a non-default alpha value overrides any alpha value in the color
- if alpha != defaults["node_alpha"]:
- return (rgba[0], rgba[1], rgba[2], alpha)
- return rgba
- if isinstance(color, tuple) and len(color) == 3:
- return (color[0], color[1], color[2], alpha)
- if isinstance(color, tuple) and len(color) == 4:
- return color
- raise ValueError(f"Invalid format for color: {color}")
- # Find which edges can be plotted as a line collection
- #
- # Non-default values for these attributes require fancy arrow patches:
- # - any arrow style (including the default -|> for directed graphs)
- # - arrow size (by extension of style)
- # - connection style
- # - min_source_margin
- # - min_target_margin
- def collection_compatible(e):
- return (
- get_edge_attr(e, "arrowstyle") == "-"
- and get_edge_attr(e, "curvature") == "arc3"
- and get_edge_attr(e, "source_margin") == 0
- and get_edge_attr(e, "target_margin") == 0
- # Self-loops will use fancy arrow patches
- and e[0] != e[1]
- )
- def edge_property_sequence(seq, attr):
- """Return a list of attribute values for `seq`, using a default if needed"""
- param_name = f"edge_{attr}"
- default = defaults[param_name]
- attr = kwargs.get(param_name, attr)
- if default is None:
- # raise instead of using non-existant default value
- for e in seq:
- if attr not in edge_subgraph.edges[e]:
- raise nx.NetworkXError(f"Attribute '{attr}' missing for edge {e}")
- if (
- attr is not None
- and nx.get_edge_attributes(edge_subgraph, attr) == {}
- and any(attr == v for k, v in kwargs.items() if "edge" in k)
- ):
- return [attr for _ in seq]
- return [edge_subgraph.edges[e].get(attr, default) for e in seq]
- def get_edge_attr(e, attr):
- """Return the final edge attribute value, using default if not None"""
- param_name = f"edge_{attr}"
- default = defaults[param_name]
- attr = kwargs.get(param_name, attr)
- if default is None and attr not in edge_subgraph.edges[e]:
- raise nx.NetworkXError(f"Attribute '{attr}' missing from edge {e}")
- if (
- attr is not None
- and nx.get_edge_attributes(edge_subgraph, attr) == {}
- and attr in kwargs.values()
- ):
- return attr
- return edge_subgraph.edges[e].get(attr, default)
- def get_node_attr(n, attr, use_edge_subgraph=True):
- """Return the final node attribute value, using default if not None"""
- subgraph = edge_subgraph if use_edge_subgraph else node_subgraph
- param_name = f"node_{attr}"
- default = defaults[param_name]
- attr = kwargs.get(param_name, attr)
- if default is None and attr not in subgraph.nodes[n]:
- raise nx.NetworkXError(f"Attribute '{attr}' missing from node {n}")
- if (
- attr is not None
- and nx.get_node_attributes(subgraph, attr) == {}
- and attr in kwargs.values()
- ):
- return attr
- return subgraph.nodes[n].get(attr, default)
- # Taken from ConnectionStyleFactory
- def self_loop(edge_index, node_size):
- def self_loop_connection(posA, posB, *args, **kwargs):
- if not np.all(posA == posB):
- raise nx.NetworkXError(
- "`self_loop` connection style method"
- "is only to be used for self-loops"
- )
- # this is called with _screen space_ values
- # so convert back to data space
- data_loc = canvas.transData.inverted().transform(posA)
- # Scale self loop based on the size of the base node
- # Size of nodes are given in points ** 2 and each point is 1/72 of an inch
- v_shift = np.sqrt(node_size) / 72
- h_shift = v_shift * 0.5
- # put the top of the loop first so arrow is not hidden by node
- path = np.asarray(
- [
- # 1
- [0, v_shift],
- # 4 4 4
- [h_shift, v_shift],
- [h_shift, 0],
- [0, 0],
- # 4 4 4
- [-h_shift, 0],
- [-h_shift, v_shift],
- [0, v_shift],
- ]
- )
- # Rotate self loop 90 deg. if more than 1
- # This will allow for maximum of 4 visible self loops
- if edge_index % 4:
- x, y = path.T
- for _ in range(edge_index % 4):
- x, y = y, -x
- path = np.array([x, y]).T
- return mpl.path.Path(
- canvas.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
- )
- return self_loop_connection
- def to_marker_edge(size, marker):
- if marker in "s^>v<d":
- return np.sqrt(2 * size) / 2
- else:
- return np.sqrt(size) / 2
- def build_fancy_arrow(e):
- source_margin = to_marker_edge(
- get_node_attr(e[0], "size"),
- get_node_attr(e[0], "shape"),
- )
- source_margin = max(
- source_margin,
- get_edge_attr(e, "source_margin"),
- )
- target_margin = to_marker_edge(
- get_node_attr(e[1], "size"),
- get_node_attr(e[1], "shape"),
- )
- target_margin = max(
- target_margin,
- get_edge_attr(e, "target_margin"),
- )
- return mpl.patches.FancyArrowPatch(
- edge_subgraph.nodes[e[0]][pos],
- edge_subgraph.nodes[e[1]][pos],
- arrowstyle=get_edge_attr(e, "arrowstyle"),
- connectionstyle=(
- get_edge_attr(e, "curvature")
- if e[0] != e[1]
- else self_loop(
- 0 if len(e) == 2 else e[2] % 4,
- get_node_attr(e[0], "size"),
- )
- ),
- color=get_edge_attr(e, "color"),
- linestyle=get_edge_attr(e, "style"),
- linewidth=get_edge_attr(e, "width"),
- mutation_scale=get_edge_attr(e, "arrowsize"),
- shrinkA=source_margin,
- shrinkB=source_margin,
- zorder=1,
- )
- class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
- pass
- ### Draw the nodes first
- node_visible = kwargs.get("node_visible", "visible")
- if isinstance(node_visible, bool):
- if node_visible:
- visible_nodes = G.nodes()
- else:
- visible_nodes = []
- else:
- visible_nodes = [
- n for n, v in nx.get_node_attributes(G, node_visible, True).items() if v
- ]
- node_subgraph = G.subgraph(visible_nodes)
- # Ignore the default dict value since that's for default values to use, not
- # default attribute name
- pos = kwargs.get("node_pos", "pos")
- default_display_pos_attr = "display's position attribute name"
- if callable(pos):
- nx.set_node_attributes(
- node_subgraph, pos(node_subgraph), default_display_pos_attr
- )
- pos = default_display_pos_attr
- kwargs["node_pos"] = default_display_pos_attr
- elif nx.get_node_attributes(G, pos) == {}:
- nx.set_node_attributes(
- node_subgraph, nx.spring_layout(node_subgraph), default_display_pos_attr
- )
- pos = default_display_pos_attr
- kwargs["node_pos"] = default_display_pos_attr
- # Each shape requires a new scatter object since they can't have different
- # shapes.
- if len(visible_nodes) > 0:
- node_shape = kwargs.get("node_shape", "shape")
- for shape in Counter(
- nx.get_node_attributes(
- node_subgraph, node_shape, defaults["node_shape"]
- ).values()
- ):
- # Filter position just on this shape.
- nodes_with_shape = [
- n
- for n, s in node_subgraph.nodes(data=node_shape)
- if s == shape or (s is None and shape == defaults["node_shape"])
- ]
- # There are two property sequences to create before hand.
- # 1. position, since it is used for x and y parameters to scatter
- # 2. edgecolor, since the spaeical 'face' parameter value can only be
- # be passed in as the sole string, not part of a list of strings.
- position = np.asarray(node_property_sequence(nodes_with_shape, "pos"))
- color = np.asarray(
- [
- compute_colors(c, a)
- for c, a in zip(
- node_property_sequence(nodes_with_shape, "color"),
- node_property_sequence(nodes_with_shape, "alpha"),
- )
- ]
- )
- border_color = np.asarray(
- [
- (
- c
- if (
- c := get_node_attr(
- n,
- "border_color",
- False,
- )
- )
- != "face"
- else color[i]
- )
- for i, n in enumerate(nodes_with_shape)
- ]
- )
- canvas.scatter(
- position[:, 0],
- position[:, 1],
- s=node_property_sequence(nodes_with_shape, "size"),
- c=color,
- marker=shape,
- linewidths=node_property_sequence(nodes_with_shape, "border_width"),
- edgecolors=border_color,
- zorder=2,
- )
- ### Draw node labels
- node_label = kwargs.get("node_label", "label")
- # Plot labels if node_label is not None and not False
- if node_label is not None and node_label is not False:
- default_dict = {}
- if isinstance(node_label, dict):
- default_dict = node_label
- node_label = None
- for n, lbl in node_subgraph.nodes(data=node_label):
- if lbl is False:
- continue
- # We work with label dicts down here...
- if not isinstance(lbl, dict):
- lbl = {"label": lbl if lbl is not None else n}
- lbl_text = lbl.get("label", n)
- if not isinstance(lbl_text, str):
- lbl_text = str(lbl_text)
- lbl.update(default_dict)
- x, y = node_subgraph.nodes[n][pos]
- canvas.text(
- x,
- y,
- lbl_text,
- size=lbl.get("size", defaults["node_label"]["size"]),
- color=lbl.get("color", defaults["node_label"]["color"]),
- family=lbl.get("family", defaults["node_label"]["family"]),
- weight=lbl.get("weight", defaults["node_label"]["weight"]),
- horizontalalignment=lbl.get(
- "h_align", defaults["node_label"]["h_align"]
- ),
- verticalalignment=lbl.get("v_align", defaults["node_label"]["v_align"]),
- transform=canvas.transData,
- bbox=lbl.get("bbox", defaults["node_label"]["bbox"]),
- )
- ### Draw edges
- edge_visible = kwargs.get("edge_visible", "visible")
- if isinstance(edge_visible, bool):
- if edge_visible:
- visible_edges = G.edges()
- else:
- visible_edges = []
- else:
- visible_edges = [
- e for e, v in nx.get_edge_attributes(G, edge_visible, True).items() if v
- ]
- edge_subgraph = G.edge_subgraph(visible_edges)
- nx.set_node_attributes(
- edge_subgraph, nx.get_node_attributes(node_subgraph, pos), name=pos
- )
- collection_edges = (
- [e for e in edge_subgraph.edges(keys=True) if collection_compatible(e)]
- if edge_subgraph.is_multigraph()
- else [e for e in edge_subgraph.edges() if collection_compatible(e)]
- )
- non_collection_edges = (
- [e for e in edge_subgraph.edges(keys=True) if not collection_compatible(e)]
- if edge_subgraph.is_multigraph()
- else [e for e in edge_subgraph.edges() if not collection_compatible(e)]
- )
- edge_position = np.asarray(
- [
- (
- get_node_attr(u, "pos", use_edge_subgraph=True),
- get_node_attr(v, "pos", use_edge_subgraph=True),
- )
- for u, v, *_ in collection_edges
- ]
- )
- # Only plot a line collection if needed
- if len(collection_edges) > 0:
- edge_collection = mpl.collections.LineCollection(
- edge_position,
- colors=edge_property_sequence(collection_edges, "color"),
- linewidths=edge_property_sequence(collection_edges, "width"),
- linestyle=edge_property_sequence(collection_edges, "style"),
- alpha=edge_property_sequence(collection_edges, "alpha"),
- antialiaseds=(1,),
- zorder=1,
- )
- canvas.add_collection(edge_collection)
- fancy_arrows = {}
- if len(non_collection_edges) > 0:
- for e in non_collection_edges:
- # Cache results for use in edge labels
- fancy_arrows[e] = build_fancy_arrow(e)
- canvas.add_patch(fancy_arrows[e])
- ### Draw edge labels
- edge_label = kwargs.get("edge_label", "label")
- default_dict = {}
- if isinstance(edge_label, dict):
- default_dict = edge_label
- # Restore the default label attribute key of 'label'
- edge_label = "label"
- # Handle multigraphs
- edge_label_data = (
- edge_subgraph.edges(data=edge_label, keys=True)
- if edge_subgraph.is_multigraph()
- else edge_subgraph.edges(data=edge_label)
- )
- if edge_label is not None and edge_label is not False:
- for *e, lbl in edge_label_data:
- e = tuple(e)
- # I'm not sure how I want to handle None here... For now it means no label
- if lbl is False or lbl is None:
- continue
- if not isinstance(lbl, dict):
- lbl = {"label": lbl}
- lbl.update(default_dict)
- lbl_text = lbl.get("label")
- if not isinstance(lbl_text, str):
- lbl_text = str(lbl_text)
- # In the old code, every non-self-loop is placed via a fancy arrow patch
- # Only compute a new fancy arrow if needed by caching the results from
- # edge placement.
- try:
- arrow = fancy_arrows[e]
- except KeyError:
- arrow = build_fancy_arrow(e)
- if e[0] == e[1]:
- # Taken directly from draw_networkx_edge_labels
- connectionstyle_obj = arrow.get_connectionstyle()
- posA = canvas.transData.transform(edge_subgraph.nodes[e[0]][pos])
- path_disp = connectionstyle_obj(posA, posA)
- path_data = canvas.transData.inverted().transform_path(path_disp)
- x, y = path_data.vertices[0]
- canvas.text(
- x,
- y,
- lbl_text,
- size=lbl.get("size", defaults["edge_label"]["size"]),
- color=lbl.get("color", defaults["edge_label"]["color"]),
- family=lbl.get("family", defaults["edge_label"]["family"]),
- weight=lbl.get("weight", defaults["edge_label"]["weight"]),
- alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
- horizontalalignment=lbl.get(
- "h_align", defaults["edge_label"]["h_align"]
- ),
- verticalalignment=lbl.get(
- "v_align", defaults["edge_label"]["v_align"]
- ),
- rotation=0,
- transform=canvas.transData,
- bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
- zorder=1,
- )
- continue
- CurvedArrowText(
- arrow,
- lbl_text,
- size=lbl.get("size", defaults["edge_label"]["size"]),
- color=lbl.get("color", defaults["edge_label"]["color"]),
- family=lbl.get("family", defaults["edge_label"]["family"]),
- weight=lbl.get("weight", defaults["edge_label"]["weight"]),
- alpha=lbl.get("alpha", defaults["edge_label"]["alpha"]),
- bbox=lbl.get("bbox", defaults["edge_label"]["bbox"]),
- horizontalalignment=lbl.get(
- "h_align", defaults["edge_label"]["h_align"]
- ),
- verticalalignment=lbl.get("v_align", defaults["edge_label"]["v_align"]),
- label_pos=lbl.get("pos", defaults["edge_label"]["pos"]),
- labels_horizontal=lbl.get("rotate", defaults["edge_label"]["rotate"]),
- transform=canvas.transData,
- zorder=1,
- ax=canvas,
- )
- # If we had to add an attribute, remove it here
- if pos == default_display_pos_attr:
- nx.remove_node_attributes(G, default_display_pos_attr)
- return G
- def draw(G, pos=None, ax=None, **kwds):
- """Draw the graph G with Matplotlib.
- Draw the graph as a simple representation with no node
- labels or edge labels and using the full Matplotlib figure area
- and no axis labels by default. See draw_networkx() for more
- full-featured drawing that allows title, axis labels etc.
- Parameters
- ----------
- G : graph
- A networkx graph
- pos : dictionary, optional
- A dictionary with nodes as keys and positions as values.
- If not specified a spring layout positioning will be computed.
- See :py:mod:`networkx.drawing.layout` for functions that
- compute node positions.
- ax : Matplotlib Axes object, optional
- Draw the graph in specified Matplotlib axes.
- kwds : optional keywords
- See networkx.draw_networkx() for a description of optional keywords.
- Examples
- --------
- >>> G = nx.dodecahedral_graph()
- >>> nx.draw(G)
- >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
- See Also
- --------
- draw_networkx
- draw_networkx_nodes
- draw_networkx_edges
- draw_networkx_labels
- draw_networkx_edge_labels
- Notes
- -----
- This function has the same name as pylab.draw and pyplot.draw
- so beware when using `from networkx import *`
- since you might overwrite the pylab.draw function.
- With pyplot use
- >>> import matplotlib.pyplot as plt
- >>> G = nx.dodecahedral_graph()
- >>> nx.draw(G) # networkx draw()
- >>> plt.draw() # pyplot draw()
- Also see the NetworkX drawing examples at
- https://networkx.org/documentation/latest/auto_examples/index.html
- """
- import matplotlib.pyplot as plt
- if ax is None:
- cf = plt.gcf()
- else:
- cf = ax.get_figure()
- cf.set_facecolor("w")
- if ax is None:
- if cf.axes:
- ax = cf.gca()
- else:
- ax = cf.add_axes((0, 0, 1, 1))
- if "with_labels" not in kwds:
- kwds["with_labels"] = "labels" in kwds
- draw_networkx(G, pos=pos, ax=ax, **kwds)
- ax.set_axis_off()
- plt.draw_if_interactive()
- return
- def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds):
- r"""Draw the graph G using Matplotlib.
- Draw the graph with Matplotlib with options for node positions,
- labeling, titles, and many other drawing features.
- See draw() for simple drawing without labels or axes.
- Parameters
- ----------
- G : graph
- A networkx graph
- pos : dictionary, optional
- A dictionary with nodes as keys and positions as values.
- If not specified a spring layout positioning will be computed.
- See :py:mod:`networkx.drawing.layout` for functions that
- compute node positions.
- arrows : bool or None, optional (default=None)
- If `None`, directed graphs draw arrowheads with
- `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
- via `~matplotlib.collections.LineCollection` for speed.
- If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
- If `False`, draw edges using LineCollection (linear and fast).
- For directed graphs, if True draw arrowheads.
- Note: Arrows will be the same color as edges.
- arrowstyle : str (default='-\|>' for directed graphs)
- For directed graphs, choose the style of the arrowsheads.
- For undirected graphs default to '-'
- See `matplotlib.patches.ArrowStyle` for more options.
- arrowsize : int or list (default=10)
- For directed graphs, choose the size of the arrow head's length and
- width. A list of values can be passed in to assign a different size for arrow head's length and width.
- See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale`
- for more info.
- with_labels : bool (default=True)
- Set to True to draw labels on the nodes.
- ax : Matplotlib Axes object, optional
- Draw the graph in the specified Matplotlib axes.
- nodelist : list (default=list(G))
- Draw only specified nodes
- edgelist : list (default=list(G.edges()))
- Draw only specified edges
- node_size : scalar or array (default=300)
- Size of nodes. If an array is specified it must be the
- same length as nodelist.
- node_color : color or array of colors (default='#1f78b4')
- Node color. Can be a single color or a sequence of colors with the same
- length as nodelist. Color can be string or rgb (or rgba) tuple of
- floats from 0-1. If numeric values are specified they will be
- mapped to colors using the cmap and vmin,vmax parameters. See
- matplotlib.scatter for more details.
- node_shape : string (default='o')
- The shape of the node. Specification is as matplotlib.scatter
- marker, one of 'so^>v<dph8'.
- alpha : float or None (default=None)
- The node and edge transparency
- cmap : Matplotlib colormap, optional
- Colormap for mapping intensities of nodes
- vmin,vmax : float, optional
- Minimum and maximum for node colormap scaling
- linewidths : scalar or sequence (default=1.0)
- Line width of symbol border
- width : float or array of floats (default=1.0)
- Line width of edges
- edge_color : color or array of colors (default='k')
- Edge color. Can be a single color or a sequence of colors with the same
- length as edgelist. Color can be string or rgb (or rgba) tuple of
- floats from 0-1. If numeric values are specified they will be
- mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
- edge_cmap : Matplotlib colormap, optional
- Colormap for mapping intensities of edges
- edge_vmin,edge_vmax : floats, optional
- Minimum and maximum for edge colormap scaling
- style : string (default=solid line)
- Edge line style e.g.: '-', '--', '-.', ':'
- or words like 'solid' or 'dashed'.
- (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
- labels : dictionary (default=None)
- Node labels in a dictionary of text labels keyed by node
- font_size : int (default=12 for nodes, 10 for edges)
- Font size for text labels
- font_color : color (default='k' black)
- Font color string. Color can be string or rgb (or rgba) tuple of
- floats from 0-1.
- font_weight : string (default='normal')
- Font weight
- font_family : string (default='sans-serif')
- Font family
- label : string, optional
- Label for graph legend
- hide_ticks : bool, optional
- Hide ticks of axes. When `True` (the default), ticks and ticklabels
- are removed from the axes. To set ticks and tick labels to the pyplot default,
- use ``hide_ticks=False``.
- kwds : optional keywords
- See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and
- networkx.draw_networkx_labels() for a description of optional keywords.
- Notes
- -----
- For directed graphs, arrows are drawn at the head end. Arrows can be
- turned off with keyword arrows=False.
- Examples
- --------
- >>> G = nx.dodecahedral_graph()
- >>> nx.draw(G)
- >>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout
- >>> import matplotlib.pyplot as plt
- >>> limits = plt.axis("off") # turn off axis
- Also see the NetworkX drawing examples at
- https://networkx.org/documentation/latest/auto_examples/index.html
- See Also
- --------
- draw
- draw_networkx_nodes
- draw_networkx_edges
- draw_networkx_labels
- draw_networkx_edge_labels
- """
- from inspect import signature
- import matplotlib.pyplot as plt
- # Get all valid keywords by inspecting the signatures of draw_networkx_nodes,
- # draw_networkx_edges, draw_networkx_labels
- valid_node_kwds = signature(draw_networkx_nodes).parameters.keys()
- valid_edge_kwds = signature(draw_networkx_edges).parameters.keys()
- valid_label_kwds = signature(draw_networkx_labels).parameters.keys()
- # Create a set with all valid keywords across the three functions and
- # remove the arguments of this function (draw_networkx)
- valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - {
- "G",
- "pos",
- "arrows",
- "with_labels",
- }
- if any(k not in valid_kwds for k in kwds):
- invalid_args = ", ".join([k for k in kwds if k not in valid_kwds])
- raise ValueError(f"Received invalid argument(s): {invalid_args}")
- node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds}
- edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds}
- label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds}
- if pos is None:
- pos = nx.drawing.spring_layout(G) # default to spring layout
- draw_networkx_nodes(G, pos, **node_kwds)
- draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds)
- if with_labels:
- draw_networkx_labels(G, pos, **label_kwds)
- plt.draw_if_interactive()
- def draw_networkx_nodes(
- G,
- pos,
- nodelist=None,
- node_size=300,
- node_color="#1f78b4",
- node_shape="o",
- alpha=None,
- cmap=None,
- vmin=None,
- vmax=None,
- ax=None,
- linewidths=None,
- edgecolors=None,
- label=None,
- margins=None,
- hide_ticks=True,
- ):
- """Draw the nodes of the graph G.
- This draws only the nodes of the graph G.
- Parameters
- ----------
- G : graph
- A networkx graph
- pos : dictionary
- A dictionary with nodes as keys and positions as values.
- Positions should be sequences of length 2.
- ax : Matplotlib Axes object, optional
- Draw the graph in the specified Matplotlib axes.
- nodelist : list (default list(G))
- Draw only specified nodes
- node_size : scalar or array (default=300)
- Size of nodes. If an array it must be the same length as nodelist.
- node_color : color or array of colors (default='#1f78b4')
- Node color. Can be a single color or a sequence of colors with the same
- length as nodelist. Color can be string or rgb (or rgba) tuple of
- floats from 0-1. If numeric values are specified they will be
- mapped to colors using the cmap and vmin,vmax parameters. See
- matplotlib.scatter for more details.
- node_shape : string (default='o')
- The shape of the node. Specification is as matplotlib.scatter
- marker, one of 'so^>v<dph8'.
- alpha : float or array of floats (default=None)
- The node transparency. This can be a single alpha value,
- in which case it will be applied to all the nodes of color. Otherwise,
- if it is an array, the elements of alpha will be applied to the colors
- in order (cycling through alpha multiple times if necessary).
- cmap : Matplotlib colormap (default=None)
- Colormap for mapping intensities of nodes
- vmin,vmax : floats or None (default=None)
- Minimum and maximum for node colormap scaling
- linewidths : [None | scalar | sequence] (default=1.0)
- Line width of symbol border
- edgecolors : [None | scalar | sequence] (default = node_color)
- Colors of node borders. Can be a single color or a sequence of colors with the
- same length as nodelist. Color can be string or rgb (or rgba) tuple of floats
- from 0-1. If numeric values are specified they will be mapped to colors
- using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details.
- label : [None | string]
- Label for legend
- margins : float or 2-tuple, optional
- Sets the padding for axis autoscaling. Increase margin to prevent
- clipping for nodes that are near the edges of an image. Values should
- be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins`
- for details. The default is `None`, which uses the Matplotlib default.
- hide_ticks : bool, optional
- Hide ticks of axes. When `True` (the default), ticks and ticklabels
- are removed from the axes. To set ticks and tick labels to the pyplot default,
- use ``hide_ticks=False``.
- Returns
- -------
- matplotlib.collections.PathCollection
- `PathCollection` of the nodes.
- Examples
- --------
- >>> G = nx.dodecahedral_graph()
- >>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G))
- Also see the NetworkX drawing examples at
- https://networkx.org/documentation/latest/auto_examples/index.html
- See Also
- --------
- draw
- draw_networkx
- draw_networkx_edges
- draw_networkx_labels
- draw_networkx_edge_labels
- """
- from collections.abc import Iterable
- import matplotlib as mpl
- import matplotlib.collections # call as mpl.collections
- import matplotlib.pyplot as plt
- import numpy as np
- if ax is None:
- ax = plt.gca()
- if nodelist is None:
- nodelist = list(G)
- if len(nodelist) == 0: # empty nodelist, no drawing
- return mpl.collections.PathCollection(None)
- try:
- xy = np.asarray([pos[v] for v in nodelist])
- except KeyError as err:
- raise nx.NetworkXError(f"Node {err} has no position.") from err
- if isinstance(alpha, Iterable):
- node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax)
- alpha = None
- if not isinstance(node_shape, np.ndarray) and not isinstance(node_shape, list):
- node_shape = np.array([node_shape for _ in range(len(nodelist))])
- elif isinstance(node_shape, list):
- node_shape = np.asarray(node_shape)
- for shape in np.unique(node_shape):
- node_collection = ax.scatter(
- xy[node_shape == shape, 0],
- xy[node_shape == shape, 1],
- s=node_size,
- c=node_color,
- marker=shape,
- cmap=cmap,
- vmin=vmin,
- vmax=vmax,
- alpha=alpha,
- linewidths=linewidths,
- edgecolors=edgecolors,
- label=label,
- )
- if hide_ticks:
- ax.tick_params(
- axis="both",
- which="both",
- bottom=False,
- left=False,
- labelbottom=False,
- labelleft=False,
- )
- if margins is not None:
- if isinstance(margins, Iterable):
- ax.margins(*margins)
- else:
- ax.margins(margins)
- node_collection.set_zorder(2)
- return node_collection
- class FancyArrowFactory:
- """Draw arrows with `matplotlib.patches.FancyarrowPatch`"""
- class ConnectionStyleFactory:
- def __init__(self, connectionstyles, selfloop_height, ax=None):
- import matplotlib as mpl
- import matplotlib.path # call as mpl.path
- import numpy as np
- self.ax = ax
- self.mpl = mpl
- self.np = np
- self.base_connection_styles = [
- mpl.patches.ConnectionStyle(cs) for cs in connectionstyles
- ]
- self.n = len(self.base_connection_styles)
- self.selfloop_height = selfloop_height
- def curved(self, edge_index):
- return self.base_connection_styles[edge_index % self.n]
- def self_loop(self, edge_index):
- def self_loop_connection(posA, posB, *args, **kwargs):
- if not self.np.all(posA == posB):
- raise nx.NetworkXError(
- "`self_loop` connection style method"
- "is only to be used for self-loops"
- )
- # this is called with _screen space_ values
- # so convert back to data space
- data_loc = self.ax.transData.inverted().transform(posA)
- v_shift = 0.1 * self.selfloop_height
- h_shift = v_shift * 0.5
- # put the top of the loop first so arrow is not hidden by node
- path = self.np.asarray(
- [
- # 1
- [0, v_shift],
- # 4 4 4
- [h_shift, v_shift],
- [h_shift, 0],
- [0, 0],
- # 4 4 4
- [-h_shift, 0],
- [-h_shift, v_shift],
- [0, v_shift],
- ]
- )
- # Rotate self loop 90 deg. if more than 1
- # This will allow for maximum of 4 visible self loops
- if edge_index % 4:
- x, y = path.T
- for _ in range(edge_index % 4):
- x, y = y, -x
- path = self.np.array([x, y]).T
- return self.mpl.path.Path(
- self.ax.transData.transform(data_loc + path), [1, 4, 4, 4, 4, 4, 4]
- )
- return self_loop_connection
- def __init__(
- self,
- edge_pos,
- edgelist,
- nodelist,
- edge_indices,
- node_size,
- selfloop_height,
- connectionstyle="arc3",
- node_shape="o",
- arrowstyle="-",
- arrowsize=10,
- edge_color="k",
- alpha=None,
- linewidth=1.0,
- style="solid",
- min_source_margin=0,
- min_target_margin=0,
- ax=None,
- ):
- import matplotlib as mpl
- import matplotlib.patches # call as mpl.patches
- import matplotlib.pyplot as plt
- import numpy as np
- if isinstance(connectionstyle, str):
- connectionstyle = [connectionstyle]
- elif np.iterable(connectionstyle):
- connectionstyle = list(connectionstyle)
- else:
- msg = "ConnectionStyleFactory arg `connectionstyle` must be str or iterable"
- raise nx.NetworkXError(msg)
- self.ax = ax
- self.mpl = mpl
- self.np = np
- self.edge_pos = edge_pos
- self.edgelist = edgelist
- self.nodelist = nodelist
- self.node_shape = node_shape
- self.min_source_margin = min_source_margin
- self.min_target_margin = min_target_margin
- self.edge_indices = edge_indices
- self.node_size = node_size
- self.connectionstyle_factory = self.ConnectionStyleFactory(
- connectionstyle, selfloop_height, ax
- )
- self.arrowstyle = arrowstyle
- self.arrowsize = arrowsize
- self.arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
- self.linewidth = linewidth
- self.style = style
- if isinstance(arrowsize, list) and len(arrowsize) != len(edge_pos):
- raise ValueError("arrowsize should have the same length as edgelist")
- def __call__(self, i):
- (x1, y1), (x2, y2) = self.edge_pos[i]
- shrink_source = 0 # space from source to tail
- shrink_target = 0 # space from head to target
- if (
- self.np.iterable(self.min_source_margin)
- and not isinstance(self.min_source_margin, str)
- and not isinstance(self.min_source_margin, tuple)
- ):
- min_source_margin = self.min_source_margin[i]
- else:
- min_source_margin = self.min_source_margin
- if (
- self.np.iterable(self.min_target_margin)
- and not isinstance(self.min_target_margin, str)
- and not isinstance(self.min_target_margin, tuple)
- ):
- min_target_margin = self.min_target_margin[i]
- else:
- min_target_margin = self.min_target_margin
- if self.np.iterable(self.node_size): # many node sizes
- source, target = self.edgelist[i][:2]
- source_node_size = self.node_size[self.nodelist.index(source)]
- target_node_size = self.node_size[self.nodelist.index(target)]
- shrink_source = self.to_marker_edge(source_node_size, self.node_shape)
- shrink_target = self.to_marker_edge(target_node_size, self.node_shape)
- else:
- shrink_source = self.to_marker_edge(self.node_size, self.node_shape)
- shrink_target = shrink_source
- shrink_source = max(shrink_source, min_source_margin)
- shrink_target = max(shrink_target, min_target_margin)
- # scale factor of arrow head
- if isinstance(self.arrowsize, list):
- mutation_scale = self.arrowsize[i]
- else:
- mutation_scale = self.arrowsize
- if len(self.arrow_colors) > i:
- arrow_color = self.arrow_colors[i]
- elif len(self.arrow_colors) == 1:
- arrow_color = self.arrow_colors[0]
- else: # Cycle through colors
- arrow_color = self.arrow_colors[i % len(self.arrow_colors)]
- if self.np.iterable(self.linewidth):
- if len(self.linewidth) > i:
- linewidth = self.linewidth[i]
- else:
- linewidth = self.linewidth[i % len(self.linewidth)]
- else:
- linewidth = self.linewidth
- if (
- self.np.iterable(self.style)
- and not isinstance(self.style, str)
- and not isinstance(self.style, tuple)
- ):
- if len(self.style) > i:
- linestyle = self.style[i]
- else: # Cycle through styles
- linestyle = self.style[i % len(self.style)]
- else:
- linestyle = self.style
- if x1 == x2 and y1 == y2:
- connectionstyle = self.connectionstyle_factory.self_loop(
- self.edge_indices[i]
- )
- else:
- connectionstyle = self.connectionstyle_factory.curved(self.edge_indices[i])
- if (
- self.np.iterable(self.arrowstyle)
- and not isinstance(self.arrowstyle, str)
- and not isinstance(self.arrowstyle, tuple)
- ):
- arrowstyle = self.arrowstyle[i]
- else:
- arrowstyle = self.arrowstyle
- return self.mpl.patches.FancyArrowPatch(
- (x1, y1),
- (x2, y2),
- arrowstyle=arrowstyle,
- shrinkA=shrink_source,
- shrinkB=shrink_target,
- mutation_scale=mutation_scale,
- color=arrow_color,
- linewidth=linewidth,
- connectionstyle=connectionstyle,
- linestyle=linestyle,
- zorder=1, # arrows go behind nodes
- )
- def to_marker_edge(self, marker_size, marker):
- if marker in "s^>v<d": # `large` markers need extra space
- return self.np.sqrt(2 * marker_size) / 2
- else:
- return self.np.sqrt(marker_size) / 2
- def draw_networkx_edges(
- G,
- pos,
- edgelist=None,
- width=1.0,
- edge_color="k",
- style="solid",
- alpha=None,
- arrowstyle=None,
- arrowsize=10,
- edge_cmap=None,
- edge_vmin=None,
- edge_vmax=None,
- ax=None,
- arrows=None,
- label=None,
- node_size=300,
- nodelist=None,
- node_shape="o",
- connectionstyle="arc3",
- min_source_margin=0,
- min_target_margin=0,
- hide_ticks=True,
- ):
- r"""Draw the edges of the graph G.
- This draws only the edges of the graph G.
- Parameters
- ----------
- G : graph
- A networkx graph
- pos : dictionary
- A dictionary with nodes as keys and positions as values.
- Positions should be sequences of length 2.
- edgelist : collection of edge tuples (default=G.edges())
- Draw only specified edges
- width : float or array of floats (default=1.0)
- Line width of edges
- edge_color : color or array of colors (default='k')
- Edge color. Can be a single color or a sequence of colors with the same
- length as edgelist. Color can be string or rgb (or rgba) tuple of
- floats from 0-1. If numeric values are specified they will be
- mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters.
- style : string or array of strings (default='solid')
- Edge line style e.g.: '-', '--', '-.', ':'
- or words like 'solid' or 'dashed'.
- Can be a single style or a sequence of styles with the same
- length as the edge list.
- If less styles than edges are given the styles will cycle.
- If more styles than edges are given the styles will be used sequentially
- and not be exhausted.
- Also, `(offset, onoffseq)` tuples can be used as style instead of a strings.
- (See `matplotlib.patches.FancyArrowPatch`: `linestyle`)
- alpha : float or array of floats (default=None)
- The edge transparency. This can be a single alpha value,
- in which case it will be applied to all specified edges. Otherwise,
- if it is an array, the elements of alpha will be applied to the colors
- in order (cycling through alpha multiple times if necessary).
- edge_cmap : Matplotlib colormap, optional
- Colormap for mapping intensities of edges
- edge_vmin,edge_vmax : floats, optional
- Minimum and maximum for edge colormap scaling
- ax : Matplotlib Axes object, optional
- Draw the graph in the specified Matplotlib axes.
- arrows : bool or None, optional (default=None)
- If `None`, directed graphs draw arrowheads with
- `~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges
- via `~matplotlib.collections.LineCollection` for speed.
- If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish).
- If `False`, draw edges using LineCollection (linear and fast).
- Note: Arrowheads will be the same color as edges.
- arrowstyle : str or list of strs (default='-\|>' for directed graphs)
- For directed graphs and `arrows==True` defaults to '-\|>',
- For undirected graphs default to '-'.
- See `matplotlib.patches.ArrowStyle` for more options.
- arrowsize : int or list of ints(default=10)
- For directed graphs, choose the size of the arrow head's length and
- width. See `matplotlib.patches.FancyArrowPatch` for attribute
- `mutation_scale` for more info.
- connectionstyle : string or iterable of strings (default="arc3")
- Pass the connectionstyle parameter to create curved arc of rounding
- radius rad. For example, connectionstyle='arc3,rad=0.2'.
- See `matplotlib.patches.ConnectionStyle` and
- `matplotlib.patches.FancyArrowPatch` for more info.
- If Iterable, index indicates i'th edge key of MultiGraph
- node_size : scalar or array (default=300)
- Size of nodes. Though the nodes are not drawn with this function, the
- node size is used in determining edge positioning.
- nodelist : list, optional (default=G.nodes())
- This provides the node order for the `node_size` array (if it is an array).
- node_shape : string (default='o')
- The marker used for nodes, used in determining edge positioning.
- Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'.
- label : None or string
- Label for legend
- min_source_margin : int or list of ints (default=0)
- The minimum margin (gap) at the beginning of the edge at the source.
- min_target_margin : int or list of ints (default=0)
- The minimum margin (gap) at the end of the edge at the target.
- hide_ticks : bool, optional
- Hide ticks of axes. When `True` (the default), ticks and ticklabels
- are removed from the axes. To set ticks and tick labels to the pyplot default,
- use ``hide_ticks=False``.
- Returns
- -------
- matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch
- If ``arrows=True``, a list of FancyArrowPatches is returned.
- If ``arrows=False``, a LineCollection is returned.
- If ``arrows=None`` (the default), then a LineCollection is returned if
- `G` is undirected, otherwise returns a list of FancyArrowPatches.
- Notes
- -----
- For directed graphs, arrows are drawn at the head end. Arrows can be
- turned off with keyword arrows=False or by passing an arrowstyle without
- an arrow on the end.
- Be sure to include `node_size` as a keyword argument; arrows are
- drawn considering the size of nodes.
- Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch`
- regardless of the value of `arrows` or whether `G` is directed.
- When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the
- FancyArrowPatches corresponding to the self-loops are not explicitly
- returned. They should instead be accessed via the ``Axes.patches``
- attribute (see examples).
- Examples
- --------
- >>> G = nx.dodecahedral_graph()
- >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
- >>> G = nx.DiGraph()
- >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
- >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
- >>> alphas = [0.3, 0.4, 0.5]
- >>> for i, arc in enumerate(arcs): # change alpha values of arcs
- ... arc.set_alpha(alphas[i])
- The FancyArrowPatches corresponding to self-loops are not always
- returned, but can always be accessed via the ``patches`` attribute of the
- `matplotlib.Axes` object.
- >>> import matplotlib.pyplot as plt
- >>> fig, ax = plt.subplots()
- >>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0
- >>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax)
- >>> self_loop_fap = ax.patches[0]
- Also see the NetworkX drawing examples at
- https://networkx.org/documentation/latest/auto_examples/index.html
- See Also
- --------
- draw
- draw_networkx
- draw_networkx_nodes
- draw_networkx_labels
- draw_networkx_edge_labels
- """
- import warnings
- import matplotlib as mpl
- import matplotlib.collections # call as mpl.collections
- import matplotlib.colors # call as mpl.colors
- import matplotlib.pyplot as plt
- import numpy as np
- # The default behavior is to use LineCollection to draw edges for
- # undirected graphs (for performance reasons) and use FancyArrowPatches
- # for directed graphs.
- # The `arrows` keyword can be used to override the default behavior
- if arrows is None:
- use_linecollection = not (G.is_directed() or G.is_multigraph())
- else:
- if not isinstance(arrows, bool):
- raise TypeError("Argument `arrows` must be of type bool or None")
- use_linecollection = not arrows
- if isinstance(connectionstyle, str):
- connectionstyle = [connectionstyle]
- elif np.iterable(connectionstyle):
- connectionstyle = list(connectionstyle)
- else:
- msg = "draw_networkx_edges arg `connectionstyle` must be str or iterable"
- raise nx.NetworkXError(msg)
- # Some kwargs only apply to FancyArrowPatches. Warn users when they use
- # non-default values for these kwargs when LineCollection is being used
- # instead of silently ignoring the specified option
- if use_linecollection:
- msg = (
- "\n\nThe {0} keyword argument is not applicable when drawing edges\n"
- "with LineCollection.\n\n"
- "To make this warning go away, either specify `arrows=True` to\n"
- "force FancyArrowPatches or use the default values.\n"
- "Note that using FancyArrowPatches may be slow for large graphs.\n"
- )
- if arrowstyle is not None:
- warnings.warn(msg.format("arrowstyle"), category=UserWarning, stacklevel=2)
- if arrowsize != 10:
- warnings.warn(msg.format("arrowsize"), category=UserWarning, stacklevel=2)
- if min_source_margin != 0:
- warnings.warn(
- msg.format("min_source_margin"), category=UserWarning, stacklevel=2
- )
- if min_target_margin != 0:
- warnings.warn(
- msg.format("min_target_margin"), category=UserWarning, stacklevel=2
- )
- if any(cs != "arc3" for cs in connectionstyle):
- warnings.warn(
- msg.format("connectionstyle"), category=UserWarning, stacklevel=2
- )
- # NOTE: Arrowstyle modification must occur after the warnings section
- if arrowstyle is None:
- arrowstyle = "-|>" if G.is_directed() else "-"
- if ax is None:
- ax = plt.gca()
- if edgelist is None:
- edgelist = list(G.edges) # (u, v, k) for multigraph (u, v) otherwise
- if len(edgelist):
- if G.is_multigraph():
- key_count = collections.defaultdict(lambda: itertools.count(0))
- edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
- else:
- edge_indices = [0] * len(edgelist)
- else: # no edges!
- return []
- if nodelist is None:
- nodelist = list(G.nodes())
- # FancyArrowPatch handles color=None different from LineCollection
- if edge_color is None:
- edge_color = "k"
- # set edge positions
- edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
- # Check if edge_color is an array of floats and map to edge_cmap.
- # This is the only case handled differently from matplotlib
- if (
- np.iterable(edge_color)
- and (len(edge_color) == len(edge_pos))
- and np.all([isinstance(c, Number) for c in edge_color])
- ):
- if edge_cmap is not None:
- assert isinstance(edge_cmap, mpl.colors.Colormap)
- else:
- edge_cmap = plt.get_cmap()
- if edge_vmin is None:
- edge_vmin = min(edge_color)
- if edge_vmax is None:
- edge_vmax = max(edge_color)
- color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
- edge_color = [edge_cmap(color_normal(e)) for e in edge_color]
- # compute initial view
- minx = np.amin(np.ravel(edge_pos[:, :, 0]))
- maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
- miny = np.amin(np.ravel(edge_pos[:, :, 1]))
- maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
- w = maxx - minx
- h = maxy - miny
- # Self-loops are scaled by view extent, except in cases the extent
- # is 0, e.g. for a single node. In this case, fall back to scaling
- # by the maximum node size
- selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
- fancy_arrow_factory = FancyArrowFactory(
- edge_pos,
- edgelist,
- nodelist,
- edge_indices,
- node_size,
- selfloop_height,
- connectionstyle,
- node_shape,
- arrowstyle,
- arrowsize,
- edge_color,
- alpha,
- width,
- style,
- min_source_margin,
- min_target_margin,
- ax=ax,
- )
- # Draw the edges
- if use_linecollection:
- edge_collection = mpl.collections.LineCollection(
- edge_pos,
- colors=edge_color,
- linewidths=width,
- antialiaseds=(1,),
- linestyle=style,
- alpha=alpha,
- )
- edge_collection.set_cmap(edge_cmap)
- edge_collection.set_clim(edge_vmin, edge_vmax)
- edge_collection.set_zorder(1) # edges go behind nodes
- edge_collection.set_label(label)
- ax.add_collection(edge_collection)
- edge_viz_obj = edge_collection
- # Make sure selfloop edges are also drawn
- # ---------------------------------------
- selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist]
- if selfloops_to_draw:
- edgelist_tuple = list(map(tuple, edgelist))
- arrow_collection = []
- for loop in selfloops_to_draw:
- i = edgelist_tuple.index(loop)
- arrow = fancy_arrow_factory(i)
- arrow_collection.append(arrow)
- ax.add_patch(arrow)
- else:
- edge_viz_obj = []
- for i in range(len(edgelist)):
- arrow = fancy_arrow_factory(i)
- ax.add_patch(arrow)
- edge_viz_obj.append(arrow)
- # update view after drawing
- padx, pady = 0.05 * w, 0.05 * h
- corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
- ax.update_datalim(corners)
- ax.autoscale_view()
- if hide_ticks:
- ax.tick_params(
- axis="both",
- which="both",
- bottom=False,
- left=False,
- labelbottom=False,
- labelleft=False,
- )
- return edge_viz_obj
- def draw_networkx_labels(
- G,
- pos,
- labels=None,
- font_size=12,
- font_color="k",
- font_family="sans-serif",
- font_weight="normal",
- alpha=None,
- bbox=None,
- horizontalalignment="center",
- verticalalignment="center",
- ax=None,
- clip_on=True,
- hide_ticks=True,
- ):
- """Draw node labels on the graph G.
- Parameters
- ----------
- G : graph
- A networkx graph
- pos : dictionary
- A dictionary with nodes as keys and positions as values.
- Positions should be sequences of length 2.
- labels : dictionary (default={n: n for n in G})
- Node labels in a dictionary of text labels keyed by node.
- Node-keys in labels should appear as keys in `pos`.
- If needed use: `{n:lab for n,lab in labels.items() if n in pos}`
- font_size : int or dictionary of nodes to ints (default=12)
- Font size for text labels.
- font_color : color or dictionary of nodes to colors (default='k' black)
- Font color string. Color can be string or rgb (or rgba) tuple of
- floats from 0-1.
- font_weight : string or dictionary of nodes to strings (default='normal')
- Font weight.
- font_family : string or dictionary of nodes to strings (default='sans-serif')
- Font family.
- alpha : float or None or dictionary of nodes to floats (default=None)
- The text transparency.
- bbox : Matplotlib bbox, (default is Matplotlib's ax.text default)
- Specify text box properties (e.g. shape, color etc.) for node labels.
- horizontalalignment : string or array of strings (default='center')
- Horizontal alignment {'center', 'right', 'left'}. If an array is
- specified it must be the same length as `nodelist`.
- verticalalignment : string (default='center')
- Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}.
- If an array is specified it must be the same length as `nodelist`.
- ax : Matplotlib Axes object, optional
- Draw the graph in the specified Matplotlib axes.
- clip_on : bool (default=True)
- Turn on clipping of node labels at axis boundaries
- hide_ticks : bool, optional
- Hide ticks of axes. When `True` (the default), ticks and ticklabels
- are removed from the axes. To set ticks and tick labels to the pyplot default,
- use ``hide_ticks=False``.
- Returns
- -------
- dict
- `dict` of labels keyed on the nodes
- Examples
- --------
- >>> G = nx.dodecahedral_graph()
- >>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G))
- Also see the NetworkX drawing examples at
- https://networkx.org/documentation/latest/auto_examples/index.html
- See Also
- --------
- draw
- draw_networkx
- draw_networkx_nodes
- draw_networkx_edges
- draw_networkx_edge_labels
- """
- import matplotlib.pyplot as plt
- if ax is None:
- ax = plt.gca()
- if labels is None:
- labels = {n: n for n in G.nodes()}
- individual_params = set()
- def check_individual_params(p_value, p_name):
- if isinstance(p_value, dict):
- if len(p_value) != len(labels):
- raise ValueError(f"{p_name} must have the same length as labels.")
- individual_params.add(p_name)
- def get_param_value(node, p_value, p_name):
- if p_name in individual_params:
- return p_value[node]
- return p_value
- check_individual_params(font_size, "font_size")
- check_individual_params(font_color, "font_color")
- check_individual_params(font_weight, "font_weight")
- check_individual_params(font_family, "font_family")
- check_individual_params(alpha, "alpha")
- text_items = {} # there is no text collection so we'll fake one
- for n, label in labels.items():
- (x, y) = pos[n]
- if not isinstance(label, str):
- label = str(label) # this makes "1" and 1 labeled the same
- t = ax.text(
- x,
- y,
- label,
- size=get_param_value(n, font_size, "font_size"),
- color=get_param_value(n, font_color, "font_color"),
- family=get_param_value(n, font_family, "font_family"),
- weight=get_param_value(n, font_weight, "font_weight"),
- alpha=get_param_value(n, alpha, "alpha"),
- horizontalalignment=horizontalalignment,
- verticalalignment=verticalalignment,
- transform=ax.transData,
- bbox=bbox,
- clip_on=clip_on,
- )
- text_items[n] = t
- if hide_ticks:
- ax.tick_params(
- axis="both",
- which="both",
- bottom=False,
- left=False,
- labelbottom=False,
- labelleft=False,
- )
- return text_items
- def draw_networkx_edge_labels(
- G,
- pos,
- edge_labels=None,
- label_pos=0.5,
- font_size=10,
- font_color="k",
- font_family="sans-serif",
- font_weight="normal",
- alpha=None,
- bbox=None,
- horizontalalignment="center",
- verticalalignment="center",
- ax=None,
- rotate=True,
- clip_on=True,
- node_size=300,
- nodelist=None,
- connectionstyle="arc3",
- hide_ticks=True,
- ):
- """Draw edge labels.
- Parameters
- ----------
- G : graph
- A networkx graph
- pos : dictionary
- A dictionary with nodes as keys and positions as values.
- Positions should be sequences of length 2.
- edge_labels : dictionary (default=None)
- Edge labels in a dictionary of labels keyed by edge two-tuple.
- Only labels for the keys in the dictionary are drawn.
- label_pos : float (default=0.5)
- Position of edge label along edge (0=head, 0.5=center, 1=tail)
- font_size : int (default=10)
- Font size for text labels
- font_color : color (default='k' black)
- Font color string. Color can be string or rgb (or rgba) tuple of
- floats from 0-1.
- font_weight : string (default='normal')
- Font weight
- font_family : string (default='sans-serif')
- Font family
- alpha : float or None (default=None)
- The text transparency
- bbox : Matplotlib bbox, optional
- Specify text box properties (e.g. shape, color etc.) for edge labels.
- Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
- horizontalalignment : string (default='center')
- Horizontal alignment {'center', 'right', 'left'}
- verticalalignment : string (default='center')
- Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
- ax : Matplotlib Axes object, optional
- Draw the graph in the specified Matplotlib axes.
- rotate : bool (default=True)
- Rotate edge labels to lie parallel to edges
- clip_on : bool (default=True)
- Turn on clipping of edge labels at axis boundaries
- node_size : scalar or array (default=300)
- Size of nodes. If an array it must be the same length as nodelist.
- nodelist : list, optional (default=G.nodes())
- This provides the node order for the `node_size` array (if it is an array).
- connectionstyle : string or iterable of strings (default="arc3")
- Pass the connectionstyle parameter to create curved arc of rounding
- radius rad. For example, connectionstyle='arc3,rad=0.2'.
- See `matplotlib.patches.ConnectionStyle` and
- `matplotlib.patches.FancyArrowPatch` for more info.
- If Iterable, index indicates i'th edge key of MultiGraph
- hide_ticks : bool, optional
- Hide ticks of axes. When `True` (the default), ticks and ticklabels
- are removed from the axes. To set ticks and tick labels to the pyplot default,
- use ``hide_ticks=False``.
- Returns
- -------
- dict
- `dict` of labels keyed by edge
- Examples
- --------
- >>> G = nx.dodecahedral_graph()
- >>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
- Also see the NetworkX drawing examples at
- https://networkx.org/documentation/latest/auto_examples/index.html
- See Also
- --------
- draw
- draw_networkx
- draw_networkx_nodes
- draw_networkx_edges
- draw_networkx_labels
- """
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- import numpy as np
- class CurvedArrowText(CurvedArrowTextBase, mpl.text.Text):
- pass
- # use default box of white with white border
- if bbox is None:
- bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)}
- if isinstance(connectionstyle, str):
- connectionstyle = [connectionstyle]
- elif np.iterable(connectionstyle):
- connectionstyle = list(connectionstyle)
- else:
- raise nx.NetworkXError(
- "draw_networkx_edges arg `connectionstyle` must be"
- "string or iterable of strings"
- )
- if ax is None:
- ax = plt.gca()
- if edge_labels is None:
- kwds = {"keys": True} if G.is_multigraph() else {}
- edge_labels = {tuple(edge): d for *edge, d in G.edges(data=True, **kwds)}
- # NOTHING TO PLOT
- if not edge_labels:
- return {}
- edgelist, labels = zip(*edge_labels.items())
- if nodelist is None:
- nodelist = list(G.nodes())
- # set edge positions
- edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
- if G.is_multigraph():
- key_count = collections.defaultdict(lambda: itertools.count(0))
- edge_indices = [next(key_count[tuple(e[:2])]) for e in edgelist]
- else:
- edge_indices = [0] * len(edgelist)
- # Used to determine self loop mid-point
- # Note, that this will not be accurate,
- # if not drawing edge_labels for all edges drawn
- h = 0
- if edge_labels:
- miny = np.amin(np.ravel(edge_pos[:, :, 1]))
- maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
- h = maxy - miny
- selfloop_height = h if h != 0 else 0.005 * np.array(node_size).max()
- fancy_arrow_factory = FancyArrowFactory(
- edge_pos,
- edgelist,
- nodelist,
- edge_indices,
- node_size,
- selfloop_height,
- connectionstyle,
- ax=ax,
- )
- individual_params = {}
- def check_individual_params(p_value, p_name):
- # TODO should this be list or array (as in a numpy array)?
- if isinstance(p_value, list):
- if len(p_value) != len(edgelist):
- raise ValueError(f"{p_name} must have the same length as edgelist.")
- individual_params[p_name] = p_value.iter()
- # Don't need to pass in an edge because these are lists, not dicts
- def get_param_value(p_value, p_name):
- if p_name in individual_params:
- return next(individual_params[p_name])
- return p_value
- check_individual_params(font_size, "font_size")
- check_individual_params(font_color, "font_color")
- check_individual_params(font_weight, "font_weight")
- check_individual_params(alpha, "alpha")
- check_individual_params(horizontalalignment, "horizontalalignment")
- check_individual_params(verticalalignment, "verticalalignment")
- check_individual_params(rotate, "rotate")
- check_individual_params(label_pos, "label_pos")
- text_items = {}
- for i, (edge, label) in enumerate(zip(edgelist, labels)):
- if not isinstance(label, str):
- label = str(label) # this makes "1" and 1 labeled the same
- n1, n2 = edge[:2]
- arrow = fancy_arrow_factory(i)
- if n1 == n2:
- connectionstyle_obj = arrow.get_connectionstyle()
- posA = ax.transData.transform(pos[n1])
- path_disp = connectionstyle_obj(posA, posA)
- path_data = ax.transData.inverted().transform_path(path_disp)
- x, y = path_data.vertices[0]
- text_items[edge] = ax.text(
- x,
- y,
- label,
- size=get_param_value(font_size, "font_size"),
- color=get_param_value(font_color, "font_color"),
- family=get_param_value(font_family, "font_family"),
- weight=get_param_value(font_weight, "font_weight"),
- alpha=get_param_value(alpha, "alpha"),
- horizontalalignment=get_param_value(
- horizontalalignment, "horizontalalignment"
- ),
- verticalalignment=get_param_value(
- verticalalignment, "verticalalignment"
- ),
- rotation=0,
- transform=ax.transData,
- bbox=bbox,
- zorder=1,
- clip_on=clip_on,
- )
- else:
- text_items[edge] = CurvedArrowText(
- arrow,
- label,
- size=get_param_value(font_size, "font_size"),
- color=get_param_value(font_color, "font_color"),
- family=get_param_value(font_family, "font_family"),
- weight=get_param_value(font_weight, "font_weight"),
- alpha=get_param_value(alpha, "alpha"),
- horizontalalignment=get_param_value(
- horizontalalignment, "horizontalalignment"
- ),
- verticalalignment=get_param_value(
- verticalalignment, "verticalalignment"
- ),
- transform=ax.transData,
- bbox=bbox,
- zorder=1,
- clip_on=clip_on,
- label_pos=get_param_value(label_pos, "label_pos"),
- labels_horizontal=not get_param_value(rotate, "rotate"),
- ax=ax,
- )
- if hide_ticks:
- ax.tick_params(
- axis="both",
- which="both",
- bottom=False,
- left=False,
- labelbottom=False,
- labelleft=False,
- )
- return text_items
- def draw_bipartite(G, **kwargs):
- """Draw the graph `G` with a bipartite layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.bipartite_layout(G), **kwargs)
- Parameters
- ----------
- G : graph
- A networkx graph
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Raises
- ------
- NetworkXError :
- If `G` is not bipartite.
- Notes
- -----
- The layout is computed each time this function is called. For
- repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.bipartite_layout` directly and reuse the result::
- >>> G = nx.complete_bipartite_graph(3, 3)
- >>> pos = nx.bipartite_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.complete_bipartite_graph(2, 5)
- >>> nx.draw_bipartite(G)
- See Also
- --------
- :func:`~networkx.drawing.layout.bipartite_layout`
- """
- draw(G, pos=nx.bipartite_layout(G), **kwargs)
- def draw_circular(G, **kwargs):
- """Draw the graph `G` with a circular layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.circular_layout(G), **kwargs)
- Parameters
- ----------
- G : graph
- A networkx graph
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Notes
- -----
- The layout is computed each time this function is called. For
- repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.circular_layout` directly and reuse the result::
- >>> G = nx.complete_graph(5)
- >>> pos = nx.circular_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.path_graph(5)
- >>> nx.draw_circular(G)
- See Also
- --------
- :func:`~networkx.drawing.layout.circular_layout`
- """
- draw(G, pos=nx.circular_layout(G), **kwargs)
- def draw_kamada_kawai(G, **kwargs):
- """Draw the graph `G` with a Kamada-Kawai force-directed layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
- Parameters
- ----------
- G : graph
- A networkx graph
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Notes
- -----
- The layout is computed each time this function is called.
- For repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the
- result::
- >>> G = nx.complete_graph(5)
- >>> pos = nx.kamada_kawai_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.path_graph(5)
- >>> nx.draw_kamada_kawai(G)
- See Also
- --------
- :func:`~networkx.drawing.layout.kamada_kawai_layout`
- """
- draw(G, pos=nx.kamada_kawai_layout(G), **kwargs)
- def draw_random(G, **kwargs):
- """Draw the graph `G` with a random layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.random_layout(G), **kwargs)
- Parameters
- ----------
- G : graph
- A networkx graph
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Notes
- -----
- The layout is computed each time this function is called.
- For repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.random_layout` directly and reuse the result::
- >>> G = nx.complete_graph(5)
- >>> pos = nx.random_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.lollipop_graph(4, 3)
- >>> nx.draw_random(G)
- See Also
- --------
- :func:`~networkx.drawing.layout.random_layout`
- """
- draw(G, pos=nx.random_layout(G), **kwargs)
- def draw_spectral(G, **kwargs):
- """Draw the graph `G` with a spectral 2D layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.spectral_layout(G), **kwargs)
- For more information about how node positions are determined, see
- `~networkx.drawing.layout.spectral_layout`.
- Parameters
- ----------
- G : graph
- A networkx graph
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Notes
- -----
- The layout is computed each time this function is called.
- For repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.spectral_layout` directly and reuse the result::
- >>> G = nx.complete_graph(5)
- >>> pos = nx.spectral_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.path_graph(5)
- >>> nx.draw_spectral(G)
- See Also
- --------
- :func:`~networkx.drawing.layout.spectral_layout`
- """
- draw(G, pos=nx.spectral_layout(G), **kwargs)
- def draw_spring(G, **kwargs):
- """Draw the graph `G` with a spring layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.spring_layout(G), **kwargs)
- Parameters
- ----------
- G : graph
- A networkx graph
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Notes
- -----
- `~networkx.drawing.layout.spring_layout` is also the default layout for
- `draw`, so this function is equivalent to `draw`.
- The layout is computed each time this function is called.
- For repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.spring_layout` directly and reuse the result::
- >>> G = nx.complete_graph(5)
- >>> pos = nx.spring_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.path_graph(20)
- >>> nx.draw_spring(G)
- See Also
- --------
- draw
- :func:`~networkx.drawing.layout.spring_layout`
- """
- draw(G, pos=nx.spring_layout(G), **kwargs)
- def draw_shell(G, nlist=None, **kwargs):
- """Draw networkx graph `G` with shell layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
- Parameters
- ----------
- G : graph
- A networkx graph
- nlist : list of list of nodes, optional
- A list containing lists of nodes representing the shells.
- Default is `None`, meaning all nodes are in a single shell.
- See `~networkx.drawing.layout.shell_layout` for details.
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Notes
- -----
- The layout is computed each time this function is called.
- For repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.shell_layout` directly and reuse the result::
- >>> G = nx.complete_graph(5)
- >>> pos = nx.shell_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.path_graph(4)
- >>> shells = [[0], [1, 2, 3]]
- >>> nx.draw_shell(G, nlist=shells)
- See Also
- --------
- :func:`~networkx.drawing.layout.shell_layout`
- """
- draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs)
- def draw_planar(G, **kwargs):
- """Draw a planar networkx graph `G` with planar layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.planar_layout(G), **kwargs)
- Parameters
- ----------
- G : graph
- A planar networkx graph
- kwargs : optional keywords
- See `draw_networkx` for a description of optional keywords.
- Raises
- ------
- NetworkXException
- When `G` is not planar
- Notes
- -----
- The layout is computed each time this function is called.
- For repeated drawing it is much more efficient to call
- `~networkx.drawing.layout.planar_layout` directly and reuse the result::
- >>> G = nx.path_graph(5)
- >>> pos = nx.planar_layout(G)
- >>> nx.draw(G, pos=pos) # Draw the original graph
- >>> # Draw a subgraph, reusing the same node positions
- >>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red")
- Examples
- --------
- >>> G = nx.path_graph(4)
- >>> nx.draw_planar(G)
- See Also
- --------
- :func:`~networkx.drawing.layout.planar_layout`
- """
- draw(G, pos=nx.planar_layout(G), **kwargs)
- def draw_forceatlas2(G, **kwargs):
- """Draw a networkx graph with forceatlas2 layout.
- This is a convenience function equivalent to::
- nx.draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
- Parameters
- ----------
- G : graph
- A networkx graph
- kwargs : optional keywords
- See networkx.draw_networkx() for a description of optional keywords,
- with the exception of the pos parameter which is not used by this
- function.
- """
- draw(G, pos=nx.forceatlas2_layout(G), **kwargs)
- def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None):
- """Apply an alpha (or list of alphas) to the colors provided.
- Parameters
- ----------
- colors : color string or array of floats (default='r')
- Color of element. Can be a single color format string,
- or a sequence of colors with the same length as nodelist.
- If numeric values are specified they will be mapped to
- colors using the cmap and vmin,vmax parameters. See
- matplotlib.scatter for more details.
- alpha : float or array of floats
- Alpha values for elements. This can be a single alpha value, in
- which case it will be applied to all the elements of color. Otherwise,
- if it is an array, the elements of alpha will be applied to the colors
- in order (cycling through alpha multiple times if necessary).
- elem_list : array of networkx objects
- The list of elements which are being colored. These could be nodes,
- edges or labels.
- cmap : matplotlib colormap
- Color map for use if colors is a list of floats corresponding to points
- on a color mapping.
- vmin, vmax : float
- Minimum and maximum values for normalizing colors if a colormap is used
- Returns
- -------
- rgba_colors : numpy ndarray
- Array containing RGBA format values for each of the node colours.
- """
- from itertools import cycle, islice
- import matplotlib as mpl
- import matplotlib.cm # call as mpl.cm
- import matplotlib.colors # call as mpl.colors
- import numpy as np
- # If we have been provided with a list of numbers as long as elem_list,
- # apply the color mapping.
- if len(colors) == len(elem_list) and isinstance(colors[0], Number):
- mapper = mpl.cm.ScalarMappable(cmap=cmap)
- mapper.set_clim(vmin, vmax)
- rgba_colors = mapper.to_rgba(colors)
- # Otherwise, convert colors to matplotlib's RGB using the colorConverter
- # object. These are converted to numpy ndarrays to be consistent with the
- # to_rgba method of ScalarMappable.
- else:
- try:
- rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)])
- except ValueError:
- rgba_colors = np.array(
- [mpl.colors.colorConverter.to_rgba(color) for color in colors]
- )
- # Set the final column of the rgba_colors to have the relevant alpha values
- try:
- # If alpha is longer than the number of colors, resize to the number of
- # elements. Also, if rgba_colors.size (the number of elements of
- # rgba_colors) is the same as the number of elements, resize the array,
- # to avoid it being interpreted as a colormap by scatter()
- if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list):
- rgba_colors = np.resize(rgba_colors, (len(elem_list), 4))
- rgba_colors[1:, 0] = rgba_colors[0, 0]
- rgba_colors[1:, 1] = rgba_colors[0, 1]
- rgba_colors[1:, 2] = rgba_colors[0, 2]
- rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors)))
- except TypeError:
- rgba_colors[:, -1] = alpha
- return rgba_colors
|