gexf.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084
  1. """Read and write graphs in GEXF format.
  2. .. warning::
  3. This parser uses the standard xml library present in Python, which is
  4. insecure - see :external+python:mod:`xml` for additional information.
  5. Only parse GEFX files you trust.
  6. GEXF (Graph Exchange XML Format) is a language for describing complex
  7. network structures, their associated data and dynamics.
  8. This implementation does not support mixed graphs (directed and
  9. undirected edges together).
  10. Format
  11. ------
  12. GEXF is an XML format. See http://gexf.net/schema.html for the
  13. specification and http://gexf.net/basic.html for examples.
  14. """
  15. import itertools
  16. import time
  17. from xml.etree.ElementTree import (
  18. Element,
  19. ElementTree,
  20. SubElement,
  21. register_namespace,
  22. tostring,
  23. )
  24. import networkx as nx
  25. from networkx.utils import open_file
  26. __all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"]
  27. @open_file(1, mode="wb")
  28. def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"):
  29. """Write G in GEXF format to path.
  30. "GEXF (Graph Exchange XML Format) is a language for describing
  31. complex networks structures, their associated data and dynamics" [1]_.
  32. Node attributes are checked according to the version of the GEXF
  33. schemas used for parameters which are not user defined,
  34. e.g. visualization 'viz' [2]_. See example for usage.
  35. .. warning::
  36. The `GEXF specification <https://gexf.net/schema.html>`_ reserves some
  37. keywords (e.g. ``id``, ``pid``, ``label``, etc.) for specifying node/edge
  38. metadata in the file format. Ensure NetworkX node/edge attribute names
  39. do not use these special keywords to guarantee all attributes are preserved
  40. as expected when roundtripping to/from GEXF format.
  41. Parameters
  42. ----------
  43. G : graph
  44. A NetworkX graph
  45. path : file or string
  46. File or file name to write.
  47. File names ending in .gz or .bz2 will be compressed.
  48. encoding : string (optional, default: 'utf-8')
  49. Encoding for text data.
  50. prettyprint : bool (optional, default: True)
  51. If True use line breaks and indenting in output XML.
  52. version: string (optional, default: '1.2draft')
  53. The version of GEXF to be used for nodes attributes checking
  54. Examples
  55. --------
  56. >>> G = nx.path_graph(4)
  57. >>> nx.write_gexf(G, "test.gexf")
  58. # visualization data
  59. >>> G.nodes[0]["viz"] = {"size": 54}
  60. >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1}
  61. >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256}
  62. Notes
  63. -----
  64. This implementation does not support mixed graphs (directed and undirected
  65. edges together).
  66. The node id attribute is set to be the string of the node label.
  67. If you want to specify an id use set it as node data, e.g.
  68. node['a']['id']=1 to set the id of node 'a' to 1.
  69. References
  70. ----------
  71. .. [1] GEXF File Format, http://gexf.net/
  72. .. [2] GEXF schema, http://gexf.net/schema.html
  73. """
  74. writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
  75. writer.add_graph(G)
  76. writer.write(path)
  77. def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"):
  78. """Generate lines of GEXF format representation of G.
  79. "GEXF (Graph Exchange XML Format) is a language for describing
  80. complex networks structures, their associated data and dynamics" [1]_.
  81. Parameters
  82. ----------
  83. G : graph
  84. A NetworkX graph
  85. encoding : string (optional, default: 'utf-8')
  86. Encoding for text data.
  87. prettyprint : bool (optional, default: True)
  88. If True use line breaks and indenting in output XML.
  89. version : string (default: 1.2draft)
  90. Version of GEFX File Format (see http://gexf.net/schema.html)
  91. Supported values: "1.1draft", "1.2draft"
  92. Examples
  93. --------
  94. >>> G = nx.path_graph(4)
  95. >>> linefeed = chr(10) # linefeed=\n
  96. >>> s = linefeed.join(nx.generate_gexf(G))
  97. >>> for line in nx.generate_gexf(G): # doctest: +SKIP
  98. ... print(line)
  99. Notes
  100. -----
  101. This implementation does not support mixed graphs (directed and undirected
  102. edges together).
  103. The node id attribute is set to be the string of the node label.
  104. If you want to specify an id use set it as node data, e.g.
  105. node['a']['id']=1 to set the id of node 'a' to 1.
  106. References
  107. ----------
  108. .. [1] GEXF File Format, https://gephi.org/gexf/format/
  109. """
  110. writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
  111. writer.add_graph(G)
  112. yield from str(writer).splitlines()
  113. @open_file(0, mode="rb")
  114. @nx._dispatchable(graphs=None, returns_graph=True)
  115. def read_gexf(path, node_type=None, relabel=False, version="1.2draft"):
  116. """Read graph in GEXF format from path.
  117. "GEXF (Graph Exchange XML Format) is a language for describing
  118. complex networks structures, their associated data and dynamics" [1]_.
  119. Parameters
  120. ----------
  121. path : file or string
  122. Filename or file handle to read.
  123. Filenames ending in .gz or .bz2 will be decompressed.
  124. node_type: Python type (default: None)
  125. Convert node ids to this type if not None.
  126. relabel : bool (default: False)
  127. If True relabel the nodes to use the GEXF node "label" attribute
  128. instead of the node "id" attribute as the NetworkX node label.
  129. version : string (default: 1.2draft)
  130. Version of GEFX File Format (see http://gexf.net/schema.html)
  131. Supported values: "1.1draft", "1.2draft"
  132. Returns
  133. -------
  134. graph: NetworkX graph
  135. If no parallel edges are found a Graph or DiGraph is returned.
  136. Otherwise a MultiGraph or MultiDiGraph is returned.
  137. Notes
  138. -----
  139. This implementation does not support mixed graphs (directed and undirected
  140. edges together).
  141. References
  142. ----------
  143. .. [1] GEXF File Format, http://gexf.net/
  144. """
  145. reader = GEXFReader(node_type=node_type, version=version)
  146. if relabel:
  147. G = relabel_gexf_graph(reader(path))
  148. else:
  149. G = reader(path)
  150. return G
  151. class GEXF:
  152. versions = {
  153. "1.1draft": {
  154. "NS_GEXF": "http://www.gexf.net/1.1draft",
  155. "NS_VIZ": "http://www.gexf.net/1.1draft/viz",
  156. "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
  157. "SCHEMALOCATION": " ".join(
  158. [
  159. "http://www.gexf.net/1.1draft",
  160. "http://www.gexf.net/1.1draft/gexf.xsd",
  161. ]
  162. ),
  163. "VERSION": "1.1",
  164. },
  165. "1.2draft": {
  166. "NS_GEXF": "http://www.gexf.net/1.2draft",
  167. "NS_VIZ": "http://www.gexf.net/1.2draft/viz",
  168. "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
  169. "SCHEMALOCATION": " ".join(
  170. [
  171. "http://www.gexf.net/1.2draft",
  172. "http://www.gexf.net/1.2draft/gexf.xsd",
  173. ]
  174. ),
  175. "VERSION": "1.2",
  176. },
  177. "1.3": {
  178. "NS_GEXF": "http://gexf.net/1.3",
  179. "NS_VIZ": "http://gexf.net/1.3/viz",
  180. "NS_XSI": "http://w3.org/2001/XMLSchema-instance",
  181. "SCHEMALOCATION": " ".join(
  182. [
  183. "http://gexf.net/1.3",
  184. "http://gexf.net/1.3/gexf.xsd",
  185. ]
  186. ),
  187. "VERSION": "1.3",
  188. },
  189. }
  190. def construct_types(self):
  191. types = [
  192. (int, "integer"),
  193. (float, "float"),
  194. (float, "double"),
  195. (bool, "boolean"),
  196. (list, "string"),
  197. (dict, "string"),
  198. (int, "long"),
  199. (str, "liststring"),
  200. (str, "anyURI"),
  201. (str, "string"),
  202. ]
  203. # These additions to types allow writing numpy types
  204. try:
  205. import numpy as np
  206. except ImportError:
  207. pass
  208. else:
  209. # prepend so that python types are created upon read (last entry wins)
  210. types = [
  211. (np.float64, "float"),
  212. (np.float32, "float"),
  213. (np.float16, "float"),
  214. (np.int_, "int"),
  215. (np.int8, "int"),
  216. (np.int16, "int"),
  217. (np.int32, "int"),
  218. (np.int64, "int"),
  219. (np.uint8, "int"),
  220. (np.uint16, "int"),
  221. (np.uint32, "int"),
  222. (np.uint64, "int"),
  223. (np.int_, "int"),
  224. (np.intc, "int"),
  225. (np.intp, "int"),
  226. ] + types
  227. self.xml_type = dict(types)
  228. self.python_type = dict(reversed(a) for a in types)
  229. # http://www.w3.org/TR/xmlschema-2/#boolean
  230. convert_bool = {
  231. "true": True,
  232. "false": False,
  233. "True": True,
  234. "False": False,
  235. "0": False,
  236. 0: False,
  237. "1": True,
  238. 1: True,
  239. }
  240. def set_version(self, version):
  241. d = self.versions.get(version)
  242. if d is None:
  243. raise nx.NetworkXError(f"Unknown GEXF version {version}.")
  244. self.NS_GEXF = d["NS_GEXF"]
  245. self.NS_VIZ = d["NS_VIZ"]
  246. self.NS_XSI = d["NS_XSI"]
  247. self.SCHEMALOCATION = d["SCHEMALOCATION"]
  248. self.VERSION = d["VERSION"]
  249. self.version = version
  250. class GEXFWriter(GEXF):
  251. # class for writing GEXF format files
  252. # use write_gexf() function
  253. def __init__(
  254. self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft"
  255. ):
  256. self.construct_types()
  257. self.prettyprint = prettyprint
  258. self.encoding = encoding
  259. self.set_version(version)
  260. self.xml = Element(
  261. "gexf",
  262. {
  263. "xmlns": self.NS_GEXF,
  264. "xmlns:xsi": self.NS_XSI,
  265. "xsi:schemaLocation": self.SCHEMALOCATION,
  266. "version": self.VERSION,
  267. },
  268. )
  269. # Make meta element a non-graph element
  270. # Also add lastmodifieddate as attribute, not tag
  271. meta_element = Element("meta")
  272. subelement_text = f"NetworkX {nx.__version__}"
  273. SubElement(meta_element, "creator").text = subelement_text
  274. meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d"))
  275. self.xml.append(meta_element)
  276. register_namespace("viz", self.NS_VIZ)
  277. # counters for edge and attribute identifiers
  278. self.edge_id = itertools.count()
  279. self.attr_id = itertools.count()
  280. self.all_edge_ids = set()
  281. # default attributes are stored in dictionaries
  282. self.attr = {}
  283. self.attr["node"] = {}
  284. self.attr["edge"] = {}
  285. self.attr["node"]["dynamic"] = {}
  286. self.attr["node"]["static"] = {}
  287. self.attr["edge"]["dynamic"] = {}
  288. self.attr["edge"]["static"] = {}
  289. if graph is not None:
  290. self.add_graph(graph)
  291. def __str__(self):
  292. if self.prettyprint:
  293. self.indent(self.xml)
  294. s = tostring(self.xml).decode(self.encoding)
  295. return s
  296. def add_graph(self, G):
  297. # first pass through G collecting edge ids
  298. for u, v, dd in G.edges(data=True):
  299. eid = dd.get("id")
  300. if eid is not None:
  301. self.all_edge_ids.add(str(eid))
  302. # set graph attributes
  303. if G.graph.get("mode") == "dynamic":
  304. mode = "dynamic"
  305. else:
  306. mode = "static"
  307. # Add a graph element to the XML
  308. if G.is_directed():
  309. default = "directed"
  310. else:
  311. default = "undirected"
  312. name = G.graph.get("name", "")
  313. graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name)
  314. self.graph_element = graph_element
  315. self.add_nodes(G, graph_element)
  316. self.add_edges(G, graph_element)
  317. self.xml.append(graph_element)
  318. def add_nodes(self, G, graph_element):
  319. nodes_element = Element("nodes")
  320. for node, data in G.nodes(data=True):
  321. node_data = data.copy()
  322. node_id = str(node_data.pop("id", node))
  323. kw = {"id": node_id}
  324. label = str(node_data.pop("label", node))
  325. kw["label"] = label
  326. try:
  327. pid = node_data.pop("pid")
  328. kw["pid"] = str(pid)
  329. except KeyError:
  330. pass
  331. try:
  332. start = node_data.pop("start")
  333. kw["start"] = str(start)
  334. self.alter_graph_mode_timeformat(start)
  335. except KeyError:
  336. pass
  337. try:
  338. end = node_data.pop("end")
  339. kw["end"] = str(end)
  340. self.alter_graph_mode_timeformat(end)
  341. except KeyError:
  342. pass
  343. # add node element with attributes
  344. node_element = Element("node", **kw)
  345. # add node element and attr subelements
  346. default = G.graph.get("node_default", {})
  347. node_data = self.add_parents(node_element, node_data)
  348. if self.VERSION == "1.1":
  349. node_data = self.add_slices(node_element, node_data)
  350. else:
  351. node_data = self.add_spells(node_element, node_data)
  352. node_data = self.add_viz(node_element, node_data)
  353. node_data = self.add_attributes("node", node_element, node_data, default)
  354. nodes_element.append(node_element)
  355. graph_element.append(nodes_element)
  356. def add_edges(self, G, graph_element):
  357. def edge_key_data(G):
  358. # helper function to unify multigraph and graph edge iterator
  359. if G.is_multigraph():
  360. for u, v, key, data in G.edges(data=True, keys=True):
  361. edge_data = data.copy()
  362. edge_data.update(key=key)
  363. edge_id = edge_data.pop("id", None)
  364. if edge_id is None:
  365. edge_id = next(self.edge_id)
  366. while str(edge_id) in self.all_edge_ids:
  367. edge_id = next(self.edge_id)
  368. self.all_edge_ids.add(str(edge_id))
  369. yield u, v, edge_id, edge_data
  370. else:
  371. for u, v, data in G.edges(data=True):
  372. edge_data = data.copy()
  373. edge_id = edge_data.pop("id", None)
  374. if edge_id is None:
  375. edge_id = next(self.edge_id)
  376. while str(edge_id) in self.all_edge_ids:
  377. edge_id = next(self.edge_id)
  378. self.all_edge_ids.add(str(edge_id))
  379. yield u, v, edge_id, edge_data
  380. edges_element = Element("edges")
  381. for u, v, key, edge_data in edge_key_data(G):
  382. kw = {"id": str(key)}
  383. try:
  384. edge_label = edge_data.pop("label")
  385. kw["label"] = str(edge_label)
  386. except KeyError:
  387. pass
  388. try:
  389. edge_weight = edge_data.pop("weight")
  390. kw["weight"] = str(edge_weight)
  391. except KeyError:
  392. pass
  393. try:
  394. edge_type = edge_data.pop("type")
  395. kw["type"] = str(edge_type)
  396. except KeyError:
  397. pass
  398. try:
  399. start = edge_data.pop("start")
  400. kw["start"] = str(start)
  401. self.alter_graph_mode_timeformat(start)
  402. except KeyError:
  403. pass
  404. try:
  405. end = edge_data.pop("end")
  406. kw["end"] = str(end)
  407. self.alter_graph_mode_timeformat(end)
  408. except KeyError:
  409. pass
  410. source_id = str(G.nodes[u].get("id", u))
  411. target_id = str(G.nodes[v].get("id", v))
  412. edge_element = Element("edge", source=source_id, target=target_id, **kw)
  413. default = G.graph.get("edge_default", {})
  414. if self.VERSION == "1.1":
  415. edge_data = self.add_slices(edge_element, edge_data)
  416. else:
  417. edge_data = self.add_spells(edge_element, edge_data)
  418. edge_data = self.add_viz(edge_element, edge_data)
  419. edge_data = self.add_attributes("edge", edge_element, edge_data, default)
  420. edges_element.append(edge_element)
  421. graph_element.append(edges_element)
  422. def add_attributes(self, node_or_edge, xml_obj, data, default):
  423. # Add attrvalues to node or edge
  424. attvalues = Element("attvalues")
  425. if len(data) == 0:
  426. return data
  427. mode = "static"
  428. for k, v in data.items():
  429. # rename generic multigraph key to avoid any name conflict
  430. if k == "key":
  431. k = "networkx_key"
  432. val_type = type(v)
  433. if val_type not in self.xml_type:
  434. raise TypeError(f"attribute value type is not allowed: {val_type}")
  435. if isinstance(v, list):
  436. # dynamic data
  437. for val, start, end in v:
  438. val_type = type(val)
  439. if start is not None or end is not None:
  440. mode = "dynamic"
  441. self.alter_graph_mode_timeformat(start)
  442. self.alter_graph_mode_timeformat(end)
  443. break
  444. attr_id = self.get_attr_id(
  445. str(k), self.xml_type[val_type], node_or_edge, default, mode
  446. )
  447. for val, start, end in v:
  448. e = Element("attvalue")
  449. e.attrib["for"] = attr_id
  450. e.attrib["value"] = str(val)
  451. # Handle nan, inf, -inf differently
  452. if val_type is float:
  453. if e.attrib["value"] == "inf":
  454. e.attrib["value"] = "INF"
  455. elif e.attrib["value"] == "nan":
  456. e.attrib["value"] = "NaN"
  457. elif e.attrib["value"] == "-inf":
  458. e.attrib["value"] = "-INF"
  459. if start is not None:
  460. e.attrib["start"] = str(start)
  461. if end is not None:
  462. e.attrib["end"] = str(end)
  463. attvalues.append(e)
  464. else:
  465. # static data
  466. mode = "static"
  467. attr_id = self.get_attr_id(
  468. str(k), self.xml_type[val_type], node_or_edge, default, mode
  469. )
  470. e = Element("attvalue")
  471. e.attrib["for"] = attr_id
  472. if isinstance(v, bool):
  473. e.attrib["value"] = str(v).lower()
  474. else:
  475. e.attrib["value"] = str(v)
  476. # Handle float nan, inf, -inf differently
  477. if val_type is float:
  478. if e.attrib["value"] == "inf":
  479. e.attrib["value"] = "INF"
  480. elif e.attrib["value"] == "nan":
  481. e.attrib["value"] = "NaN"
  482. elif e.attrib["value"] == "-inf":
  483. e.attrib["value"] = "-INF"
  484. attvalues.append(e)
  485. xml_obj.append(attvalues)
  486. return data
  487. def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
  488. # find the id of the attribute or generate a new id
  489. try:
  490. return self.attr[edge_or_node][mode][title]
  491. except KeyError:
  492. # generate new id
  493. new_id = str(next(self.attr_id))
  494. self.attr[edge_or_node][mode][title] = new_id
  495. attr_kwargs = {"id": new_id, "title": title, "type": attr_type}
  496. attribute = Element("attribute", **attr_kwargs)
  497. # add subelement for data default value if present
  498. default_title = default.get(title)
  499. if default_title is not None:
  500. default_element = Element("default")
  501. default_element.text = str(default_title)
  502. attribute.append(default_element)
  503. # new insert it into the XML
  504. attributes_element = None
  505. for a in self.graph_element.findall("attributes"):
  506. # find existing attributes element by class and mode
  507. a_class = a.get("class")
  508. a_mode = a.get("mode", "static")
  509. if a_class == edge_or_node and a_mode == mode:
  510. attributes_element = a
  511. if attributes_element is None:
  512. # create new attributes element
  513. attr_kwargs = {"mode": mode, "class": edge_or_node}
  514. attributes_element = Element("attributes", **attr_kwargs)
  515. self.graph_element.insert(0, attributes_element)
  516. attributes_element.append(attribute)
  517. return new_id
  518. def add_viz(self, element, node_data):
  519. viz = node_data.pop("viz", False)
  520. if viz:
  521. color = viz.get("color")
  522. if color is not None:
  523. if self.VERSION == "1.1":
  524. e = Element(
  525. f"{{{self.NS_VIZ}}}color",
  526. r=str(color.get("r")),
  527. g=str(color.get("g")),
  528. b=str(color.get("b")),
  529. )
  530. else:
  531. e = Element(
  532. f"{{{self.NS_VIZ}}}color",
  533. r=str(color.get("r")),
  534. g=str(color.get("g")),
  535. b=str(color.get("b")),
  536. a=str(color.get("a", 1.0)),
  537. )
  538. element.append(e)
  539. size = viz.get("size")
  540. if size is not None:
  541. e = Element(f"{{{self.NS_VIZ}}}size", value=str(size))
  542. element.append(e)
  543. thickness = viz.get("thickness")
  544. if thickness is not None:
  545. e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness))
  546. element.append(e)
  547. shape = viz.get("shape")
  548. if shape is not None:
  549. if shape.startswith("http"):
  550. e = Element(
  551. f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape)
  552. )
  553. else:
  554. e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape))
  555. element.append(e)
  556. position = viz.get("position")
  557. if position is not None:
  558. e = Element(
  559. f"{{{self.NS_VIZ}}}position",
  560. x=str(position.get("x")),
  561. y=str(position.get("y")),
  562. z=str(position.get("z")),
  563. )
  564. element.append(e)
  565. return node_data
  566. def add_parents(self, node_element, node_data):
  567. parents = node_data.pop("parents", False)
  568. if parents:
  569. parents_element = Element("parents")
  570. for p in parents:
  571. e = Element("parent")
  572. e.attrib["for"] = str(p)
  573. parents_element.append(e)
  574. node_element.append(parents_element)
  575. return node_data
  576. def add_slices(self, node_or_edge_element, node_or_edge_data):
  577. slices = node_or_edge_data.pop("slices", False)
  578. if slices:
  579. slices_element = Element("slices")
  580. for start, end in slices:
  581. e = Element("slice", start=str(start), end=str(end))
  582. slices_element.append(e)
  583. node_or_edge_element.append(slices_element)
  584. return node_or_edge_data
  585. def add_spells(self, node_or_edge_element, node_or_edge_data):
  586. spells = node_or_edge_data.pop("spells", False)
  587. if spells:
  588. spells_element = Element("spells")
  589. for start, end in spells:
  590. e = Element("spell")
  591. if start is not None:
  592. e.attrib["start"] = str(start)
  593. self.alter_graph_mode_timeformat(start)
  594. if end is not None:
  595. e.attrib["end"] = str(end)
  596. self.alter_graph_mode_timeformat(end)
  597. spells_element.append(e)
  598. node_or_edge_element.append(spells_element)
  599. return node_or_edge_data
  600. def alter_graph_mode_timeformat(self, start_or_end):
  601. # If 'start' or 'end' appears, set timeformat
  602. if start_or_end is not None:
  603. if isinstance(start_or_end, str):
  604. timeformat = "date"
  605. elif isinstance(start_or_end, float):
  606. timeformat = "double"
  607. elif isinstance(start_or_end, int):
  608. timeformat = "long"
  609. else:
  610. raise nx.NetworkXError(
  611. "timeformat should be of the type int, float or str"
  612. )
  613. self.graph_element.set("timeformat", timeformat)
  614. # If Graph mode is static, alter to dynamic
  615. if self.graph_element.get("mode") == "static":
  616. self.graph_element.set("mode", "dynamic")
  617. def write(self, fh):
  618. # Serialize graph G in GEXF to the open fh
  619. if self.prettyprint:
  620. self.indent(self.xml)
  621. document = ElementTree(self.xml)
  622. document.write(fh, encoding=self.encoding, xml_declaration=True)
  623. def indent(self, elem, level=0):
  624. # in-place prettyprint formatter
  625. i = "\n" + " " * level
  626. if len(elem):
  627. if not elem.text or not elem.text.strip():
  628. elem.text = i + " "
  629. if not elem.tail or not elem.tail.strip():
  630. elem.tail = i
  631. for elem in elem:
  632. self.indent(elem, level + 1)
  633. if not elem.tail or not elem.tail.strip():
  634. elem.tail = i
  635. else:
  636. if level and (not elem.tail or not elem.tail.strip()):
  637. elem.tail = i
  638. class GEXFReader(GEXF):
  639. # Class to read GEXF format files
  640. # use read_gexf() function
  641. def __init__(self, node_type=None, version="1.2draft"):
  642. self.construct_types()
  643. self.node_type = node_type
  644. # assume simple graph and test for multigraph on read
  645. self.simple_graph = True
  646. self.set_version(version)
  647. def __call__(self, stream):
  648. self.xml = ElementTree(file=stream)
  649. g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
  650. if g is not None:
  651. return self.make_graph(g)
  652. # try all the versions
  653. for version in self.versions:
  654. self.set_version(version)
  655. g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
  656. if g is not None:
  657. return self.make_graph(g)
  658. raise nx.NetworkXError("No <graph> element in GEXF file.")
  659. def make_graph(self, graph_xml):
  660. # start with empty DiGraph or MultiDiGraph
  661. edgedefault = graph_xml.get("defaultedgetype", None)
  662. if edgedefault == "directed":
  663. G = nx.MultiDiGraph()
  664. else:
  665. G = nx.MultiGraph()
  666. # graph attributes
  667. graph_name = graph_xml.get("name", "")
  668. if graph_name != "":
  669. G.graph["name"] = graph_name
  670. graph_start = graph_xml.get("start")
  671. if graph_start is not None:
  672. G.graph["start"] = graph_start
  673. graph_end = graph_xml.get("end")
  674. if graph_end is not None:
  675. G.graph["end"] = graph_end
  676. graph_mode = graph_xml.get("mode", "")
  677. if graph_mode == "dynamic":
  678. G.graph["mode"] = "dynamic"
  679. else:
  680. G.graph["mode"] = "static"
  681. # timeformat
  682. self.timeformat = graph_xml.get("timeformat")
  683. if self.timeformat == "date":
  684. self.timeformat = "string"
  685. # node and edge attributes
  686. attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes")
  687. # dictionaries to hold attributes and attribute defaults
  688. node_attr = {}
  689. node_default = {}
  690. edge_attr = {}
  691. edge_default = {}
  692. for a in attributes_elements:
  693. attr_class = a.get("class")
  694. if attr_class == "node":
  695. na, nd = self.find_gexf_attributes(a)
  696. node_attr.update(na)
  697. node_default.update(nd)
  698. G.graph["node_default"] = node_default
  699. elif attr_class == "edge":
  700. ea, ed = self.find_gexf_attributes(a)
  701. edge_attr.update(ea)
  702. edge_default.update(ed)
  703. G.graph["edge_default"] = edge_default
  704. else:
  705. raise # unknown attribute class
  706. # Hack to handle Gephi0.7beta bug
  707. # add weight attribute
  708. ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}}
  709. ed = {}
  710. edge_attr.update(ea)
  711. edge_default.update(ed)
  712. G.graph["edge_default"] = edge_default
  713. # add nodes
  714. nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes")
  715. if nodes_element is not None:
  716. for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"):
  717. self.add_node(G, node_xml, node_attr)
  718. # add edges
  719. edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges")
  720. if edges_element is not None:
  721. for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"):
  722. self.add_edge(G, edge_xml, edge_attr)
  723. # switch to Graph or DiGraph if no parallel edges were found.
  724. if self.simple_graph:
  725. if G.is_directed():
  726. G = nx.DiGraph(G)
  727. else:
  728. G = nx.Graph(G)
  729. return G
  730. def add_node(self, G, node_xml, node_attr, node_pid=None):
  731. # add a single node with attributes to the graph
  732. # get attributes and subattributues for node
  733. data = self.decode_attr_elements(node_attr, node_xml)
  734. data = self.add_parents(data, node_xml) # add any parents
  735. if self.VERSION == "1.1":
  736. data = self.add_slices(data, node_xml) # add slices
  737. else:
  738. data = self.add_spells(data, node_xml) # add spells
  739. data = self.add_viz(data, node_xml) # add viz
  740. data = self.add_start_end(data, node_xml) # add start/end
  741. # find the node id and cast it to the appropriate type
  742. node_id = node_xml.get("id")
  743. if self.node_type is not None:
  744. node_id = self.node_type(node_id)
  745. # every node should have a label
  746. node_label = node_xml.get("label")
  747. data["label"] = node_label
  748. # parent node id
  749. node_pid = node_xml.get("pid", node_pid)
  750. if node_pid is not None:
  751. data["pid"] = node_pid
  752. # check for subnodes, recursive
  753. subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes")
  754. if subnodes is not None:
  755. for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"):
  756. self.add_node(G, node_xml, node_attr, node_pid=node_id)
  757. G.add_node(node_id, **data)
  758. def add_start_end(self, data, xml):
  759. # start and end times
  760. ttype = self.timeformat
  761. node_start = xml.get("start")
  762. if node_start is not None:
  763. data["start"] = self.python_type[ttype](node_start)
  764. node_end = xml.get("end")
  765. if node_end is not None:
  766. data["end"] = self.python_type[ttype](node_end)
  767. return data
  768. def add_viz(self, data, node_xml):
  769. # add viz element for node
  770. viz = {}
  771. color = node_xml.find(f"{{{self.NS_VIZ}}}color")
  772. if color is not None:
  773. if self.VERSION == "1.1":
  774. viz["color"] = {
  775. "r": int(color.get("r")),
  776. "g": int(color.get("g")),
  777. "b": int(color.get("b")),
  778. }
  779. else:
  780. viz["color"] = {
  781. "r": int(color.get("r")),
  782. "g": int(color.get("g")),
  783. "b": int(color.get("b")),
  784. "a": float(color.get("a", 1)),
  785. }
  786. size = node_xml.find(f"{{{self.NS_VIZ}}}size")
  787. if size is not None:
  788. viz["size"] = float(size.get("value"))
  789. thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness")
  790. if thickness is not None:
  791. viz["thickness"] = float(thickness.get("value"))
  792. shape = node_xml.find(f"{{{self.NS_VIZ}}}shape")
  793. if shape is not None:
  794. viz["shape"] = shape.get("shape")
  795. if viz["shape"] == "image":
  796. viz["shape"] = shape.get("uri")
  797. position = node_xml.find(f"{{{self.NS_VIZ}}}position")
  798. if position is not None:
  799. viz["position"] = {
  800. "x": float(position.get("x", 0)),
  801. "y": float(position.get("y", 0)),
  802. "z": float(position.get("z", 0)),
  803. }
  804. if len(viz) > 0:
  805. data["viz"] = viz
  806. return data
  807. def add_parents(self, data, node_xml):
  808. parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents")
  809. if parents_element is not None:
  810. data["parents"] = []
  811. for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"):
  812. parent = p.get("for")
  813. data["parents"].append(parent)
  814. return data
  815. def add_slices(self, data, node_or_edge_xml):
  816. slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices")
  817. if slices_element is not None:
  818. data["slices"] = []
  819. for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"):
  820. start = s.get("start")
  821. end = s.get("end")
  822. data["slices"].append((start, end))
  823. return data
  824. def add_spells(self, data, node_or_edge_xml):
  825. spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells")
  826. if spells_element is not None:
  827. data["spells"] = []
  828. ttype = self.timeformat
  829. for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"):
  830. start = self.python_type[ttype](s.get("start"))
  831. end = self.python_type[ttype](s.get("end"))
  832. data["spells"].append((start, end))
  833. return data
  834. def add_edge(self, G, edge_element, edge_attr):
  835. # add an edge to the graph
  836. # raise error if we find mixed directed and undirected edges
  837. edge_direction = edge_element.get("type")
  838. if G.is_directed() and edge_direction == "undirected":
  839. raise nx.NetworkXError("Undirected edge found in directed graph.")
  840. if (not G.is_directed()) and edge_direction == "directed":
  841. raise nx.NetworkXError("Directed edge found in undirected graph.")
  842. # Get source and target and recast type if required
  843. source = edge_element.get("source")
  844. target = edge_element.get("target")
  845. if self.node_type is not None:
  846. source = self.node_type(source)
  847. target = self.node_type(target)
  848. data = self.decode_attr_elements(edge_attr, edge_element)
  849. data = self.add_start_end(data, edge_element)
  850. if self.VERSION == "1.1":
  851. data = self.add_slices(data, edge_element) # add slices
  852. else:
  853. data = self.add_spells(data, edge_element) # add spells
  854. # GEXF stores edge ids as an attribute
  855. # NetworkX uses them as keys in multigraphs
  856. # if networkx_key is not specified as an attribute
  857. edge_id = edge_element.get("id")
  858. if edge_id is not None:
  859. data["id"] = edge_id
  860. # check if there is a 'multigraph_key' and use that as edge_id
  861. multigraph_key = data.pop("networkx_key", None)
  862. if multigraph_key is not None:
  863. edge_id = multigraph_key
  864. weight = edge_element.get("weight")
  865. if weight is not None:
  866. data["weight"] = float(weight)
  867. edge_label = edge_element.get("label")
  868. if edge_label is not None:
  869. data["label"] = edge_label
  870. if G.has_edge(source, target):
  871. # seen this edge before - this is a multigraph
  872. self.simple_graph = False
  873. G.add_edge(source, target, key=edge_id, **data)
  874. if edge_direction == "mutual":
  875. G.add_edge(target, source, key=edge_id, **data)
  876. def decode_attr_elements(self, gexf_keys, obj_xml):
  877. # Use the key information to decode the attr XML
  878. attr = {}
  879. # look for outer '<attvalues>' element
  880. attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues")
  881. if attr_element is not None:
  882. # loop over <attvalue> elements
  883. for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"):
  884. key = a.get("for") # for is required
  885. try: # should be in our gexf_keys dictionary
  886. title = gexf_keys[key]["title"]
  887. except KeyError as err:
  888. raise nx.NetworkXError(f"No attribute defined for={key}.") from err
  889. atype = gexf_keys[key]["type"]
  890. value = a.get("value")
  891. if atype == "boolean":
  892. value = self.convert_bool[value]
  893. else:
  894. value = self.python_type[atype](value)
  895. if gexf_keys[key]["mode"] == "dynamic":
  896. # for dynamic graphs use list of three-tuples
  897. # [(value1,start1,end1), (value2,start2,end2), etc]
  898. ttype = self.timeformat
  899. start = self.python_type[ttype](a.get("start"))
  900. end = self.python_type[ttype](a.get("end"))
  901. if title in attr:
  902. attr[title].append((value, start, end))
  903. else:
  904. attr[title] = [(value, start, end)]
  905. else:
  906. # for static graphs just assign the value
  907. attr[title] = value
  908. return attr
  909. def find_gexf_attributes(self, attributes_element):
  910. # Extract all the attributes and defaults
  911. attrs = {}
  912. defaults = {}
  913. mode = attributes_element.get("mode")
  914. for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"):
  915. attr_id = k.get("id")
  916. title = k.get("title")
  917. atype = k.get("type")
  918. attrs[attr_id] = {"title": title, "type": atype, "mode": mode}
  919. # check for the 'default' subelement of key element and add
  920. default = k.find(f"{{{self.NS_GEXF}}}default")
  921. if default is not None:
  922. if atype == "boolean":
  923. value = self.convert_bool[default.text]
  924. else:
  925. value = self.python_type[atype](default.text)
  926. defaults[title] = value
  927. return attrs, defaults
  928. def relabel_gexf_graph(G):
  929. """Relabel graph using "label" node keyword for node label.
  930. Parameters
  931. ----------
  932. G : graph
  933. A NetworkX graph read from GEXF data
  934. Returns
  935. -------
  936. H : graph
  937. A NetworkX graph with relabeled nodes
  938. Raises
  939. ------
  940. NetworkXError
  941. If node labels are missing or not unique while relabel=True.
  942. Notes
  943. -----
  944. This function relabels the nodes in a NetworkX graph with the
  945. "label" attribute. It also handles relabeling the specific GEXF
  946. node attributes "parents", and "pid".
  947. """
  948. # build mapping of node labels, do some error checking
  949. try:
  950. mapping = [(u, G.nodes[u]["label"]) for u in G]
  951. except KeyError as err:
  952. raise nx.NetworkXError(
  953. "Failed to relabel nodes: missing node labels found. Use relabel=False."
  954. ) from err
  955. x, y = zip(*mapping)
  956. if len(set(y)) != len(G):
  957. raise nx.NetworkXError(
  958. "Failed to relabel nodes: duplicate node labels found. Use relabel=False."
  959. )
  960. mapping = dict(mapping)
  961. H = nx.relabel_nodes(G, mapping)
  962. # relabel attributes
  963. for n in G:
  964. m = mapping[n]
  965. H.nodes[m]["id"] = n
  966. H.nodes[m].pop("label")
  967. if "pid" in H.nodes[m]:
  968. H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]]
  969. if "parents" in H.nodes[m]:
  970. H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]]
  971. return H