_mio5.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901
  1. ''' Classes for read / write of matlab (TM) 5 files
  2. The matfile specification last found here:
  3. https://www.mathworks.com/access/helpdesk/help/pdf_doc/matlab/matfile_format.pdf
  4. (as of December 5 2008)
  5. =================================
  6. Note on functions and mat files
  7. =================================
  8. The document above does not give any hints as to the storage of matlab
  9. function handles, or anonymous function handles. I had, therefore, to
  10. guess the format of matlab arrays of ``mxFUNCTION_CLASS`` and
  11. ``mxOPAQUE_CLASS`` by looking at example mat files.
  12. ``mxFUNCTION_CLASS`` stores all types of matlab functions. It seems to
  13. contain a struct matrix with a set pattern of fields. For anonymous
  14. functions, a sub-fields of one of these fields seems to contain the
  15. well-named ``mxOPAQUE_CLASS``. This seems to contain:
  16. * array flags as for any matlab matrix
  17. * 3 int8 strings
  18. * a matrix
  19. It seems that whenever the mat file contains a ``mxOPAQUE_CLASS``
  20. instance, there is also an un-named matrix (name == '') at the end of
  21. the mat file. I'll call this the ``__function_workspace__`` matrix.
  22. When I saved two anonymous functions in a mat file, or appended another
  23. anonymous function to the mat file, there was still only one
  24. ``__function_workspace__`` un-named matrix at the end, but larger than
  25. that for a mat file with a single anonymous function, suggesting that
  26. the workspaces for the two functions had been merged.
  27. The ``__function_workspace__`` matrix appears to be of double class
  28. (``mxCLASS_DOUBLE``), but stored as uint8, the memory for which is in
  29. the format of a mini .mat file, without the first 124 bytes of the file
  30. header (the description and the subsystem_offset), but with the version
  31. U2 bytes, and the S2 endian test bytes. There follow 4 zero bytes,
  32. presumably for 8 byte padding, and then a series of ``miMATRIX``
  33. entries, as in a standard mat file. The ``miMATRIX`` entries appear to
  34. be series of un-named (name == '') matrices, and may also contain arrays
  35. of this same mini-mat format.
  36. I guess that:
  37. * saving an anonymous function back to a mat file will need the
  38. associated ``__function_workspace__`` matrix saved as well for the
  39. anonymous function to work correctly.
  40. * appending to a mat file that has a ``__function_workspace__`` would
  41. involve first pulling off this workspace, appending, checking whether
  42. there were any more anonymous functions appended, and then somehow
  43. merging the relevant workspaces, and saving at the end of the mat
  44. file.
  45. The mat files I was playing with are in ``tests/data``:
  46. * sqr.mat
  47. * parabola.mat
  48. * some_functions.mat
  49. See ``tests/test_mio.py:test_mio_funcs.py`` for the debugging
  50. script I was working with.
  51. Small fragments of current code adapted from matfile.py by Heiko
  52. Henkelmann; parts of the code for simplify_cells=True adapted from
  53. http://blog.nephics.com/2019/08/28/better-loadmat-for-scipy/.
  54. '''
  55. import math
  56. import os
  57. import time
  58. import sys
  59. import zlib
  60. from io import BytesIO
  61. import warnings
  62. import numpy as np
  63. import scipy.sparse
  64. from ._byteordercodes import native_code, swapped_code
  65. from ._miobase import (MatFileReader, docfiller, matdims, read_dtype,
  66. arr_to_chars, arr_dtype_number, MatWriteError,
  67. MatReadError, MatReadWarning, MatWriteWarning)
  68. # Reader object for matlab 5 format variables
  69. from ._mio5_utils import VarReader5
  70. # Constants and helper objects
  71. from ._mio5_params import (MatlabObject, MatlabFunction, MDTYPES, NP_TO_MTYPES,
  72. NP_TO_MXTYPES, miCOMPRESSED, miMATRIX, miINT8,
  73. miUTF8, miUINT32, mxCELL_CLASS, mxSTRUCT_CLASS,
  74. mxOBJECT_CLASS, mxCHAR_CLASS, mxSPARSE_CLASS,
  75. mxDOUBLE_CLASS, mclass_info, mat_struct)
  76. from ._streams import ZlibInputStream
  77. def _has_struct(elem):
  78. """Determine if elem is an array and if first array item is a struct."""
  79. return (isinstance(elem, np.ndarray) and (elem.size > 0) and (elem.ndim > 0) and
  80. isinstance(elem[0], mat_struct))
  81. def _inspect_cell_array(ndarray):
  82. """Construct lists from cell arrays (loaded as numpy ndarrays), recursing
  83. into items if they contain mat_struct objects."""
  84. elem_list = []
  85. for sub_elem in ndarray:
  86. if isinstance(sub_elem, mat_struct):
  87. elem_list.append(_matstruct_to_dict(sub_elem))
  88. elif _has_struct(sub_elem):
  89. elem_list.append(_inspect_cell_array(sub_elem))
  90. else:
  91. elem_list.append(sub_elem)
  92. return elem_list
  93. def _matstruct_to_dict(matobj):
  94. """Construct nested dicts from mat_struct objects."""
  95. d = {}
  96. for f in matobj._fieldnames:
  97. elem = matobj.__dict__[f]
  98. if isinstance(elem, mat_struct):
  99. d[f] = _matstruct_to_dict(elem)
  100. elif _has_struct(elem):
  101. d[f] = _inspect_cell_array(elem)
  102. else:
  103. d[f] = elem
  104. return d
  105. def _simplify_cells(d):
  106. """Convert mat objects in dict to nested dicts."""
  107. for key in d:
  108. if isinstance(d[key], mat_struct):
  109. d[key] = _matstruct_to_dict(d[key])
  110. elif _has_struct(d[key]):
  111. d[key] = _inspect_cell_array(d[key])
  112. return d
  113. class MatFile5Reader(MatFileReader):
  114. ''' Reader for Mat 5 mat files
  115. Adds the following attribute to base class
  116. uint16_codec - char codec to use for uint16 char arrays
  117. (defaults to system default codec)
  118. Uses variable reader that has the following standard interface (see
  119. abstract class in ``miobase``::
  120. __init__(self, file_reader)
  121. read_header(self)
  122. array_from_header(self)
  123. and added interface::
  124. set_stream(self, stream)
  125. read_full_tag(self)
  126. '''
  127. @docfiller
  128. def __init__(self,
  129. mat_stream,
  130. byte_order=None,
  131. mat_dtype=False,
  132. squeeze_me=False,
  133. chars_as_strings=True,
  134. matlab_compatible=False,
  135. struct_as_record=True,
  136. verify_compressed_data_integrity=True,
  137. uint16_codec=None,
  138. simplify_cells=False):
  139. '''Initializer for matlab 5 file format reader
  140. %(matstream_arg)s
  141. %(load_args)s
  142. %(struct_arg)s
  143. uint16_codec : {None, string}
  144. Set codec to use for uint16 char arrays (e.g., 'utf-8').
  145. Use system default codec if None
  146. '''
  147. super().__init__(
  148. mat_stream,
  149. byte_order,
  150. mat_dtype,
  151. squeeze_me,
  152. chars_as_strings,
  153. matlab_compatible,
  154. struct_as_record,
  155. verify_compressed_data_integrity,
  156. simplify_cells)
  157. # Set uint16 codec
  158. if not uint16_codec:
  159. uint16_codec = sys.getdefaultencoding()
  160. self.uint16_codec = uint16_codec
  161. # placeholders for readers - see initialize_read method
  162. self._file_reader = None
  163. self._matrix_reader = None
  164. def guess_byte_order(self):
  165. ''' Guess byte order.
  166. Sets stream pointer to 0'''
  167. self.mat_stream.seek(126)
  168. mi = self.mat_stream.read(2)
  169. self.mat_stream.seek(0)
  170. return mi == b'IM' and '<' or '>'
  171. def read_file_header(self):
  172. ''' Read in mat 5 file header '''
  173. hdict = {}
  174. hdr_dtype = MDTYPES[self.byte_order]['dtypes']['file_header']
  175. hdr = read_dtype(self.mat_stream, hdr_dtype)
  176. hdict['__header__'] = hdr['description'].item().strip(b' \t\n\000')
  177. v_major = hdr['version'] >> 8
  178. v_minor = hdr['version'] & 0xFF
  179. hdict['__version__'] = f'{v_major}.{v_minor}'
  180. return hdict
  181. def initialize_read(self):
  182. ''' Run when beginning read of variables
  183. Sets up readers from parameters in `self`
  184. '''
  185. # reader for top level stream. We need this extra top-level
  186. # reader because we use the matrix_reader object to contain
  187. # compressed matrices (so they have their own stream)
  188. self._file_reader = VarReader5(self)
  189. # reader for matrix streams
  190. self._matrix_reader = VarReader5(self)
  191. def read_var_header(self):
  192. ''' Read header, return header, next position
  193. Header has to define at least .name and .is_global
  194. Parameters
  195. ----------
  196. None
  197. Returns
  198. -------
  199. header : object
  200. object that can be passed to self.read_var_array, and that
  201. has attributes .name and .is_global
  202. next_position : int
  203. position in stream of next variable
  204. '''
  205. mdtype, byte_count = self._file_reader.read_full_tag()
  206. if not byte_count > 0:
  207. raise ValueError("Did not read any bytes")
  208. next_pos = self.mat_stream.tell() + byte_count
  209. if mdtype == miCOMPRESSED:
  210. # Make new stream from compressed data
  211. stream = ZlibInputStream(self.mat_stream, byte_count)
  212. self._matrix_reader.set_stream(stream)
  213. check_stream_limit = self.verify_compressed_data_integrity
  214. mdtype, byte_count = self._matrix_reader.read_full_tag()
  215. else:
  216. check_stream_limit = False
  217. self._matrix_reader.set_stream(self.mat_stream)
  218. if not mdtype == miMATRIX:
  219. raise TypeError(f'Expecting miMATRIX type here, got {mdtype}')
  220. header = self._matrix_reader.read_header(check_stream_limit)
  221. return header, next_pos
  222. def read_var_array(self, header, process=True):
  223. ''' Read array, given `header`
  224. Parameters
  225. ----------
  226. header : header object
  227. object with fields defining variable header
  228. process : {True, False} bool, optional
  229. If True, apply recursive post-processing during loading of
  230. array.
  231. Returns
  232. -------
  233. arr : array
  234. array with post-processing applied or not according to
  235. `process`.
  236. '''
  237. return self._matrix_reader.array_from_header(header, process)
  238. def get_variables(self, variable_names=None):
  239. ''' get variables from stream as dictionary
  240. variable_names - optional list of variable names to get
  241. If variable_names is None, then get all variables in file
  242. '''
  243. if isinstance(variable_names, str):
  244. variable_names = [variable_names]
  245. elif variable_names is not None:
  246. variable_names = list(variable_names)
  247. self.mat_stream.seek(0)
  248. # Here we pass all the parameters in self to the reading objects
  249. self.initialize_read()
  250. mdict = self.read_file_header()
  251. mdict['__globals__'] = []
  252. while not self.end_of_stream():
  253. hdr, next_position = self.read_var_header()
  254. name = 'None' if hdr.name is None else hdr.name.decode('latin1')
  255. if name in mdict:
  256. msg = (
  257. f'Duplicate variable name "{name}" in stream'
  258. " - replacing previous with new\nConsider"
  259. "scipy.io.matlab.varmats_from_mat to split "
  260. "file into single variable files"
  261. )
  262. warnings.warn(msg, MatReadWarning, stacklevel=2)
  263. if name == '':
  264. # can only be a matlab 7 function workspace
  265. name = '__function_workspace__'
  266. # We want to keep this raw because mat_dtype processing
  267. # will break the format (uint8 as mxDOUBLE_CLASS)
  268. process = False
  269. else:
  270. process = True
  271. if variable_names is not None and name not in variable_names:
  272. self.mat_stream.seek(next_position)
  273. continue
  274. try:
  275. res = self.read_var_array(hdr, process)
  276. except MatReadError as err:
  277. warnings.warn(
  278. f'Unreadable variable "{name}", because "{err}"',
  279. Warning, stacklevel=2)
  280. res = f"Read error: {err}"
  281. self.mat_stream.seek(next_position)
  282. mdict[name] = res
  283. if hdr.is_global:
  284. mdict['__globals__'].append(name)
  285. if variable_names is not None:
  286. variable_names.remove(name)
  287. if len(variable_names) == 0:
  288. break
  289. if self.simplify_cells:
  290. return _simplify_cells(mdict)
  291. else:
  292. return mdict
  293. def list_variables(self):
  294. ''' list variables from stream '''
  295. self.mat_stream.seek(0)
  296. # Here we pass all the parameters in self to the reading objects
  297. self.initialize_read()
  298. self.read_file_header()
  299. vars = []
  300. while not self.end_of_stream():
  301. hdr, next_position = self.read_var_header()
  302. name = 'None' if hdr.name is None else hdr.name.decode('latin1')
  303. if name == '':
  304. # can only be a matlab 7 function workspace
  305. name = '__function_workspace__'
  306. shape = self._matrix_reader.shape_from_header(hdr)
  307. if hdr.is_logical:
  308. info = 'logical'
  309. else:
  310. info = mclass_info.get(hdr.mclass, 'unknown')
  311. vars.append((name, shape, info))
  312. self.mat_stream.seek(next_position)
  313. return vars
  314. def varmats_from_mat(file_obj):
  315. """ Pull variables out of mat 5 file as a sequence of mat file objects
  316. This can be useful with a difficult mat file, containing unreadable
  317. variables. This routine pulls the variables out in raw form and puts them,
  318. unread, back into a file stream for saving or reading. Another use is the
  319. pathological case where there is more than one variable of the same name in
  320. the file; this routine returns the duplicates, whereas the standard reader
  321. will overwrite duplicates in the returned dictionary.
  322. The file pointer in `file_obj` will be undefined. File pointers for the
  323. returned file-like objects are set at 0.
  324. Parameters
  325. ----------
  326. file_obj : file-like
  327. file object containing mat file
  328. Returns
  329. -------
  330. named_mats : list
  331. list contains tuples of (name, BytesIO) where BytesIO is a file-like
  332. object containing mat file contents as for a single variable. The
  333. BytesIO contains a string with the original header and a single var. If
  334. ``var_file_obj`` is an individual BytesIO instance, then save as a mat
  335. file with something like ``open('test.mat',
  336. 'wb').write(var_file_obj.read())``
  337. Examples
  338. --------
  339. >>> import scipy.io
  340. >>> import numpy as np
  341. >>> from io import BytesIO
  342. >>> from scipy.io.matlab._mio5 import varmats_from_mat
  343. >>> mat_fileobj = BytesIO()
  344. >>> scipy.io.savemat(mat_fileobj, {'b': np.arange(10), 'a': 'a string'})
  345. >>> varmats = varmats_from_mat(mat_fileobj)
  346. >>> sorted([name for name, str_obj in varmats])
  347. ['a', 'b']
  348. """
  349. rdr = MatFile5Reader(file_obj)
  350. file_obj.seek(0)
  351. # Raw read of top-level file header
  352. hdr_len = MDTYPES[native_code]['dtypes']['file_header'].itemsize
  353. raw_hdr = file_obj.read(hdr_len)
  354. # Initialize variable reading
  355. file_obj.seek(0)
  356. rdr.initialize_read()
  357. rdr.read_file_header()
  358. next_position = file_obj.tell()
  359. named_mats = []
  360. while not rdr.end_of_stream():
  361. start_position = next_position
  362. hdr, next_position = rdr.read_var_header()
  363. name = 'None' if hdr.name is None else hdr.name.decode('latin1')
  364. # Read raw variable string
  365. file_obj.seek(start_position)
  366. byte_count = next_position - start_position
  367. var_str = file_obj.read(byte_count)
  368. # write to stringio object
  369. out_obj = BytesIO()
  370. out_obj.write(raw_hdr)
  371. out_obj.write(var_str)
  372. out_obj.seek(0)
  373. named_mats.append((name, out_obj))
  374. return named_mats
  375. class EmptyStructMarker:
  376. """ Class to indicate presence of empty matlab struct on output """
  377. def to_writeable(source):
  378. ''' Convert input object ``source`` to something we can write
  379. Parameters
  380. ----------
  381. source : object
  382. Returns
  383. -------
  384. arr : None or ndarray or EmptyStructMarker
  385. If `source` cannot be converted to something we can write to a matfile,
  386. return None. If `source` is equivalent to an empty dictionary, return
  387. ``EmptyStructMarker``. Otherwise return `source` converted to an
  388. ndarray with contents for writing to matfile.
  389. '''
  390. if isinstance(source, np.ndarray):
  391. return source
  392. if source is None:
  393. return None
  394. if hasattr(source, "__array__"):
  395. return np.asarray(source)
  396. # Objects that implement mappings
  397. is_mapping = (hasattr(source, 'keys') and hasattr(source, 'values') and
  398. hasattr(source, 'items'))
  399. # Objects that don't implement mappings, but do have dicts
  400. if isinstance(source, np.generic):
  401. # NumPy scalars are never mappings (PyPy issue workaround)
  402. pass
  403. elif not is_mapping and hasattr(source, '__dict__'):
  404. source = {key: value for key, value in source.__dict__.items()
  405. if not key.startswith('_')}
  406. is_mapping = True
  407. if is_mapping:
  408. dtype = []
  409. values = []
  410. for field, value in source.items():
  411. if isinstance(field, str):
  412. if field[0] not in '_0123456789':
  413. dtype.append((str(field), object))
  414. values.append(value)
  415. else:
  416. msg = (f"Starting field name with a underscore "
  417. f"or a digit ({field}) is ignored")
  418. warnings.warn(msg, MatWriteWarning, stacklevel=2)
  419. if dtype:
  420. return np.array([tuple(values)], dtype)
  421. else:
  422. return EmptyStructMarker
  423. # Next try and convert to an array
  424. try:
  425. narr = np.asanyarray(source)
  426. except ValueError:
  427. narr = np.asanyarray(source, dtype=object)
  428. if narr.dtype.type in (object, np.object_) and \
  429. narr.shape == () and narr == source:
  430. # No interesting conversion possible
  431. return None
  432. return narr
  433. # Native byte ordered dtypes for convenience for writers
  434. NDT_FILE_HDR = MDTYPES[native_code]['dtypes']['file_header']
  435. NDT_TAG_FULL = MDTYPES[native_code]['dtypes']['tag_full']
  436. NDT_TAG_SMALL = MDTYPES[native_code]['dtypes']['tag_smalldata']
  437. NDT_ARRAY_FLAGS = MDTYPES[native_code]['dtypes']['array_flags']
  438. class VarWriter5:
  439. ''' Generic matlab matrix writing class '''
  440. mat_tag = np.zeros((), NDT_TAG_FULL)
  441. mat_tag['mdtype'] = miMATRIX
  442. def __init__(self, file_writer):
  443. self.file_stream = file_writer.file_stream
  444. self.unicode_strings = file_writer.unicode_strings
  445. self.long_field_names = file_writer.long_field_names
  446. self.oned_as = file_writer.oned_as
  447. # These are used for top level writes, and unset after
  448. self._var_name = None
  449. self._var_is_global = False
  450. def write_bytes(self, arr):
  451. self.file_stream.write(arr.tobytes(order='F'))
  452. def write_string(self, s):
  453. self.file_stream.write(s)
  454. def write_element(self, arr, mdtype=None):
  455. ''' write tag and data '''
  456. if mdtype is None:
  457. mdtype = NP_TO_MTYPES[arr.dtype.str[1:]]
  458. # Array needs to be in native byte order
  459. if arr.dtype.byteorder == swapped_code:
  460. arr = arr.byteswap().view(arr.dtype.newbyteorder())
  461. byte_count = arr.size*arr.itemsize
  462. if byte_count <= 4:
  463. self.write_smalldata_element(arr, mdtype, byte_count)
  464. else:
  465. self.write_regular_element(arr, mdtype, byte_count)
  466. def write_smalldata_element(self, arr, mdtype, byte_count):
  467. # write tag with embedded data
  468. tag = np.zeros((), NDT_TAG_SMALL)
  469. tag['byte_count_mdtype'] = (byte_count << 16) + mdtype
  470. # if arr.tobytes is < 4, the element will be zero-padded as needed.
  471. tag['data'] = arr.tobytes(order='F')
  472. self.write_bytes(tag)
  473. def write_regular_element(self, arr, mdtype, byte_count):
  474. # write tag, data
  475. tag = np.zeros((), NDT_TAG_FULL)
  476. tag['mdtype'] = mdtype
  477. tag['byte_count'] = byte_count
  478. self.write_bytes(tag)
  479. self.write_bytes(arr)
  480. # pad to next 64-bit boundary
  481. bc_mod_8 = byte_count % 8
  482. if bc_mod_8:
  483. self.file_stream.write(b'\x00' * (8-bc_mod_8))
  484. def write_header(self,
  485. shape,
  486. mclass,
  487. is_complex=False,
  488. is_logical=False,
  489. nzmax=0):
  490. ''' Write header for given data options
  491. shape : sequence
  492. array shape
  493. mclass - mat5 matrix class
  494. is_complex - True if matrix is complex
  495. is_logical - True if matrix is logical
  496. nzmax - max non zero elements for sparse arrays
  497. We get the name and the global flag from the object, and reset
  498. them to defaults after we've used them
  499. '''
  500. # get name and is_global from one-shot object store
  501. name = self._var_name
  502. is_global = self._var_is_global
  503. # initialize the top-level matrix tag, store position
  504. self._mat_tag_pos = self.file_stream.tell()
  505. self.write_bytes(self.mat_tag)
  506. # write array flags (complex, global, logical, class, nzmax)
  507. af = np.zeros((), NDT_ARRAY_FLAGS)
  508. af['data_type'] = miUINT32
  509. af['byte_count'] = 8
  510. flags = is_complex << 3 | is_global << 2 | is_logical << 1
  511. af['flags_class'] = mclass | flags << 8
  512. af['nzmax'] = nzmax
  513. self.write_bytes(af)
  514. # shape
  515. self.write_element(np.array(shape, dtype='i4'))
  516. # write name
  517. name = np.asarray(name)
  518. if name == '': # empty string zero-terminated
  519. self.write_smalldata_element(name, miINT8, 0)
  520. else:
  521. self.write_element(name, miINT8)
  522. # reset the one-shot store to defaults
  523. self._var_name = ''
  524. self._var_is_global = False
  525. def update_matrix_tag(self, start_pos):
  526. curr_pos = self.file_stream.tell()
  527. self.file_stream.seek(start_pos)
  528. byte_count = curr_pos - start_pos - 8
  529. if byte_count >= 2**32:
  530. raise MatWriteError("Matrix too large to save with Matlab "
  531. "5 format")
  532. self.mat_tag['byte_count'] = byte_count
  533. self.write_bytes(self.mat_tag)
  534. self.file_stream.seek(curr_pos)
  535. def write_top(self, arr, name, is_global):
  536. """ Write variable at top level of mat file
  537. Parameters
  538. ----------
  539. arr : array_like
  540. array-like object to create writer for
  541. name : str, optional
  542. name as it will appear in matlab workspace
  543. default is empty string
  544. is_global : {False, True}, optional
  545. whether variable will be global on load into matlab
  546. """
  547. # these are set before the top-level header write, and unset at
  548. # the end of the same write, because they do not apply for lower levels
  549. self._var_is_global = is_global
  550. self._var_name = name
  551. # write the header and data
  552. self.write(arr)
  553. def write(self, arr):
  554. ''' Write `arr` to stream at top and sub levels
  555. Parameters
  556. ----------
  557. arr : array_like
  558. array-like object to create writer for
  559. '''
  560. # store position, so we can update the matrix tag
  561. mat_tag_pos = self.file_stream.tell()
  562. # First check if these are sparse
  563. if scipy.sparse.issparse(arr):
  564. self.write_sparse(arr)
  565. self.update_matrix_tag(mat_tag_pos)
  566. return
  567. # Try to convert things that aren't arrays
  568. narr = to_writeable(arr)
  569. if narr is None:
  570. raise TypeError(f'Could not convert {arr} (type {type(arr)}) to array')
  571. if isinstance(narr, MatlabObject):
  572. self.write_object(narr)
  573. elif isinstance(narr, MatlabFunction):
  574. raise MatWriteError('Cannot write matlab functions')
  575. elif narr is EmptyStructMarker: # empty struct array
  576. self.write_empty_struct()
  577. elif narr.dtype.fields: # struct array
  578. self.write_struct(narr)
  579. elif narr.dtype.hasobject: # cell array
  580. self.write_cells(narr)
  581. elif narr.dtype.kind in ('U', 'S'):
  582. if self.unicode_strings:
  583. codec = 'UTF8'
  584. else:
  585. codec = 'ascii'
  586. self.write_char(narr, codec)
  587. else:
  588. self.write_numeric(narr)
  589. self.update_matrix_tag(mat_tag_pos)
  590. def write_numeric(self, arr):
  591. imagf = arr.dtype.kind == 'c'
  592. logif = arr.dtype.kind == 'b'
  593. try:
  594. mclass = NP_TO_MXTYPES[arr.dtype.str[1:]]
  595. except KeyError:
  596. # No matching matlab type, probably complex256 / float128 / float96
  597. # Cast data to complex128 / float64.
  598. if imagf:
  599. arr = arr.astype('c128')
  600. elif logif:
  601. arr = arr.astype('i1') # Should only contain 0/1
  602. else:
  603. arr = arr.astype('f8')
  604. mclass = mxDOUBLE_CLASS
  605. self.write_header(matdims(arr, self.oned_as),
  606. mclass,
  607. is_complex=imagf,
  608. is_logical=logif)
  609. if imagf:
  610. self.write_element(arr.real)
  611. self.write_element(arr.imag)
  612. else:
  613. self.write_element(arr)
  614. def write_char(self, arr, codec='ascii'):
  615. ''' Write string array `arr` with given `codec`
  616. '''
  617. if arr.size == 0 or np.all(arr == ''):
  618. # This an empty string array or a string array containing
  619. # only empty strings. Matlab cannot distinguish between a
  620. # string array that is empty, and a string array containing
  621. # only empty strings, because it stores strings as arrays of
  622. # char. There is no way of having an array of char that is
  623. # not empty, but contains an empty string. We have to
  624. # special-case the array-with-empty-strings because even
  625. # empty strings have zero padding, which would otherwise
  626. # appear in matlab as a string with a space.
  627. shape = (0,) * np.max([arr.ndim, 2])
  628. self.write_header(shape, mxCHAR_CLASS)
  629. self.write_smalldata_element(arr, miUTF8, 0)
  630. return
  631. # non-empty string.
  632. #
  633. # Convert to char array
  634. arr = arr_to_chars(arr)
  635. # We have to write the shape directly, because we are going
  636. # recode the characters, and the resulting stream of chars
  637. # may have a different length
  638. shape = arr.shape
  639. self.write_header(shape, mxCHAR_CLASS)
  640. if arr.dtype.kind == 'U' and arr.size:
  641. # Make one long string from all the characters. We need to
  642. # transpose here, because we're flattening the array, before
  643. # we write the bytes. The bytes have to be written in
  644. # Fortran order.
  645. n_chars = math.prod(shape)
  646. st_arr = np.ndarray(shape=(),
  647. dtype=arr_dtype_number(arr, n_chars),
  648. buffer=arr.T.copy()) # Fortran order
  649. # Recode with codec to give byte string
  650. st = st_arr.item().encode(codec)
  651. # Reconstruct as 1-D byte array
  652. arr = np.ndarray(shape=(len(st),),
  653. dtype='S1',
  654. buffer=st)
  655. self.write_element(arr, mdtype=miUTF8)
  656. def write_sparse(self, arr):
  657. ''' Sparse matrices are 2D
  658. '''
  659. A = arr.tocsc() # convert to sparse CSC format
  660. A.sort_indices() # MATLAB expects sorted row indices
  661. is_complex = (A.dtype.kind == 'c')
  662. is_logical = (A.dtype.kind == 'b')
  663. nz = A.nnz
  664. self.write_header(matdims(arr, self.oned_as),
  665. mxSPARSE_CLASS,
  666. is_complex=is_complex,
  667. is_logical=is_logical,
  668. # matlab won't load file with 0 nzmax
  669. nzmax=1 if nz == 0 else nz)
  670. self.write_element(A.indices.astype('i4'))
  671. self.write_element(A.indptr.astype('i4'))
  672. self.write_element(A.data.real)
  673. if is_complex:
  674. self.write_element(A.data.imag)
  675. def write_cells(self, arr):
  676. self.write_header(matdims(arr, self.oned_as),
  677. mxCELL_CLASS)
  678. # loop over data, column major
  679. A = np.atleast_2d(arr).flatten('F')
  680. for el in A:
  681. self.write(el)
  682. def write_empty_struct(self):
  683. self.write_header((1, 1), mxSTRUCT_CLASS)
  684. # max field name length set to 1 in an example matlab struct
  685. self.write_element(np.array(1, dtype=np.int32))
  686. # Field names element is empty
  687. self.write_element(np.array([], dtype=np.int8))
  688. def write_struct(self, arr):
  689. self.write_header(matdims(arr, self.oned_as),
  690. mxSTRUCT_CLASS)
  691. self._write_items(arr)
  692. def _write_items(self, arr):
  693. # write fieldnames
  694. fieldnames = [f[0] for f in arr.dtype.descr]
  695. length = max([len(fieldname) for fieldname in fieldnames])+1
  696. max_length = (self.long_field_names and 64) or 32
  697. if length > max_length:
  698. raise ValueError(
  699. f"Field names are restricted to {max_length - 1} characters"
  700. )
  701. self.write_element(np.array([length], dtype='i4'))
  702. self.write_element(np.array(fieldnames, dtype=f'S{length}'), mdtype=miINT8)
  703. A = np.atleast_2d(arr).flatten('F')
  704. for el in A:
  705. for f in fieldnames:
  706. self.write(el[f])
  707. def write_object(self, arr):
  708. '''Same as writing structs, except different mx class, and extra
  709. classname element after header
  710. '''
  711. self.write_header(matdims(arr, self.oned_as),
  712. mxOBJECT_CLASS)
  713. self.write_element(np.array(arr.classname, dtype='S'),
  714. mdtype=miINT8)
  715. self._write_items(arr)
  716. class MatFile5Writer:
  717. ''' Class for writing mat5 files '''
  718. @docfiller
  719. def __init__(self, file_stream,
  720. do_compression=False,
  721. unicode_strings=False,
  722. global_vars=None,
  723. long_field_names=False,
  724. oned_as='row'):
  725. ''' Initialize writer for matlab 5 format files
  726. Parameters
  727. ----------
  728. %(do_compression)s
  729. %(unicode_strings)s
  730. global_vars : None or sequence of strings, optional
  731. Names of variables to be marked as global for matlab
  732. %(long_fields)s
  733. %(oned_as)s
  734. '''
  735. self.file_stream = file_stream
  736. self.do_compression = do_compression
  737. self.unicode_strings = unicode_strings
  738. if global_vars:
  739. self.global_vars = global_vars
  740. else:
  741. self.global_vars = []
  742. self.long_field_names = long_field_names
  743. self.oned_as = oned_as
  744. self._matrix_writer = None
  745. def write_file_header(self):
  746. # write header
  747. hdr = np.zeros((), NDT_FILE_HDR)
  748. hdr['description'] = (f'MATLAB 5.0 MAT-file Platform: {os.name}, '
  749. f'Created on: {time.asctime()}')
  750. hdr['version'] = 0x0100
  751. hdr['endian_test'] = np.ndarray(shape=(),
  752. dtype='S2',
  753. buffer=np.uint16(0x4d49))
  754. self.file_stream.write(hdr.tobytes())
  755. def put_variables(self, mdict, write_header=None):
  756. ''' Write variables in `mdict` to stream
  757. Parameters
  758. ----------
  759. mdict : mapping
  760. mapping with method ``items`` returns name, contents pairs where
  761. ``name`` which will appear in the matlab workspace in file load, and
  762. ``contents`` is something writeable to a matlab file, such as a NumPy
  763. array.
  764. write_header : {None, True, False}, optional
  765. If True, then write the matlab file header before writing the
  766. variables. If None (the default) then write the file header
  767. if we are at position 0 in the stream. By setting False
  768. here, and setting the stream position to the end of the file,
  769. you can append variables to a matlab file
  770. '''
  771. # write header if requested, or None and start of file
  772. if write_header is None:
  773. write_header = self.file_stream.tell() == 0
  774. if write_header:
  775. self.write_file_header()
  776. self._matrix_writer = VarWriter5(self)
  777. for name, var in mdict.items():
  778. if name[0] == '_':
  779. msg = (f"Starting field name with a "
  780. f"underscore ({name}) is ignored")
  781. warnings.warn(msg, MatWriteWarning, stacklevel=2)
  782. continue
  783. is_global = name in self.global_vars
  784. if self.do_compression:
  785. stream = BytesIO()
  786. self._matrix_writer.file_stream = stream
  787. self._matrix_writer.write_top(var, name.encode('latin1'), is_global)
  788. out_str = zlib.compress(stream.getvalue())
  789. tag = np.empty((), NDT_TAG_FULL)
  790. tag['mdtype'] = miCOMPRESSED
  791. tag['byte_count'] = len(out_str)
  792. self.file_stream.write(tag.tobytes())
  793. self.file_stream.write(out_str)
  794. else: # not compressing
  795. self._matrix_writer.write_top(var, name.encode('latin1'), is_global)