tree.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from itertools import chain
  2. import networkx as nx
  3. __all__ = ["tree_data", "tree_graph"]
  4. def tree_data(G, root, ident="id", children="children"):
  5. """Returns data in tree format that is suitable for JSON serialization
  6. and use in JavaScript documents.
  7. Parameters
  8. ----------
  9. G : NetworkX graph
  10. G must be an oriented tree
  11. root : node
  12. The root of the tree
  13. ident : string
  14. Attribute name for storing NetworkX-internal graph data. `ident` must
  15. have a different value than `children`. The default is 'id'.
  16. children : string
  17. Attribute name for storing NetworkX-internal graph data. `children`
  18. must have a different value than `ident`. The default is 'children'.
  19. Returns
  20. -------
  21. data : dict
  22. A dictionary with node-link formatted data.
  23. Raises
  24. ------
  25. NetworkXError
  26. If `children` and `ident` attributes are identical.
  27. Examples
  28. --------
  29. >>> from networkx.readwrite import json_graph
  30. >>> G = nx.DiGraph([(1, 2)])
  31. >>> data = json_graph.tree_data(G, root=1)
  32. To serialize with json
  33. >>> import json
  34. >>> s = json.dumps(data)
  35. Notes
  36. -----
  37. Node attributes are stored in this format but keys
  38. for attributes must be strings if you want to serialize with JSON.
  39. Graph and edge attributes are not stored.
  40. See Also
  41. --------
  42. tree_graph, node_link_data, adjacency_data
  43. """
  44. if G.number_of_nodes() != G.number_of_edges() + 1:
  45. raise TypeError("G is not a tree.")
  46. if not G.is_directed():
  47. raise TypeError("G is not directed.")
  48. if not nx.is_weakly_connected(G):
  49. raise TypeError("G is not weakly connected.")
  50. if ident == children:
  51. raise nx.NetworkXError("The values for `id` and `children` must be different.")
  52. def add_children(n, G):
  53. nbrs = G[n]
  54. if len(nbrs) == 0:
  55. return []
  56. children_ = []
  57. for child in nbrs:
  58. d = {**G.nodes[child], ident: child}
  59. c = add_children(child, G)
  60. if c:
  61. d[children] = c
  62. children_.append(d)
  63. return children_
  64. return {**G.nodes[root], ident: root, children: add_children(root, G)}
  65. @nx._dispatchable(graphs=None, returns_graph=True)
  66. def tree_graph(data, ident="id", children="children"):
  67. """Returns graph from tree data format.
  68. Parameters
  69. ----------
  70. data : dict
  71. Tree formatted graph data
  72. ident : string
  73. Attribute name for storing NetworkX-internal graph data. `ident` must
  74. have a different value than `children`. The default is 'id'.
  75. children : string
  76. Attribute name for storing NetworkX-internal graph data. `children`
  77. must have a different value than `ident`. The default is 'children'.
  78. Returns
  79. -------
  80. G : NetworkX DiGraph
  81. Examples
  82. --------
  83. >>> from networkx.readwrite import json_graph
  84. >>> G = nx.DiGraph([(1, 2)])
  85. >>> data = json_graph.tree_data(G, root=1)
  86. >>> H = json_graph.tree_graph(data)
  87. See Also
  88. --------
  89. tree_data, node_link_data, adjacency_data
  90. """
  91. graph = nx.DiGraph()
  92. def add_children(parent, children_):
  93. for data in children_:
  94. child = data[ident]
  95. graph.add_edge(parent, child)
  96. grandchildren = data.get(children, [])
  97. if grandchildren:
  98. add_children(child, grandchildren)
  99. nodedata = {
  100. str(k): v for k, v in data.items() if k != ident and k != children
  101. }
  102. graph.add_node(child, **nodedata)
  103. root = data[ident]
  104. children_ = data.get(children, [])
  105. nodedata = {str(k): v for k, v in data.items() if k != ident and k != children}
  106. graph.add_node(root, **nodedata)
  107. add_children(root, children_)
  108. return graph