_matrix_io.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import numpy as np
  2. import scipy as sp
  3. __all__ = ['save_npz', 'load_npz']
  4. # Make loading safe vs. malicious input
  5. PICKLE_KWARGS = dict(allow_pickle=False)
  6. def save_npz(file, matrix, compressed=True):
  7. """ Save a sparse matrix or array to a file using ``.npz`` format.
  8. Parameters
  9. ----------
  10. file : str or file-like object
  11. Either the file name (string) or an open file (file-like object)
  12. where the data will be saved. If file is a string, the ``.npz``
  13. extension will be appended to the file name if it is not already
  14. there.
  15. matrix: spmatrix or sparray
  16. The sparse matrix or array to save.
  17. Supported formats: ``csc``, ``csr``, ``bsr``, ``dia`` or ``coo``.
  18. compressed : bool, optional
  19. Allow compressing the file. Default: True
  20. See Also
  21. --------
  22. scipy.sparse.load_npz: Load a sparse matrix from a file using ``.npz`` format.
  23. numpy.savez: Save several arrays into a ``.npz`` archive.
  24. numpy.savez_compressed : Save several arrays into a compressed ``.npz`` archive.
  25. Examples
  26. --------
  27. Store sparse matrix to disk, and load it again:
  28. >>> import numpy as np
  29. >>> import scipy as sp
  30. >>> sparse_matrix = sp.sparse.csc_matrix([[0, 0, 3], [4, 0, 0]])
  31. >>> sparse_matrix
  32. <Compressed Sparse Column sparse matrix of dtype 'int64'
  33. with 2 stored elements and shape (2, 3)>
  34. >>> sparse_matrix.toarray()
  35. array([[0, 0, 3],
  36. [4, 0, 0]], dtype=int64)
  37. >>> sp.sparse.save_npz('/tmp/sparse_matrix.npz', sparse_matrix)
  38. >>> sparse_matrix = sp.sparse.load_npz('/tmp/sparse_matrix.npz')
  39. >>> sparse_matrix
  40. <Compressed Sparse Column sparse matrix of dtype 'int64'
  41. with 2 stored elements and shape (2, 3)>
  42. >>> sparse_matrix.toarray()
  43. array([[0, 0, 3],
  44. [4, 0, 0]], dtype=int64)
  45. """
  46. arrays_dict = {}
  47. if matrix.format in ('csc', 'csr', 'bsr'):
  48. arrays_dict.update(indices=matrix.indices, indptr=matrix.indptr)
  49. elif matrix.format == 'dia':
  50. arrays_dict.update(offsets=matrix.offsets)
  51. elif matrix.format == 'coo':
  52. if matrix.ndim == 2:
  53. # TODO: After a few releases, switch 2D case to save with coords only.
  54. arrays_dict.update(row=matrix.row, col=matrix.col)
  55. else:
  56. arrays_dict.update(coords=matrix.coords)
  57. else:
  58. msg = f'Save is not implemented for sparse matrix of format {matrix.format}.'
  59. raise NotImplementedError(msg)
  60. arrays_dict.update(
  61. format=matrix.format.encode('ascii'),
  62. shape=matrix.shape,
  63. data=matrix.data
  64. )
  65. if isinstance(matrix, sp.sparse.sparray):
  66. arrays_dict.update(_is_array=True)
  67. if compressed:
  68. np.savez_compressed(file, **arrays_dict)
  69. else:
  70. np.savez(file, **arrays_dict)
  71. def load_npz(file):
  72. """ Load a sparse array/matrix from a file using ``.npz`` format.
  73. Parameters
  74. ----------
  75. file : str or file-like object
  76. Either the file name (string) or an open file (file-like object)
  77. where the data will be loaded.
  78. Returns
  79. -------
  80. result : csc_array, csr_array, bsr_array, dia_array or coo_array
  81. A sparse array/matrix containing the loaded data.
  82. Raises
  83. ------
  84. OSError
  85. If the input file does not exist or cannot be read.
  86. See Also
  87. --------
  88. scipy.sparse.save_npz: Save a sparse array/matrix to a file using ``.npz`` format.
  89. numpy.load: Load several arrays from a ``.npz`` archive.
  90. Examples
  91. --------
  92. Store sparse array/matrix to disk, and load it again:
  93. >>> import numpy as np
  94. >>> import scipy as sp
  95. >>> sparse_array = sp.sparse.csc_array([[0, 0, 3], [4, 0, 0]])
  96. >>> sparse_array
  97. <Compressed Sparse Column sparse array of dtype 'int64'
  98. with 2 stored elements and shape (2, 3)>
  99. >>> sparse_array.toarray()
  100. array([[0, 0, 3],
  101. [4, 0, 0]], dtype=int64)
  102. >>> sp.sparse.save_npz('/tmp/sparse_array.npz', sparse_array)
  103. >>> sparse_array = sp.sparse.load_npz('/tmp/sparse_array.npz')
  104. >>> sparse_array
  105. <Compressed Sparse Column sparse array of dtype 'int64'
  106. with 2 stored elements and shape (2, 3)>
  107. >>> sparse_array.toarray()
  108. array([[0, 0, 3],
  109. [4, 0, 0]], dtype=int64)
  110. In this example we force the result to be csr_array from csr_matrix
  111. >>> sparse_matrix = sp.sparse.csc_matrix([[0, 0, 3], [4, 0, 0]])
  112. >>> sp.sparse.save_npz('/tmp/sparse_matrix.npz', sparse_matrix)
  113. >>> tmp = sp.sparse.load_npz('/tmp/sparse_matrix.npz')
  114. >>> sparse_array = sp.sparse.csr_array(tmp)
  115. """
  116. with np.load(file, **PICKLE_KWARGS) as loaded:
  117. sparse_format = loaded.get('format')
  118. if sparse_format is None:
  119. raise ValueError(f'The file {file} does not contain '
  120. f'a sparse array or matrix.')
  121. sparse_format = sparse_format.item()
  122. if not isinstance(sparse_format, str):
  123. # Play safe with Python 2 vs 3 backward compatibility;
  124. # files saved with SciPy < 1.0.0 may contain unicode or bytes.
  125. sparse_format = sparse_format.decode('ascii')
  126. if loaded.get('_is_array'):
  127. sparse_type = sparse_format + '_array'
  128. else:
  129. sparse_type = sparse_format + '_matrix'
  130. try:
  131. cls = getattr(sp.sparse, f'{sparse_type}')
  132. except AttributeError as e:
  133. raise ValueError(f'Unknown format "{sparse_type}"') from e
  134. if sparse_format in ('csc', 'csr', 'bsr'):
  135. return cls((loaded['data'], loaded['indices'], loaded['indptr']),
  136. shape=loaded['shape'])
  137. elif sparse_format == 'dia':
  138. return cls((loaded['data'], loaded['offsets']), shape=loaded['shape'])
  139. elif sparse_format == 'coo':
  140. if 'coords' in loaded:
  141. return cls((loaded['data'], loaded['coords']), shape=loaded['shape'])
  142. return cls((loaded['data'], (loaded['row'], loaded['col'])),
  143. shape=loaded['shape'])
  144. else:
  145. raise NotImplementedError(f'Load is not implemented for '
  146. f'sparse matrix of format {sparse_format}.')