_mmio.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. """
  2. Matrix Market I/O in Python.
  3. See http://math.nist.gov/MatrixMarket/formats.html
  4. for information about the Matrix Market format.
  5. """
  6. #
  7. # Author: Pearu Peterson <pearu@cens.ioc.ee>
  8. # Created: October, 2004
  9. #
  10. # References:
  11. # http://math.nist.gov/MatrixMarket/
  12. #
  13. import os
  14. import numpy as np
  15. from numpy import (asarray, real, imag, conj, zeros, ndarray, concatenate,
  16. ones, can_cast)
  17. from scipy.sparse import coo_array, issparse, coo_matrix
  18. __all__ = ['mminfo', 'mmread', 'mmwrite', 'MMFile']
  19. # -----------------------------------------------------------------------------
  20. def asstr(s):
  21. if isinstance(s, bytes):
  22. return s.decode('latin1')
  23. return str(s)
  24. def mminfo(source):
  25. """
  26. Return size and storage parameters from Matrix Market file-like 'source'.
  27. Parameters
  28. ----------
  29. source : str or file-like
  30. Matrix Market filename (extension .mtx) or open file-like object
  31. Returns
  32. -------
  33. rows : int
  34. Number of matrix rows.
  35. cols : int
  36. Number of matrix columns.
  37. entries : int
  38. Number of non-zero entries of a sparse matrix
  39. or rows*cols for a dense matrix.
  40. format : str
  41. Either 'coordinate' or 'array'.
  42. field : str
  43. Either 'real', 'complex', 'pattern', or 'integer'.
  44. symmetry : str
  45. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  46. Examples
  47. --------
  48. >>> from io import StringIO
  49. >>> from scipy.io import mminfo
  50. >>> text = '''%%MatrixMarket matrix coordinate real general
  51. ... 5 5 7
  52. ... 2 3 1.0
  53. ... 3 4 2.0
  54. ... 3 5 3.0
  55. ... 4 1 4.0
  56. ... 4 2 5.0
  57. ... 4 3 6.0
  58. ... 4 4 7.0
  59. ... '''
  60. ``mminfo(source)`` returns the number of rows, number of columns,
  61. format, field type and symmetry attribute of the source file.
  62. >>> mminfo(StringIO(text))
  63. (5, 5, 7, 'coordinate', 'real', 'general')
  64. """
  65. return MMFile.info(source)
  66. # -----------------------------------------------------------------------------
  67. def mmread(source, *, spmatrix=True):
  68. """
  69. Reads the contents of a Matrix Market file-like 'source' into a matrix.
  70. Parameters
  71. ----------
  72. source : str or file-like
  73. Matrix Market filename (extensions .mtx, .mtz.gz)
  74. or open file-like object.
  75. spmatrix : bool, optional (default: True)
  76. If ``True``, return sparse matrix. Otherwise return sparse array.
  77. Returns
  78. -------
  79. a : ndarray or coo_array or coo_matrix
  80. Dense or sparse array depending on the matrix format in the
  81. Matrix Market file.
  82. Examples
  83. --------
  84. >>> from io import StringIO
  85. >>> from scipy.io import mmread
  86. >>> text = '''%%MatrixMarket matrix coordinate real general
  87. ... 5 5 7
  88. ... 2 3 1.0
  89. ... 3 4 2.0
  90. ... 3 5 3.0
  91. ... 4 1 4.0
  92. ... 4 2 5.0
  93. ... 4 3 6.0
  94. ... 4 4 7.0
  95. ... '''
  96. ``mmread(source)`` returns the data as sparse matrix in COO format.
  97. >>> m = mmread(StringIO(text), spmatrix=False)
  98. >>> m
  99. <COOrdinate sparse array of dtype 'float64'
  100. with 7 stored elements and shape (5, 5)>
  101. >>> m.toarray()
  102. array([[0., 0., 0., 0., 0.],
  103. [0., 0., 1., 0., 0.],
  104. [0., 0., 0., 2., 3.],
  105. [4., 5., 6., 7., 0.],
  106. [0., 0., 0., 0., 0.]])
  107. """
  108. return MMFile().read(source, spmatrix=spmatrix)
  109. # -----------------------------------------------------------------------------
  110. def mmwrite(target, a, comment='', field=None, precision=None, symmetry=None):
  111. r"""
  112. Writes the sparse or dense array `a` to Matrix Market file-like `target`.
  113. Parameters
  114. ----------
  115. target : str or file-like
  116. Matrix Market filename (extension .mtx) or open file-like object.
  117. a : array like
  118. Sparse or dense 2-D array.
  119. comment : str, optional
  120. Comments to be prepended to the Matrix Market file.
  121. field : None or str, optional
  122. Either 'real', 'complex', 'pattern', or 'integer'.
  123. precision : None or int, optional
  124. Number of digits to display for real or complex values.
  125. symmetry : None or str, optional
  126. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  127. If symmetry is None the symmetry type of 'a' is determined by its
  128. values.
  129. Returns
  130. -------
  131. None
  132. Examples
  133. --------
  134. >>> from io import BytesIO
  135. >>> import numpy as np
  136. >>> from scipy.sparse import coo_array
  137. >>> from scipy.io import mmwrite
  138. Write a small NumPy array to a matrix market file. The file will be
  139. written in the ``'array'`` format.
  140. >>> a = np.array([[1.0, 0, 0, 0], [0, 2.5, 0, 6.25]])
  141. >>> target = BytesIO()
  142. >>> mmwrite(target, a)
  143. >>> print(target.getvalue().decode('latin1'))
  144. %%MatrixMarket matrix array real general
  145. %
  146. 2 4
  147. 1
  148. 0
  149. 0
  150. 2.5
  151. 0
  152. 0
  153. 0
  154. 6.25
  155. Add a comment to the output file, and set the precision to 3.
  156. >>> target = BytesIO()
  157. >>> mmwrite(target, a, comment='\n Some test data.\n', precision=3)
  158. >>> print(target.getvalue().decode('latin1'))
  159. %%MatrixMarket matrix array real general
  160. %
  161. % Some test data.
  162. %
  163. 2 4
  164. 1.00e+00
  165. 0.00e+00
  166. 0.00e+00
  167. 2.50e+00
  168. 0.00e+00
  169. 0.00e+00
  170. 0.00e+00
  171. 6.25e+00
  172. Convert to a sparse matrix before calling ``mmwrite``. This will
  173. result in the output format being ``'coordinate'`` rather than
  174. ``'array'``.
  175. >>> target = BytesIO()
  176. >>> mmwrite(target, coo_array(a), precision=3)
  177. >>> print(target.getvalue().decode('latin1'))
  178. %%MatrixMarket matrix coordinate real general
  179. %
  180. 2 4 3
  181. 1 1 1.00e+00
  182. 2 2 2.50e+00
  183. 2 4 6.25e+00
  184. Write a complex Hermitian array to a matrix market file. Note that
  185. only six values are actually written to the file; the other values
  186. are implied by the symmetry.
  187. >>> z = np.array([[3, 1+2j, 4-3j], [1-2j, 1, -5j], [4+3j, 5j, 2.5]])
  188. >>> z
  189. array([[ 3. +0.j, 1. +2.j, 4. -3.j],
  190. [ 1. -2.j, 1. +0.j, -0. -5.j],
  191. [ 4. +3.j, 0. +5.j, 2.5+0.j]])
  192. >>> target = BytesIO()
  193. >>> mmwrite(target, z, precision=2)
  194. >>> print(target.getvalue().decode('latin1'))
  195. %%MatrixMarket matrix array complex hermitian
  196. %
  197. 3 3
  198. 3.0e+00 0.0e+00
  199. 1.0e+00 -2.0e+00
  200. 4.0e+00 3.0e+00
  201. 1.0e+00 0.0e+00
  202. 0.0e+00 5.0e+00
  203. 2.5e+00 0.0e+00
  204. """
  205. MMFile().write(target, a, comment, field, precision, symmetry)
  206. ###############################################################################
  207. class MMFile:
  208. __slots__ = ('_rows',
  209. '_cols',
  210. '_entries',
  211. '_format',
  212. '_field',
  213. '_symmetry')
  214. @property
  215. def rows(self):
  216. return self._rows
  217. @property
  218. def cols(self):
  219. return self._cols
  220. @property
  221. def entries(self):
  222. return self._entries
  223. @property
  224. def format(self):
  225. return self._format
  226. @property
  227. def field(self):
  228. return self._field
  229. @property
  230. def symmetry(self):
  231. return self._symmetry
  232. @property
  233. def has_symmetry(self):
  234. return self._symmetry in (self.SYMMETRY_SYMMETRIC,
  235. self.SYMMETRY_SKEW_SYMMETRIC,
  236. self.SYMMETRY_HERMITIAN)
  237. # format values
  238. FORMAT_COORDINATE = 'coordinate'
  239. FORMAT_ARRAY = 'array'
  240. FORMAT_VALUES = (FORMAT_COORDINATE, FORMAT_ARRAY)
  241. @classmethod
  242. def _validate_format(self, format):
  243. if format not in self.FORMAT_VALUES:
  244. msg = f'unknown format type {format}, must be one of {self.FORMAT_VALUES}'
  245. raise ValueError(msg)
  246. # field values
  247. FIELD_INTEGER = 'integer'
  248. FIELD_UNSIGNED = 'unsigned-integer'
  249. FIELD_REAL = 'real'
  250. FIELD_COMPLEX = 'complex'
  251. FIELD_PATTERN = 'pattern'
  252. FIELD_VALUES = (FIELD_INTEGER, FIELD_UNSIGNED, FIELD_REAL, FIELD_COMPLEX,
  253. FIELD_PATTERN)
  254. @classmethod
  255. def _validate_field(self, field):
  256. if field not in self.FIELD_VALUES:
  257. msg = f'unknown field type {field}, must be one of {self.FIELD_VALUES}'
  258. raise ValueError(msg)
  259. # symmetry values
  260. SYMMETRY_GENERAL = 'general'
  261. SYMMETRY_SYMMETRIC = 'symmetric'
  262. SYMMETRY_SKEW_SYMMETRIC = 'skew-symmetric'
  263. SYMMETRY_HERMITIAN = 'hermitian'
  264. SYMMETRY_VALUES = (SYMMETRY_GENERAL, SYMMETRY_SYMMETRIC,
  265. SYMMETRY_SKEW_SYMMETRIC, SYMMETRY_HERMITIAN)
  266. @classmethod
  267. def _validate_symmetry(self, symmetry):
  268. if symmetry not in self.SYMMETRY_VALUES:
  269. raise ValueError(f'unknown symmetry type {symmetry}, '
  270. f'must be one of {self.SYMMETRY_VALUES}')
  271. DTYPES_BY_FIELD = {FIELD_INTEGER: 'intp',
  272. FIELD_UNSIGNED: 'uint64',
  273. FIELD_REAL: 'd',
  274. FIELD_COMPLEX: 'D',
  275. FIELD_PATTERN: 'd'}
  276. # -------------------------------------------------------------------------
  277. @staticmethod
  278. def reader():
  279. pass
  280. # -------------------------------------------------------------------------
  281. @staticmethod
  282. def writer():
  283. pass
  284. # -------------------------------------------------------------------------
  285. @classmethod
  286. def info(self, source):
  287. """
  288. Return size, storage parameters from Matrix Market file-like 'source'.
  289. Parameters
  290. ----------
  291. source : str or file-like
  292. Matrix Market filename (extension .mtx) or open file-like object
  293. Returns
  294. -------
  295. rows : int
  296. Number of matrix rows.
  297. cols : int
  298. Number of matrix columns.
  299. entries : int
  300. Number of non-zero entries of a sparse matrix
  301. or rows*cols for a dense matrix.
  302. format : str
  303. Either 'coordinate' or 'array'.
  304. field : str
  305. Either 'real', 'complex', 'pattern', or 'integer'.
  306. symmetry : str
  307. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  308. """
  309. stream, close_it = self._open(source)
  310. try:
  311. # read and validate header line
  312. line = stream.readline()
  313. mmid, matrix, format, field, symmetry = \
  314. (asstr(part.strip()) for part in line.split())
  315. if not mmid.startswith('%%MatrixMarket'):
  316. raise ValueError('source is not in Matrix Market format')
  317. if not matrix.lower() == 'matrix':
  318. raise ValueError("Problem reading file header: " + line)
  319. # http://math.nist.gov/MatrixMarket/formats.html
  320. if format.lower() == 'array':
  321. format = self.FORMAT_ARRAY
  322. elif format.lower() == 'coordinate':
  323. format = self.FORMAT_COORDINATE
  324. # skip comments
  325. # line.startswith('%')
  326. while line:
  327. if line.lstrip() and line.lstrip()[0] in ['%', 37]:
  328. line = stream.readline()
  329. else:
  330. break
  331. # skip empty lines
  332. while not line.strip():
  333. line = stream.readline()
  334. split_line = line.split()
  335. if format == self.FORMAT_ARRAY:
  336. if not len(split_line) == 2:
  337. raise ValueError("Header line not of length 2: " +
  338. line.decode('ascii'))
  339. rows, cols = map(int, split_line)
  340. entries = rows * cols
  341. else:
  342. if not len(split_line) == 3:
  343. raise ValueError("Header line not of length 3: " +
  344. line.decode('ascii'))
  345. rows, cols, entries = map(int, split_line)
  346. return (rows, cols, entries, format, field.lower(),
  347. symmetry.lower())
  348. finally:
  349. if close_it:
  350. stream.close()
  351. # -------------------------------------------------------------------------
  352. @staticmethod
  353. def _open(filespec, mode='rb'):
  354. """ Return an open file stream for reading based on source.
  355. If source is a file name, open it (after trying to find it with mtx and
  356. gzipped mtx extensions). Otherwise, just return source.
  357. Parameters
  358. ----------
  359. filespec : str or file-like
  360. String giving file name or file-like object
  361. mode : str, optional
  362. Mode with which to open file, if `filespec` is a file name.
  363. Returns
  364. -------
  365. fobj : file-like
  366. Open file-like object.
  367. close_it : bool
  368. True if the calling function should close this file when done,
  369. false otherwise.
  370. """
  371. # If 'filespec' is path-like (str, pathlib.Path, os.DirEntry, other class
  372. # implementing a '__fspath__' method), try to convert it to str. If this
  373. # fails by throwing a 'TypeError', assume it's an open file handle and
  374. # return it as-is.
  375. try:
  376. filespec = os.fspath(filespec)
  377. except TypeError:
  378. return filespec, False
  379. # 'filespec' is definitely a str now
  380. # open for reading
  381. if mode[0] == 'r':
  382. # determine filename plus extension
  383. if not os.path.isfile(filespec):
  384. if os.path.isfile(filespec+'.mtx'):
  385. filespec = filespec + '.mtx'
  386. elif os.path.isfile(filespec+'.mtx.gz'):
  387. filespec = filespec + '.mtx.gz'
  388. elif os.path.isfile(filespec+'.mtx.bz2'):
  389. filespec = filespec + '.mtx.bz2'
  390. # open filename
  391. if filespec.endswith('.gz'):
  392. import gzip
  393. stream = gzip.open(filespec, mode)
  394. elif filespec.endswith('.bz2'):
  395. import bz2
  396. stream = bz2.BZ2File(filespec, 'rb')
  397. else:
  398. stream = open(filespec, mode)
  399. # open for writing
  400. else:
  401. if filespec[-4:] != '.mtx':
  402. filespec = filespec + '.mtx'
  403. stream = open(filespec, mode)
  404. return stream, True
  405. # -------------------------------------------------------------------------
  406. @staticmethod
  407. def _get_symmetry(a):
  408. m, n = a.shape
  409. if m != n:
  410. return MMFile.SYMMETRY_GENERAL
  411. issymm = True
  412. isskew = True
  413. isherm = a.dtype.char in 'FD'
  414. # sparse input
  415. if issparse(a):
  416. # check if number of nonzero entries of lower and upper triangle
  417. # matrix are equal
  418. a = a.tocoo()
  419. (row, col) = a.nonzero()
  420. if (row < col).sum() != (row > col).sum():
  421. return MMFile.SYMMETRY_GENERAL
  422. # define iterator over symmetric pair entries
  423. a = a.todok()
  424. def symm_iterator():
  425. for ((i, j), aij) in a.items():
  426. if i > j:
  427. aji = a[j, i]
  428. yield (aij, aji, False)
  429. elif i == j:
  430. yield (aij, aij, True)
  431. # non-sparse input
  432. else:
  433. # define iterator over symmetric pair entries
  434. def symm_iterator():
  435. for j in range(n):
  436. for i in range(j, n):
  437. aij, aji = a[i][j], a[j][i]
  438. yield (aij, aji, i == j)
  439. # check for symmetry
  440. # yields aij, aji, is_diagonal
  441. for (aij, aji, is_diagonal) in symm_iterator():
  442. if isskew and is_diagonal and aij != 0:
  443. isskew = False
  444. else:
  445. if issymm and aij != aji:
  446. issymm = False
  447. with np.errstate(over="ignore"):
  448. # This can give a warning for uint dtypes, so silence that
  449. if isskew and aij != -aji:
  450. isskew = False
  451. if isherm and aij != conj(aji):
  452. isherm = False
  453. if not (issymm or isskew or isherm):
  454. break
  455. # return symmetry value
  456. if issymm:
  457. return MMFile.SYMMETRY_SYMMETRIC
  458. if isskew:
  459. return MMFile.SYMMETRY_SKEW_SYMMETRIC
  460. if isherm:
  461. return MMFile.SYMMETRY_HERMITIAN
  462. return MMFile.SYMMETRY_GENERAL
  463. # -------------------------------------------------------------------------
  464. @staticmethod
  465. def _field_template(field, precision):
  466. return {MMFile.FIELD_REAL: '%%.%ie\n' % precision,
  467. MMFile.FIELD_INTEGER: '%i\n',
  468. MMFile.FIELD_UNSIGNED: '%u\n',
  469. MMFile.FIELD_COMPLEX: '%%.%ie %%.%ie\n' %
  470. (precision, precision)
  471. }.get(field, None)
  472. # -------------------------------------------------------------------------
  473. def __init__(self, **kwargs):
  474. self._init_attrs(**kwargs)
  475. # -------------------------------------------------------------------------
  476. def read(self, source, *, spmatrix=True):
  477. """
  478. Reads the contents of a Matrix Market file-like 'source' into a matrix.
  479. Parameters
  480. ----------
  481. source : str or file-like
  482. Matrix Market filename (extensions .mtx, .mtz.gz)
  483. or open file object.
  484. spmatrix : bool, optional (default: True)
  485. If ``True``, return sparse matrix. Otherwise return sparse array.
  486. Returns
  487. -------
  488. a : ndarray or coo_array or coo_matrix
  489. Dense or sparse array depending on the matrix format in the
  490. Matrix Market file.
  491. """
  492. stream, close_it = self._open(source)
  493. try:
  494. self._parse_header(stream)
  495. data = self._parse_body(stream)
  496. finally:
  497. if close_it:
  498. stream.close()
  499. if spmatrix and isinstance(data, coo_array):
  500. data = coo_matrix(data)
  501. return data
  502. # -------------------------------------------------------------------------
  503. def write(self, target, a, comment='', field=None, precision=None,
  504. symmetry=None):
  505. """
  506. Writes sparse or dense array `a` to Matrix Market file-like `target`.
  507. Parameters
  508. ----------
  509. target : str or file-like
  510. Matrix Market filename (extension .mtx) or open file-like object.
  511. a : array like
  512. Sparse or dense 2-D array.
  513. comment : str, optional
  514. Comments to be prepended to the Matrix Market file.
  515. field : None or str, optional
  516. Either 'real', 'complex', 'pattern', or 'integer'.
  517. precision : None or int, optional
  518. Number of digits to display for real or complex values.
  519. symmetry : None or str, optional
  520. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  521. If symmetry is None the symmetry type of 'a' is determined by its
  522. values.
  523. """
  524. stream, close_it = self._open(target, 'wb')
  525. try:
  526. self._write(stream, a, comment, field, precision, symmetry)
  527. finally:
  528. if close_it:
  529. stream.close()
  530. else:
  531. stream.flush()
  532. # -------------------------------------------------------------------------
  533. def _init_attrs(self, **kwargs):
  534. """
  535. Initialize each attributes with the corresponding keyword arg value
  536. or a default of None
  537. """
  538. attrs = self.__class__.__slots__
  539. public_attrs = [attr[1:] for attr in attrs]
  540. invalid_keys = set(kwargs.keys()) - set(public_attrs)
  541. if invalid_keys:
  542. raise ValueError(f"found {tuple(invalid_keys)} invalid keyword "
  543. f"arguments, please only use {public_attrs}")
  544. for attr in attrs:
  545. setattr(self, attr, kwargs.get(attr[1:], None))
  546. # -------------------------------------------------------------------------
  547. def _parse_header(self, stream):
  548. rows, cols, entries, format, field, symmetry = \
  549. self.__class__.info(stream)
  550. self._init_attrs(rows=rows, cols=cols, entries=entries, format=format,
  551. field=field, symmetry=symmetry)
  552. # -------------------------------------------------------------------------
  553. def _parse_body(self, stream):
  554. rows, cols, entries, format, field, symm = (self.rows, self.cols,
  555. self.entries, self.format,
  556. self.field, self.symmetry)
  557. dtype = self.DTYPES_BY_FIELD.get(field, None)
  558. has_symmetry = self.has_symmetry
  559. is_integer = field == self.FIELD_INTEGER
  560. is_unsigned_integer = field == self.FIELD_UNSIGNED
  561. is_complex = field == self.FIELD_COMPLEX
  562. is_skew = symm == self.SYMMETRY_SKEW_SYMMETRIC
  563. is_herm = symm == self.SYMMETRY_HERMITIAN
  564. is_pattern = field == self.FIELD_PATTERN
  565. if format == self.FORMAT_ARRAY:
  566. a = zeros((rows, cols), dtype=dtype)
  567. line = 1
  568. i, j = 0, 0
  569. if is_skew:
  570. a[i, j] = 0
  571. if i < rows - 1:
  572. i += 1
  573. while line:
  574. line = stream.readline()
  575. # line.startswith('%')
  576. if not line or line[0] in ['%', 37] or not line.strip():
  577. continue
  578. if is_integer:
  579. aij = int(line)
  580. elif is_unsigned_integer:
  581. aij = int(line)
  582. elif is_complex:
  583. aij = complex(*map(float, line.split()))
  584. else:
  585. aij = float(line)
  586. a[i, j] = aij
  587. if has_symmetry and i != j:
  588. if is_skew:
  589. a[j, i] = -aij
  590. elif is_herm:
  591. a[j, i] = conj(aij)
  592. else:
  593. a[j, i] = aij
  594. if i < rows-1:
  595. i = i + 1
  596. else:
  597. j = j + 1
  598. if not has_symmetry:
  599. i = 0
  600. else:
  601. i = j
  602. if is_skew:
  603. a[i, j] = 0
  604. if i < rows-1:
  605. i += 1
  606. if is_skew:
  607. if not (i in [0, j] and j == cols - 1):
  608. raise ValueError("Parse error, did not read all lines.")
  609. else:
  610. if not (i in [0, j] and j == cols):
  611. raise ValueError("Parse error, did not read all lines.")
  612. elif format == self.FORMAT_COORDINATE:
  613. # Read sparse COOrdinate format
  614. if entries == 0:
  615. # empty matrix
  616. return coo_array((rows, cols), dtype=dtype)
  617. I = zeros(entries, dtype='intc')
  618. J = zeros(entries, dtype='intc')
  619. if is_pattern:
  620. V = ones(entries, dtype='int8')
  621. elif is_integer:
  622. V = zeros(entries, dtype='intp')
  623. elif is_unsigned_integer:
  624. V = zeros(entries, dtype='uint64')
  625. elif is_complex:
  626. V = zeros(entries, dtype='complex')
  627. else:
  628. V = zeros(entries, dtype='float')
  629. entry_number = 0
  630. for line in stream:
  631. # line.startswith('%')
  632. if not line or line[0] in ['%', 37] or not line.strip():
  633. continue
  634. if entry_number+1 > entries:
  635. raise ValueError("'entries' in header is smaller than "
  636. "number of entries")
  637. l = line.split()
  638. I[entry_number], J[entry_number] = map(int, l[:2])
  639. if not is_pattern:
  640. if is_integer:
  641. V[entry_number] = int(l[2])
  642. elif is_unsigned_integer:
  643. V[entry_number] = int(l[2])
  644. elif is_complex:
  645. V[entry_number] = complex(*map(float, l[2:]))
  646. else:
  647. V[entry_number] = float(l[2])
  648. entry_number += 1
  649. if entry_number < entries:
  650. raise ValueError("'entries' in header is larger than "
  651. "number of entries")
  652. I -= 1 # adjust indices (base 1 -> base 0)
  653. J -= 1
  654. if has_symmetry:
  655. mask = (I != J) # off diagonal mask
  656. od_I = I[mask]
  657. od_J = J[mask]
  658. od_V = V[mask]
  659. I = concatenate((I, od_J))
  660. J = concatenate((J, od_I))
  661. if is_skew:
  662. od_V *= -1
  663. elif is_herm:
  664. od_V = od_V.conjugate()
  665. V = concatenate((V, od_V))
  666. a = coo_array((V, (I, J)), shape=(rows, cols), dtype=dtype)
  667. else:
  668. raise NotImplementedError(format)
  669. return a
  670. # ------------------------------------------------------------------------
  671. def _write(self, stream, a, comment='', field=None, precision=None,
  672. symmetry=None):
  673. if isinstance(a, list) or isinstance(a, ndarray) or \
  674. isinstance(a, tuple) or hasattr(a, '__array__'):
  675. rep = self.FORMAT_ARRAY
  676. a = asarray(a)
  677. if len(a.shape) != 2:
  678. raise ValueError('Expected 2 dimensional array')
  679. rows, cols = a.shape
  680. if field is not None:
  681. if field == self.FIELD_INTEGER:
  682. if not can_cast(a.dtype, 'intp'):
  683. raise OverflowError("mmwrite does not support integer "
  684. "dtypes larger than native 'intp'.")
  685. a = a.astype('intp')
  686. elif field == self.FIELD_REAL:
  687. if a.dtype.char not in 'fd':
  688. a = a.astype('d')
  689. elif field == self.FIELD_COMPLEX:
  690. if a.dtype.char not in 'FD':
  691. a = a.astype('D')
  692. else:
  693. if not issparse(a):
  694. raise ValueError(f'unknown matrix type: {type(a)}')
  695. rep = 'coordinate'
  696. rows, cols = a.shape
  697. typecode = a.dtype.char
  698. if precision is None:
  699. if typecode in 'fF':
  700. precision = 8
  701. else:
  702. precision = 16
  703. if field is None:
  704. kind = a.dtype.kind
  705. if kind == 'i':
  706. if not can_cast(a.dtype, 'intp'):
  707. raise OverflowError("mmwrite does not support integer "
  708. "dtypes larger than native 'intp'.")
  709. field = 'integer'
  710. elif kind == 'f':
  711. field = 'real'
  712. elif kind == 'c':
  713. field = 'complex'
  714. elif kind == 'u':
  715. field = 'unsigned-integer'
  716. else:
  717. raise TypeError('unexpected dtype kind ' + kind)
  718. if symmetry is None:
  719. symmetry = self._get_symmetry(a)
  720. # validate rep, field, and symmetry
  721. self.__class__._validate_format(rep)
  722. self.__class__._validate_field(field)
  723. self.__class__._validate_symmetry(symmetry)
  724. # write initial header line
  725. data = f'%%MatrixMarket matrix {rep} {field} {symmetry}\n'
  726. stream.write(data.encode('latin1'))
  727. # write comments
  728. for line in comment.split('\n'):
  729. data = f'%{line}\n'
  730. stream.write(data.encode('latin1'))
  731. template = self._field_template(field, precision)
  732. # write dense format
  733. if rep == self.FORMAT_ARRAY:
  734. # write shape spec
  735. data = '%i %i\n' % (rows, cols)
  736. stream.write(data.encode('latin1'))
  737. if field in (self.FIELD_INTEGER, self.FIELD_REAL,
  738. self.FIELD_UNSIGNED):
  739. if symmetry == self.SYMMETRY_GENERAL:
  740. for j in range(cols):
  741. for i in range(rows):
  742. data = template % a[i, j]
  743. stream.write(data.encode('latin1'))
  744. elif symmetry == self.SYMMETRY_SKEW_SYMMETRIC:
  745. for j in range(cols):
  746. for i in range(j + 1, rows):
  747. data = template % a[i, j]
  748. stream.write(data.encode('latin1'))
  749. else:
  750. for j in range(cols):
  751. for i in range(j, rows):
  752. data = template % a[i, j]
  753. stream.write(data.encode('latin1'))
  754. elif field == self.FIELD_COMPLEX:
  755. if symmetry == self.SYMMETRY_GENERAL:
  756. for j in range(cols):
  757. for i in range(rows):
  758. aij = a[i, j]
  759. data = template % (real(aij), imag(aij))
  760. stream.write(data.encode('latin1'))
  761. else:
  762. for j in range(cols):
  763. for i in range(j, rows):
  764. aij = a[i, j]
  765. data = template % (real(aij), imag(aij))
  766. stream.write(data.encode('latin1'))
  767. elif field == self.FIELD_PATTERN:
  768. raise ValueError('pattern type inconsisted with dense format')
  769. else:
  770. raise TypeError(f'Unknown field type {field}')
  771. # write sparse format
  772. else:
  773. coo = a.tocoo() # convert to COOrdinate format
  774. # if symmetry format used, remove values above main diagonal
  775. if symmetry != self.SYMMETRY_GENERAL:
  776. lower_triangle_mask = coo.row >= coo.col
  777. coo = coo_array((coo.data[lower_triangle_mask],
  778. (coo.row[lower_triangle_mask],
  779. coo.col[lower_triangle_mask])),
  780. shape=coo.shape)
  781. # write shape spec
  782. data = '%i %i %i\n' % (rows, cols, coo.nnz)
  783. stream.write(data.encode('latin1'))
  784. template = self._field_template(field, precision-1)
  785. if field == self.FIELD_PATTERN:
  786. for r, c in zip(coo.row+1, coo.col+1):
  787. data = "%i %i\n" % (r, c)
  788. stream.write(data.encode('latin1'))
  789. elif field in (self.FIELD_INTEGER, self.FIELD_REAL,
  790. self.FIELD_UNSIGNED):
  791. for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
  792. data = ("%i %i " % (r, c)) + (template % d)
  793. stream.write(data.encode('latin1'))
  794. elif field == self.FIELD_COMPLEX:
  795. for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
  796. data = ("%i %i " % (r, c)) + (template % (d.real, d.imag))
  797. stream.write(data.encode('latin1'))
  798. else:
  799. raise TypeError(f'Unknown field type {field}')
  800. def _is_fromfile_compatible(stream):
  801. """
  802. Check whether `stream` is compatible with numpy.fromfile.
  803. Passing a gzipped file object to ``fromfile/fromstring`` doesn't work with
  804. Python 3.
  805. """
  806. bad_cls = []
  807. try:
  808. import gzip
  809. bad_cls.append(gzip.GzipFile)
  810. except ImportError:
  811. pass
  812. try:
  813. import bz2
  814. bad_cls.append(bz2.BZ2File)
  815. except ImportError:
  816. pass
  817. bad_cls = tuple(bad_cls)
  818. return not isinstance(stream, bad_cls)