node_classification.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. """This module provides the functions for node classification problem.
  2. The functions in this module are not imported
  3. into the top level `networkx` namespace.
  4. You can access these functions by importing
  5. the `networkx.algorithms.node_classification` modules,
  6. then accessing the functions as attributes of `node_classification`.
  7. For example:
  8. >>> from networkx.algorithms import node_classification
  9. >>> G = nx.path_graph(4)
  10. >>> G.edges()
  11. EdgeView([(0, 1), (1, 2), (2, 3)])
  12. >>> G.nodes[0]["label"] = "A"
  13. >>> G.nodes[3]["label"] = "B"
  14. >>> node_classification.harmonic_function(G)
  15. ['A', 'A', 'B', 'B']
  16. References
  17. ----------
  18. Zhu, X., Ghahramani, Z., & Lafferty, J. (2003, August).
  19. Semi-supervised learning using gaussian fields and harmonic functions.
  20. In ICML (Vol. 3, pp. 912-919).
  21. """
  22. import networkx as nx
  23. __all__ = ["harmonic_function", "local_and_global_consistency"]
  24. @nx.utils.not_implemented_for("directed")
  25. @nx._dispatchable(node_attrs="label_name")
  26. def harmonic_function(G, max_iter=30, label_name="label"):
  27. """Node classification by Harmonic function
  28. Function for computing Harmonic function algorithm by Zhu et al.
  29. Parameters
  30. ----------
  31. G : NetworkX Graph
  32. max_iter : int
  33. maximum number of iterations allowed
  34. label_name : string
  35. name of target labels to predict
  36. Returns
  37. -------
  38. predicted : list
  39. List of length ``len(G)`` with the predicted labels for each node.
  40. Raises
  41. ------
  42. NetworkXError
  43. If no nodes in `G` have attribute `label_name`.
  44. Examples
  45. --------
  46. >>> from networkx.algorithms import node_classification
  47. >>> G = nx.path_graph(4)
  48. >>> G.nodes[0]["label"] = "A"
  49. >>> G.nodes[3]["label"] = "B"
  50. >>> G.nodes(data=True)
  51. NodeDataView({0: {'label': 'A'}, 1: {}, 2: {}, 3: {'label': 'B'}})
  52. >>> G.edges()
  53. EdgeView([(0, 1), (1, 2), (2, 3)])
  54. >>> predicted = node_classification.harmonic_function(G)
  55. >>> predicted
  56. ['A', 'A', 'B', 'B']
  57. References
  58. ----------
  59. Zhu, X., Ghahramani, Z., & Lafferty, J. (2003, August).
  60. Semi-supervised learning using gaussian fields and harmonic functions.
  61. In ICML (Vol. 3, pp. 912-919).
  62. """
  63. import numpy as np
  64. import scipy as sp
  65. X = nx.to_scipy_sparse_array(G) # adjacency matrix
  66. labels, label_dict = _get_label_info(G, label_name)
  67. if labels.shape[0] == 0:
  68. raise nx.NetworkXError(
  69. f"No node on the input graph is labeled by '{label_name}'."
  70. )
  71. n_samples = X.shape[0]
  72. n_classes = label_dict.shape[0]
  73. F = np.zeros((n_samples, n_classes))
  74. # Build propagation matrix
  75. degrees = X.sum(axis=0)
  76. degrees[degrees == 0] = 1 # Avoid division by 0
  77. D = sp.sparse.dia_array((1.0 / degrees, 0), shape=(n_samples, n_samples)).tocsr()
  78. P = (D @ X).tolil()
  79. P[labels[:, 0]] = 0 # labels[:, 0] indicates IDs of labeled nodes
  80. # Build base matrix
  81. B = np.zeros((n_samples, n_classes))
  82. B[labels[:, 0], labels[:, 1]] = 1
  83. for _ in range(max_iter):
  84. F = (P @ F) + B
  85. return label_dict[np.argmax(F, axis=1)].tolist()
  86. @nx.utils.not_implemented_for("directed")
  87. @nx._dispatchable(node_attrs="label_name")
  88. def local_and_global_consistency(G, alpha=0.99, max_iter=30, label_name="label"):
  89. """Node classification by Local and Global Consistency
  90. Function for computing Local and global consistency algorithm by Zhou et al.
  91. Parameters
  92. ----------
  93. G : NetworkX Graph
  94. alpha : float
  95. Clamping factor
  96. max_iter : int
  97. Maximum number of iterations allowed
  98. label_name : string
  99. Name of target labels to predict
  100. Returns
  101. -------
  102. predicted : list
  103. List of length ``len(G)`` with the predicted labels for each node.
  104. Raises
  105. ------
  106. NetworkXError
  107. If no nodes in `G` have attribute `label_name`.
  108. Examples
  109. --------
  110. >>> from networkx.algorithms import node_classification
  111. >>> G = nx.path_graph(4)
  112. >>> G.nodes[0]["label"] = "A"
  113. >>> G.nodes[3]["label"] = "B"
  114. >>> G.nodes(data=True)
  115. NodeDataView({0: {'label': 'A'}, 1: {}, 2: {}, 3: {'label': 'B'}})
  116. >>> G.edges()
  117. EdgeView([(0, 1), (1, 2), (2, 3)])
  118. >>> predicted = node_classification.local_and_global_consistency(G)
  119. >>> predicted
  120. ['A', 'A', 'B', 'B']
  121. References
  122. ----------
  123. Zhou, D., Bousquet, O., Lal, T. N., Weston, J., & Schölkopf, B. (2004).
  124. Learning with local and global consistency.
  125. Advances in neural information processing systems, 16(16), 321-328.
  126. """
  127. import numpy as np
  128. import scipy as sp
  129. X = nx.to_scipy_sparse_array(G) # adjacency matrix
  130. labels, label_dict = _get_label_info(G, label_name)
  131. if labels.shape[0] == 0:
  132. raise nx.NetworkXError(
  133. f"No node on the input graph is labeled by '{label_name}'."
  134. )
  135. n_samples = X.shape[0]
  136. n_classes = label_dict.shape[0]
  137. F = np.zeros((n_samples, n_classes))
  138. # Build propagation matrix
  139. degrees = X.sum(axis=0)
  140. degrees[degrees == 0] = 1 # Avoid division by 0
  141. D2 = sp.sparse.dia_array(
  142. (1.0 / np.sqrt(degrees), 0), shape=(n_samples, n_samples)
  143. ).tocsr()
  144. P = alpha * ((D2 @ X) @ D2)
  145. # Build base matrix
  146. B = np.zeros((n_samples, n_classes))
  147. B[labels[:, 0], labels[:, 1]] = 1 - alpha
  148. for _ in range(max_iter):
  149. F = (P @ F) + B
  150. return label_dict[np.argmax(F, axis=1)].tolist()
  151. def _get_label_info(G, label_name):
  152. """Get and return information of labels from the input graph
  153. Parameters
  154. ----------
  155. G : Network X graph
  156. label_name : string
  157. Name of the target label
  158. Returns
  159. -------
  160. labels : numpy array, shape = [n_labeled_samples, 2]
  161. Array of pairs of labeled node ID and label ID
  162. label_dict : numpy array, shape = [n_classes]
  163. Array of labels
  164. i-th element contains the label corresponding label ID `i`
  165. """
  166. import numpy as np
  167. labels = []
  168. label_to_id = {}
  169. lid = 0
  170. for i, n in enumerate(G.nodes(data=True)):
  171. if label_name in n[1]:
  172. label = n[1][label_name]
  173. if label not in label_to_id:
  174. label_to_id[label] = lid
  175. lid += 1
  176. labels.append([i, label_to_id[label]])
  177. labels = np.array(labels)
  178. label_dict = np.array(
  179. [label for label, _ in sorted(label_to_id.items(), key=lambda x: x[1])]
  180. )
  181. return (labels, label_dict)