test_fortran.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. ''' Tests for fortran sequential files '''
  2. import tempfile
  3. import shutil
  4. import os
  5. from os import path
  6. from glob import iglob
  7. import threading
  8. import re
  9. from numpy.testing import assert_equal, assert_allclose
  10. import numpy as np
  11. import pytest
  12. from scipy.io import (FortranFile,
  13. FortranEOFError,
  14. FortranFormattingError)
  15. DATA_PATH = path.join(path.dirname(__file__), 'data')
  16. @pytest.fixture
  17. def io_lock():
  18. return threading.Lock()
  19. def test_fortranfiles_read(io_lock):
  20. for filename in iglob(path.join(DATA_PATH, "fortran-*-*x*x*.dat")):
  21. m = re.search(r'fortran-([^-]+)-(\d+)x(\d+)x(\d+).dat', filename, re.I)
  22. if not m:
  23. raise RuntimeError(f"Couldn't match {filename} filename to regex")
  24. dims = (int(m.group(2)), int(m.group(3)), int(m.group(4)))
  25. dtype = m.group(1).replace('s', '<')
  26. with io_lock:
  27. f = FortranFile(filename, 'r', '<u4')
  28. data = f.read_record(dtype=dtype).reshape(dims, order='F')
  29. f.close()
  30. expected = np.arange(np.prod(dims)).reshape(dims).astype(dtype)
  31. assert_equal(data, expected)
  32. def test_fortranfiles_mixed_record(io_lock):
  33. filename = path.join(DATA_PATH, "fortran-mixed.dat")
  34. with io_lock:
  35. with FortranFile(filename, 'r', '<u4') as f:
  36. record = f.read_record('<i4,<f4,<i8,2<f8')
  37. assert_equal(record['f0'][0], 1)
  38. assert_allclose(record['f1'][0], 2.3)
  39. assert_equal(record['f2'][0], 4)
  40. assert_allclose(record['f3'][0], [5.6, 7.8])
  41. def test_fortranfiles_write():
  42. for filename in iglob(path.join(DATA_PATH, "fortran-*-*x*x*.dat")):
  43. m = re.search(r'fortran-([^-]+)-(\d+)x(\d+)x(\d+).dat', filename, re.I)
  44. if not m:
  45. raise RuntimeError(f"Couldn't match {filename} filename to regex")
  46. dims = (int(m.group(2)), int(m.group(3)), int(m.group(4)))
  47. dtype = m.group(1).replace('s', '<')
  48. data = np.arange(np.prod(dims)).reshape(dims).astype(dtype)
  49. tmpdir = tempfile.mkdtemp()
  50. try:
  51. testFile = path.join(str(threading.get_native_id()),
  52. tmpdir,path.basename(filename))
  53. f = FortranFile(testFile, 'w','<u4')
  54. f.write_record(data.T)
  55. f.close()
  56. originalfile = open(filename, 'rb')
  57. newfile = open(testFile, 'rb')
  58. assert_equal(originalfile.read(), newfile.read(),
  59. err_msg=filename)
  60. originalfile.close()
  61. newfile.close()
  62. finally:
  63. shutil.rmtree(tmpdir)
  64. def test_fortranfile_read_mixed_record(io_lock):
  65. # The data file fortran-3x3d-2i.dat contains the program that
  66. # produced it at the end.
  67. #
  68. # double precision :: a(3,3)
  69. # integer :: b(2)
  70. # ...
  71. # open(1, file='fortran-3x3d-2i.dat', form='unformatted')
  72. # write(1) a, b
  73. # close(1)
  74. #
  75. filename = path.join(DATA_PATH, "fortran-3x3d-2i.dat")
  76. with io_lock:
  77. with FortranFile(filename, 'r', '<u4') as f:
  78. record = f.read_record('(3,3)<f8', '2<i4')
  79. ax = np.arange(3*3).reshape(3, 3).astype(np.float64)
  80. bx = np.array([-1, -2], dtype=np.int32)
  81. assert_equal(record[0], ax.T)
  82. assert_equal(record[1], bx.T)
  83. def test_fortranfile_write_mixed_record(tmpdir):
  84. tf = path.join(str(tmpdir), str(threading.get_native_id()), 'test.dat')
  85. os.makedirs(path.dirname(tf), exist_ok=True)
  86. r1 = (('f4', 'f4', 'i4'), (np.float32(2), np.float32(3), np.int32(100)))
  87. r2 = (('4f4', '(3,3)f4', '8i4'),
  88. (np.random.randint(255, size=[4]).astype(np.float32),
  89. np.random.randint(255, size=[3, 3]).astype(np.float32),
  90. np.random.randint(255, size=[8]).astype(np.int32)))
  91. records = [r1, r2]
  92. for dtype, a in records:
  93. with FortranFile(tf, 'w') as f:
  94. f.write_record(*a)
  95. with FortranFile(tf, 'r') as f:
  96. b = f.read_record(*dtype)
  97. assert_equal(len(a), len(b))
  98. for aa, bb in zip(a, b):
  99. assert_equal(bb, aa)
  100. def read_unformatted_double(m, n, k, filename):
  101. """
  102. Read a Fortran-style unformatted binary file written with a single write() call,
  103. assuming it wraps the data with 4-byte record markers.
  104. Returns:
  105. np.ndarray of shape (m, n, k) with dtype float64
  106. Reference:
  107. Fortran implementation:
  108. https://github.com/scipy/scipy/blob/maintenance/1.15.x/scipy/io/_test_fortran.f#L1-L9
  109. """
  110. with open(filename.strip(), 'rb') as f:
  111. f.read(4) # Skip initial 4-byte record marker
  112. data = np.fromfile(f, dtype=np.float64, count=m * n * k)
  113. f.read(4) # Skip trailing 4-byte record marker
  114. if data.size != m * n * k:
  115. raise ValueError(f"Expected {m*n*k} elements, got {data.size}")
  116. return data.reshape((m, n, k), order='F') # Fortran column-major order
  117. def read_unformatted_mixed(m, n, k, filename):
  118. """
  119. Read a Fortran unformatted binary file that contains a mix of:
  120. - a double precision array a(m, n)
  121. - an integer array b(k)
  122. Assumes a single write(10) a, b was used and file is wrapped
  123. with Fortran record markers.
  124. Returns:
  125. a: np.ndarray of shape (m, n) with dtype float64
  126. b: np.ndarray of shape (k,) with dtype int32
  127. Reference:
  128. Fortran implementation:
  129. https://github.com/scipy/scipy/blob/maintenance/1.15.x/scipy/io/_test_fortran.f#L21-L30
  130. """
  131. with open(filename.strip(), 'rb') as f:
  132. f.read(4) # Skip initial 4-byte record marker
  133. # Read a(m,n): total m*n float64 values
  134. a_flat = np.fromfile(f, dtype=np.float64, count=m * n)
  135. # Read b(k): total k int32 values (assuming Fortran default integer*4)
  136. b = np.fromfile(f, dtype=np.int32, count=k)
  137. f.read(4) # Skip trailing 4-byte record marker
  138. # Reshape a to (m,n) Fortran-style
  139. a = a_flat.reshape((m, n), order='F')
  140. return a, b
  141. def read_unformatted_int(m, n, k, filename):
  142. """
  143. Read a Fortran unformatted binary file
  144. containing a 3D integer array (m, n, k).
  145. Assumes the array is written with a single
  146. write(10) a and wrapped with record markers.
  147. Returns:
  148. np.ndarray: 3D array of shape (m, n, k) with dtype int32
  149. Reference:
  150. Fortran implementation:
  151. https://github.com/scipy/scipy/blob/maintenance/1.15.x/scipy/io/_test_fortran.f#L11-L19
  152. """
  153. with open(filename.strip(), 'rb') as f:
  154. f.read(4) # Skip Fortran record marker at start
  155. # Read m*n*k integers (Fortran default = 4 bytes per integer)
  156. data = np.fromfile(f, dtype=np.int32, count=m * n * k)
  157. f.read(4) # Skip Fortran record marker at end
  158. if data.size != m * n * k:
  159. raise ValueError(f"Expected {m*n*k} elements, got {data.size}")
  160. return data.reshape((m, n, k), order='F') # Fortran-style column-major order
  161. def test_fortran_roundtrip(tmpdir, io_lock):
  162. filename = path.join(str(tmpdir), str(threading.get_native_id()),
  163. 'test.dat')
  164. os.makedirs(path.dirname(filename), exist_ok=True)
  165. rng = np.random.RandomState(1)
  166. # double precision
  167. m, n, k = 5, 3, 2
  168. a = rng.randn(m, n, k)
  169. with FortranFile(filename, 'w') as f:
  170. f.write_record(a.T)
  171. with io_lock:
  172. a2 = read_unformatted_double(m, n, k, filename)
  173. with FortranFile(filename, 'r') as f:
  174. a3 = f.read_record('(2,3,5)f8').T
  175. assert_equal(a2, a)
  176. assert_equal(a3, a)
  177. # integer
  178. m, n, k = 5, 3, 2
  179. a = rng.randn(m, n, k).astype(np.int32)
  180. with FortranFile(filename, 'w') as f:
  181. f.write_record(a.T)
  182. with io_lock:
  183. a2 = read_unformatted_int(m, n, k, filename)
  184. with FortranFile(filename, 'r') as f:
  185. a3 = f.read_record('(2,3,5)i4').T
  186. assert_equal(a2, a)
  187. assert_equal(a3, a)
  188. # mixed
  189. m, n, k = 5, 3, 2
  190. a = rng.randn(m, n)
  191. b = rng.randn(k).astype(np.intc)
  192. with FortranFile(filename, 'w') as f:
  193. f.write_record(a.T, b.T)
  194. with io_lock:
  195. a2, b2 = read_unformatted_mixed(m, n, k, filename)
  196. with FortranFile(filename, 'r') as f:
  197. a3, b3 = f.read_record('(3,5)f8', '2i4')
  198. a3 = a3.T
  199. assert_equal(a2, a)
  200. assert_equal(a3, a)
  201. assert_equal(b2, b)
  202. assert_equal(b3, b)
  203. def test_fortran_eof_ok(tmpdir):
  204. filename = path.join(str(tmpdir), str(threading.get_native_id()),
  205. "scratch")
  206. os.makedirs(path.dirname(filename), exist_ok=True)
  207. rng = np.random.RandomState(1)
  208. with FortranFile(filename, 'w') as f:
  209. f.write_record(rng.randn(5))
  210. f.write_record(rng.randn(3))
  211. with FortranFile(filename, 'r') as f:
  212. assert len(f.read_reals()) == 5
  213. assert len(f.read_reals()) == 3
  214. with pytest.raises(FortranEOFError):
  215. f.read_reals()
  216. def test_fortran_eof_broken_size(tmpdir):
  217. filename = path.join(str(tmpdir), str(threading.get_native_id()),
  218. "scratch")
  219. os.makedirs(path.dirname(filename), exist_ok=True)
  220. rng = np.random.RandomState(1)
  221. with FortranFile(filename, 'w') as f:
  222. f.write_record(rng.randn(5))
  223. f.write_record(rng.randn(3))
  224. with open(filename, "ab") as f:
  225. f.write(b"\xff")
  226. with FortranFile(filename, 'r') as f:
  227. assert len(f.read_reals()) == 5
  228. assert len(f.read_reals()) == 3
  229. with pytest.raises(FortranFormattingError):
  230. f.read_reals()
  231. def test_fortran_bogus_size(tmpdir):
  232. filename = path.join(str(tmpdir), str(threading.get_native_id()),
  233. "scratch")
  234. os.makedirs(path.dirname(filename), exist_ok=True)
  235. rng = np.random.RandomState(1)
  236. with FortranFile(filename, 'w') as f:
  237. f.write_record(rng.randn(5))
  238. f.write_record(rng.randn(3))
  239. with open(filename, "w+b") as f:
  240. f.write(b"\xff\xff")
  241. with FortranFile(filename, 'r') as f:
  242. with pytest.raises(FortranFormattingError):
  243. f.read_reals()
  244. def test_fortran_eof_broken_record(tmpdir):
  245. filename = path.join(str(tmpdir), str(threading.get_native_id()),
  246. "scratch")
  247. os.makedirs(path.dirname(filename), exist_ok=True)
  248. rng = np.random.RandomState(1)
  249. with FortranFile(filename, 'w') as f:
  250. f.write_record(rng.randn(5))
  251. f.write_record(rng.randn(3))
  252. with open(filename, "ab") as f:
  253. f.truncate(path.getsize(filename)-20)
  254. with FortranFile(filename, 'r') as f:
  255. assert len(f.read_reals()) == 5
  256. with pytest.raises(FortranFormattingError):
  257. f.read_reals()
  258. def test_fortran_eof_multidimensional(tmpdir):
  259. filename = path.join(str(tmpdir), str(threading.get_native_id()),
  260. "scratch")
  261. os.makedirs(path.dirname(filename), exist_ok=True)
  262. n, m, q = 3, 5, 7
  263. dt = np.dtype([("field", np.float64, (n, m))])
  264. a = np.zeros(q, dtype=dt)
  265. with FortranFile(filename, 'w') as f:
  266. f.write_record(a[0])
  267. f.write_record(a)
  268. f.write_record(a)
  269. with open(filename, "ab") as f:
  270. f.truncate(path.getsize(filename)-20)
  271. with FortranFile(filename, 'r') as f:
  272. assert len(f.read_record(dtype=dt)) == 1
  273. assert len(f.read_record(dtype=dt)) == q
  274. with pytest.raises(FortranFormattingError):
  275. f.read_record(dtype=dt)