_validation.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import numpy as np
  2. from scipy.sparse import issparse
  3. from scipy.sparse._sputils import convert_pydata_sparse_to_scipy
  4. from scipy.sparse.csgraph._tools import (
  5. csgraph_to_dense, csgraph_from_dense,
  6. csgraph_masked_from_dense, csgraph_from_masked
  7. )
  8. DTYPE = np.float64
  9. def validate_graph(csgraph, directed, dtype=DTYPE,
  10. csr_output=True, dense_output=True,
  11. copy_if_dense=False, copy_if_sparse=False,
  12. null_value_in=0, null_value_out=np.inf,
  13. infinity_null=True, nan_null=True):
  14. """Routine for validation and conversion of csgraph inputs"""
  15. if not (csr_output or dense_output):
  16. raise ValueError("Internal: dense or csr output must be true")
  17. accept_fv = [null_value_in]
  18. if infinity_null:
  19. accept_fv.append(np.inf)
  20. if nan_null:
  21. accept_fv.append(np.nan)
  22. csgraph = convert_pydata_sparse_to_scipy(csgraph, accept_fv=accept_fv)
  23. # if undirected and csc storage, then transposing in-place
  24. # is quicker than later converting to csr.
  25. if (not directed) and issparse(csgraph) and csgraph.format == "csc":
  26. csgraph = csgraph.T
  27. if issparse(csgraph):
  28. if csr_output:
  29. csgraph = csgraph.tocsr(copy=copy_if_sparse).astype(DTYPE, copy=False)
  30. else:
  31. csgraph = csgraph_to_dense(csgraph, null_value=null_value_out)
  32. elif np.ma.isMaskedArray(csgraph):
  33. if dense_output:
  34. mask = csgraph.mask
  35. csgraph = np.array(csgraph.data, dtype=DTYPE, copy=copy_if_dense)
  36. csgraph[mask] = null_value_out
  37. else:
  38. csgraph = csgraph_from_masked(csgraph)
  39. else:
  40. if dense_output:
  41. csgraph = csgraph_masked_from_dense(csgraph,
  42. copy=copy_if_dense,
  43. null_value=null_value_in,
  44. nan_null=nan_null,
  45. infinity_null=infinity_null)
  46. mask = csgraph.mask
  47. csgraph = np.asarray(csgraph.data, dtype=DTYPE)
  48. csgraph[mask] = null_value_out
  49. else:
  50. csgraph = csgraph_from_dense(csgraph, null_value=null_value_in,
  51. infinity_null=infinity_null,
  52. nan_null=nan_null)
  53. if csgraph.ndim != 2:
  54. raise ValueError("compressed-sparse graph must be 2-D")
  55. if csgraph.shape[0] != csgraph.shape[1]:
  56. raise ValueError("compressed-sparse graph must be shape (N, N)")
  57. return csgraph