test_backends.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import pickle
  2. import pytest
  3. import networkx as nx
  4. sp = pytest.importorskip("scipy")
  5. pytest.importorskip("numpy")
  6. @nx._dispatchable(implemented_by_nx=False)
  7. def _stub_func(G):
  8. raise NotImplementedError("_stub_func is a stub")
  9. def test_dispatch_kwds_vs_args():
  10. G = nx.path_graph(4)
  11. nx.pagerank(G)
  12. nx.pagerank(G=G)
  13. with pytest.raises(TypeError):
  14. nx.pagerank()
  15. def test_pickle():
  16. count = 0
  17. for name, func in nx.utils.backends._registered_algorithms.items():
  18. pickled = pickle.dumps(func.__wrapped__)
  19. assert pickle.loads(pickled) is func.__wrapped__
  20. try:
  21. # Some functions can't be pickled, but it's not b/c of _dispatchable
  22. pickled = pickle.dumps(func)
  23. except pickle.PicklingError:
  24. continue
  25. assert pickle.loads(pickled) is func
  26. count += 1
  27. assert count > 0
  28. assert pickle.loads(pickle.dumps(nx.inverse_line_graph)) is nx.inverse_line_graph
  29. @pytest.mark.skipif(
  30. "not nx.config.backend_priority.algos "
  31. "or nx.config.backend_priority.algos[0] != 'nx_loopback'"
  32. )
  33. def test_graph_converter_needs_backend():
  34. # When testing, `nx.from_scipy_sparse_array` will *always* call the backend
  35. # implementation if it's implemented. If `backend=` isn't given, then the result
  36. # will be converted back to NetworkX via `convert_to_nx`.
  37. # If not testing, then calling `nx.from_scipy_sparse_array` w/o `backend=` will
  38. # always call the original version. `backend=` is *required* to call the backend.
  39. from networkx.classes.tests.dispatch_interface import (
  40. LoopbackBackendInterface,
  41. LoopbackGraph,
  42. )
  43. A = sp.sparse.coo_array([[0, 3, 2], [3, 0, 1], [2, 1, 0]])
  44. side_effects = []
  45. def from_scipy_sparse_array(self, *args, **kwargs):
  46. side_effects.append(1) # Just to prove this was called
  47. return self.convert_from_nx(
  48. self.__getattr__("from_scipy_sparse_array")(*args, **kwargs),
  49. preserve_edge_attrs=True,
  50. preserve_node_attrs=True,
  51. preserve_graph_attrs=True,
  52. )
  53. @staticmethod
  54. def convert_to_nx(obj, *, name=None):
  55. if type(obj) is nx.Graph:
  56. return obj
  57. return nx.Graph(obj)
  58. # *This mutates LoopbackBackendInterface!*
  59. orig_convert_to_nx = LoopbackBackendInterface.convert_to_nx
  60. LoopbackBackendInterface.convert_to_nx = convert_to_nx
  61. LoopbackBackendInterface.from_scipy_sparse_array = from_scipy_sparse_array
  62. try:
  63. assert side_effects == []
  64. assert type(nx.from_scipy_sparse_array(A)) is nx.Graph
  65. assert side_effects == [1]
  66. assert (
  67. type(nx.from_scipy_sparse_array(A, backend="nx_loopback")) is LoopbackGraph
  68. )
  69. assert side_effects == [1, 1]
  70. # backend="networkx" is default implementation
  71. assert type(nx.from_scipy_sparse_array(A, backend="networkx")) is nx.Graph
  72. assert side_effects == [1, 1]
  73. finally:
  74. LoopbackBackendInterface.convert_to_nx = staticmethod(orig_convert_to_nx)
  75. del LoopbackBackendInterface.from_scipy_sparse_array
  76. with pytest.raises(ImportError, match="backend is not installed"):
  77. nx.from_scipy_sparse_array(A, backend="bad-backend-name")
  78. @pytest.mark.skipif(
  79. "not nx.config.backend_priority.algos "
  80. "or nx.config.backend_priority.algos[0] != 'nx_loopback'"
  81. )
  82. def test_networkx_backend():
  83. """Test using `backend="networkx"` in a dispatchable function."""
  84. # (Implementing this test is harder than it should be)
  85. from networkx.classes.tests.dispatch_interface import (
  86. LoopbackBackendInterface,
  87. LoopbackGraph,
  88. )
  89. G = LoopbackGraph()
  90. G.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 4)])
  91. @staticmethod
  92. def convert_to_nx(obj, *, name=None):
  93. if isinstance(obj, LoopbackGraph):
  94. new_graph = nx.Graph()
  95. new_graph.__dict__.update(obj.__dict__)
  96. return new_graph
  97. return obj
  98. # *This mutates LoopbackBackendInterface!*
  99. # This uses the same trick as in the previous test.
  100. orig_convert_to_nx = LoopbackBackendInterface.convert_to_nx
  101. LoopbackBackendInterface.convert_to_nx = convert_to_nx
  102. try:
  103. G2 = nx.ego_graph(G, 0, backend="networkx")
  104. assert type(G2) is nx.Graph
  105. finally:
  106. LoopbackBackendInterface.convert_to_nx = staticmethod(orig_convert_to_nx)
  107. def test_dispatchable_are_functions():
  108. assert type(nx.pagerank) is type(nx.pagerank.orig_func)
  109. @pytest.mark.skipif("not nx.utils.backends.backends")
  110. def test_mixing_backend_graphs():
  111. from networkx.classes.tests import dispatch_interface
  112. G = nx.Graph()
  113. G.add_edge(1, 2)
  114. G.add_edge(2, 3)
  115. H = nx.Graph()
  116. H.add_edge(2, 3)
  117. rv = nx.intersection(G, H)
  118. assert set(nx.intersection(G, H)) == {2, 3}
  119. G2 = dispatch_interface.convert(G)
  120. H2 = dispatch_interface.convert(H)
  121. if "nx_loopback" in nx.config.backend_priority:
  122. # Auto-convert
  123. assert set(nx.intersection(G2, H)) == {2, 3}
  124. assert set(nx.intersection(G, H2)) == {2, 3}
  125. elif not nx.config.backend_priority and "nx_loopback" not in nx.config.backends:
  126. # G2 and H2 are backend objects for a backend that is not registered!
  127. with pytest.raises(ImportError, match="backend is not installed"):
  128. nx.intersection(G2, H)
  129. with pytest.raises(ImportError, match="backend is not installed"):
  130. nx.intersection(G, H2)
  131. # It would be nice to test passing graphs from *different* backends,
  132. # but we are not set up to do this yet.
  133. def test_bad_backend_name():
  134. """Using `backend=` raises with unknown backend even if there are no backends."""
  135. with pytest.raises(
  136. ImportError, match="'this_backend_does_not_exist' backend is not installed"
  137. ):
  138. nx.null_graph(backend="this_backend_does_not_exist")
  139. def test_not_implemented_by_nx():
  140. assert "networkx" in nx.pagerank.backends
  141. assert "networkx" not in _stub_func.backends
  142. if "nx_loopback" in nx.config.backends:
  143. from networkx.classes.tests.dispatch_interface import LoopbackBackendInterface
  144. def stub_func_implementation(G):
  145. return True
  146. LoopbackBackendInterface._stub_func = staticmethod(stub_func_implementation)
  147. try:
  148. assert _stub_func(nx.Graph()) is True
  149. finally:
  150. del LoopbackBackendInterface._stub_func
  151. with pytest.raises(NotImplementedError):
  152. _stub_func(nx.Graph())
  153. @pytest.mark.skipif(
  154. "not nx.config.backend_priority.algos "
  155. "or nx.config.backend_priority.algos[0] != 'nx_loopback'"
  156. )
  157. def test_dispatch_graph_new():
  158. from networkx.classes.tests.dispatch_interface import LoopbackGraph
  159. G = nx.Graph()
  160. assert not isinstance(G, LoopbackGraph)
  161. # `backend=` argument that gets passed to __init__ is ignored.
  162. # Best practice is that it should not be in the `.graph` dict.
  163. G = nx.Graph(backend="networkx")
  164. assert type(G) is nx.Graph
  165. assert "backend" not in G.graph
  166. G = nx.Graph(backend="nx_loopback")
  167. assert isinstance(G, LoopbackGraph)
  168. assert "backend" not in G.graph
  169. # Args are passed
  170. G1 = nx.Graph([(0, 1), (1, 2)])
  171. assert not isinstance(G1, LoopbackGraph)
  172. G2 = nx.Graph([(0, 1), (1, 2)], backend="nx_loopback")
  173. assert isinstance(G2, LoopbackGraph)
  174. assert nx.utils.misc.graphs_equal(G1, G2)
  175. # Test config for automatic usage
  176. with nx.config.backend_priority(classes=["nx_loopback"]):
  177. G = nx.Graph()
  178. assert isinstance(G, LoopbackGraph)
  179. # LoopbackDiGraph __new__ is not implemented
  180. G = nx.DiGraph()
  181. assert not isinstance(G, LoopbackGraph)
  182. G = nx.Graph()
  183. assert not isinstance(G, LoopbackGraph)