dispatch_interface.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # This file contains utilities for testing the dispatching feature
  2. # A full test of all dispatchable algorithms is performed by
  3. # modifying the pytest invocation and setting an environment variable
  4. # NETWORKX_TEST_BACKEND=nx_loopback pytest
  5. # This is comprehensive, but only tests the `test_override_dispatch`
  6. # function in networkx.classes.backends.
  7. # To test the `_dispatchable` function directly, several tests scattered throughout
  8. # NetworkX have been augmented to test normal and dispatch mode.
  9. # Searching for `dispatch_interface` should locate the specific tests.
  10. import networkx as nx
  11. from networkx import DiGraph, Graph, MultiDiGraph, MultiGraph, PlanarEmbedding
  12. from networkx.classes.reportviews import NodeView
  13. class LoopbackGraph(Graph):
  14. __networkx_backend__ = "nx_loopback"
  15. class LoopbackDiGraph(DiGraph):
  16. __networkx_backend__ = "nx_loopback"
  17. class LoopbackMultiGraph(MultiGraph):
  18. __networkx_backend__ = "nx_loopback"
  19. class LoopbackMultiDiGraph(MultiDiGraph):
  20. __networkx_backend__ = "nx_loopback"
  21. class LoopbackPlanarEmbedding(PlanarEmbedding):
  22. __networkx_backend__ = "nx_loopback"
  23. def convert(graph):
  24. if isinstance(graph, PlanarEmbedding):
  25. return LoopbackPlanarEmbedding(graph)
  26. if isinstance(graph, MultiDiGraph):
  27. return LoopbackMultiDiGraph(graph)
  28. if isinstance(graph, MultiGraph):
  29. return LoopbackMultiGraph(graph)
  30. if isinstance(graph, DiGraph):
  31. return LoopbackDiGraph(graph)
  32. if isinstance(graph, Graph):
  33. return LoopbackGraph(graph)
  34. raise TypeError(f"Unsupported type of graph: {type(graph)}")
  35. class LoopbackBackendInterface:
  36. def __getattr__(self, item):
  37. try:
  38. return nx.utils.backends._registered_algorithms[item].orig_func
  39. except KeyError:
  40. raise AttributeError(item) from None
  41. @staticmethod
  42. def graph__new__(cls, incoming_graph_data=None, **attr):
  43. # LoopbackGraph.__init__ will be called next since the returned
  44. # object is an instance of an nx.Graph. For more details, see:
  45. # https://docs.python.org/3/reference/datamodel.html#object.__new__
  46. return object.__new__(LoopbackGraph)
  47. @staticmethod
  48. def convert_from_nx(
  49. graph,
  50. *,
  51. edge_attrs=None,
  52. node_attrs=None,
  53. preserve_edge_attrs=None,
  54. preserve_node_attrs=None,
  55. preserve_graph_attrs=None,
  56. name=None,
  57. graph_name=None,
  58. ):
  59. if name in {
  60. # Raise if input graph changes. See test_dag.py::test_topological_sort6
  61. "lexicographical_topological_sort",
  62. "topological_generations",
  63. "topological_sort",
  64. # Would be nice to some day avoid these cutoffs of full testing
  65. }:
  66. return graph
  67. if isinstance(graph, NodeView):
  68. # Convert to a Graph with only nodes (no edges)
  69. new_graph = Graph()
  70. new_graph.add_nodes_from(graph.items())
  71. graph = new_graph
  72. G = LoopbackGraph()
  73. elif not isinstance(graph, Graph):
  74. raise TypeError(
  75. f"Bad type for graph argument {graph_name} in {name}: {type(graph)}"
  76. )
  77. elif graph.__class__ in {Graph, LoopbackGraph}:
  78. G = LoopbackGraph()
  79. elif graph.__class__ in {DiGraph, LoopbackDiGraph}:
  80. G = LoopbackDiGraph()
  81. elif graph.__class__ in {MultiGraph, LoopbackMultiGraph}:
  82. G = LoopbackMultiGraph()
  83. elif graph.__class__ in {MultiDiGraph, LoopbackMultiDiGraph}:
  84. G = LoopbackMultiDiGraph()
  85. elif graph.__class__ in {PlanarEmbedding, LoopbackPlanarEmbedding}:
  86. G = LoopbackDiGraph() # or LoopbackPlanarEmbedding
  87. else:
  88. # Would be nice to handle these better some day
  89. # nx.algorithms.approximation.kcomponents._AntiGraph
  90. # nx.classes.tests.test_multidigraph.MultiDiGraphSubClass
  91. # nx.classes.tests.test_multigraph.MultiGraphSubClass
  92. G = graph.__class__()
  93. if preserve_graph_attrs:
  94. G.graph.update(graph.graph)
  95. # add nodes
  96. G.add_nodes_from(graph)
  97. if preserve_node_attrs:
  98. for n, dd in G._node.items():
  99. dd.update(graph.nodes[n])
  100. elif node_attrs:
  101. for n, dd in G._node.items():
  102. dd.update(
  103. (attr, graph._node[n].get(attr, default))
  104. for attr, default in node_attrs.items()
  105. if default is not None or attr in graph._node[n]
  106. )
  107. # tools to build datadict and keydict
  108. if preserve_edge_attrs:
  109. def G_new_datadict(old_dd):
  110. return G.edge_attr_dict_factory(old_dd)
  111. elif edge_attrs:
  112. def G_new_datadict(old_dd):
  113. return G.edge_attr_dict_factory(
  114. (attr, old_dd.get(attr, default))
  115. for attr, default in edge_attrs.items()
  116. if default is not None or attr in old_dd
  117. )
  118. else:
  119. def G_new_datadict(old_dd):
  120. return G.edge_attr_dict_factory()
  121. if G.is_multigraph():
  122. def G_new_inner(keydict):
  123. kd = G.adjlist_inner_dict_factory(
  124. (k, G_new_datadict(dd)) for k, dd in keydict.items()
  125. )
  126. return kd
  127. else:
  128. G_new_inner = G_new_datadict
  129. # add edges keeping the same order in _adj and _pred
  130. G_adj = G._adj
  131. if G.is_directed():
  132. for n, nbrs in graph._adj.items():
  133. G_adj[n].update((nbr, G_new_inner(dd)) for nbr, dd in nbrs.items())
  134. # ensure same datadict for pred and adj; and pred order of graph._pred
  135. G_pred = G._pred
  136. for n, nbrs in graph._pred.items():
  137. G_pred[n].update((nbr, G_adj[nbr][n]) for nbr in nbrs)
  138. else: # undirected
  139. for n, nbrs in graph._adj.items():
  140. # ensure same datadict for both ways; and adj order of graph._adj
  141. G_adj[n].update(
  142. (nbr, G_adj[nbr][n] if n in G_adj[nbr] else G_new_inner(dd))
  143. for nbr, dd in nbrs.items()
  144. )
  145. return G
  146. @staticmethod
  147. def convert_to_nx(obj, *, name=None):
  148. return obj
  149. @staticmethod
  150. def on_start_tests(items):
  151. # Verify that items can be xfailed
  152. for item in items:
  153. assert hasattr(item, "add_marker")
  154. def can_run(self, name, args, kwargs):
  155. # It is unnecessary to define this function if algorithms are fully supported.
  156. # We include it for illustration purposes.
  157. return hasattr(self, name)
  158. backend_interface = LoopbackBackendInterface()