common.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # This file is part of h5py, a Python interface to the HDF5 library.
  2. #
  3. # http://www.h5py.org
  4. #
  5. # Copyright 2008-2013 Andrew Collette and contributors
  6. #
  7. # License: Standard 3-clause BSD; see "license.txt" for full license terms
  8. # and contributor agreement.
  9. import sys
  10. import os
  11. import shutil
  12. import inspect
  13. import tempfile
  14. import threading
  15. import subprocess
  16. from contextlib import contextmanager
  17. from functools import wraps
  18. import numpy as np
  19. from numpy.lib.recfunctions import repack_fields
  20. import h5py
  21. from h5py._objects import phil
  22. import unittest as ut
  23. # Check if non-ascii filenames are supported
  24. # Evidently this is the most reliable way to check
  25. # See also h5py issue #263 and ipython #466
  26. # To test for this, run the testsuite with LC_ALL=C
  27. try:
  28. testfile, fname = tempfile.mkstemp(chr(0x03b7))
  29. except UnicodeError:
  30. UNICODE_FILENAMES = False
  31. else:
  32. UNICODE_FILENAMES = True
  33. os.close(testfile)
  34. os.unlink(fname)
  35. del fname
  36. del testfile
  37. class TestCase(ut.TestCase):
  38. """
  39. Base class for unit tests.
  40. """
  41. @classmethod
  42. def setUpClass(cls):
  43. cls.tempdir = tempfile.mkdtemp(prefix='h5py-test_')
  44. @classmethod
  45. def tearDownClass(cls):
  46. shutil.rmtree(cls.tempdir)
  47. def mktemp(self, suffix='.hdf5', prefix='tmp', dir=None):
  48. if dir is None:
  49. dir = self.tempdir
  50. return tempfile.mktemp(suffix, make_name(prefix), dir=dir)
  51. def mktemp_mpi(self, comm=None, suffix='.hdf5', prefix='', dir=None):
  52. if comm is None:
  53. from mpi4py import MPI
  54. comm = MPI.COMM_WORLD
  55. fname = None
  56. if comm.Get_rank() == 0:
  57. fname = self.mktemp(suffix, prefix, dir)
  58. fname = comm.bcast(fname, 0)
  59. return fname
  60. def setUp(self):
  61. self.f = h5py.File(self.mktemp(), 'w')
  62. def tearDown(self):
  63. try:
  64. if self.f:
  65. self.f.close()
  66. except:
  67. pass
  68. def assertSameElements(self, a, b):
  69. for x in a:
  70. match = False
  71. for y in b:
  72. if x == y:
  73. match = True
  74. if not match:
  75. raise AssertionError("Item '%s' appears in a but not b" % x)
  76. for x in b:
  77. match = False
  78. for y in a:
  79. if x == y:
  80. match = True
  81. if not match:
  82. raise AssertionError("Item '%s' appears in b but not a" % x)
  83. def assertArrayEqual(self, dset, arr, message=None, precision=None, check_alignment=True):
  84. """ Make sure dset and arr have the same shape, dtype and contents, to
  85. within the given precision, optionally ignoring differences in dtype alignment.
  86. Note that dset may be a NumPy array or an HDF5 dataset.
  87. """
  88. if precision is None:
  89. precision = 1e-5
  90. if message is None:
  91. message = ''
  92. else:
  93. message = ' (%s)' % message
  94. if np.isscalar(dset) or np.isscalar(arr):
  95. assert np.isscalar(dset) and np.isscalar(arr), \
  96. 'Scalar/array mismatch ("%r" vs "%r")%s' % (dset, arr, message)
  97. dset = np.asarray(dset)
  98. arr = np.asarray(arr)
  99. assert dset.shape == arr.shape, \
  100. "Shape mismatch (%s vs %s)%s" % (dset.shape, arr.shape, message)
  101. if dset.dtype != arr.dtype:
  102. if check_alignment:
  103. normalized_dset_dtype = dset.dtype
  104. normalized_arr_dtype = arr.dtype
  105. else:
  106. normalized_dset_dtype = repack_fields(dset.dtype)
  107. normalized_arr_dtype = repack_fields(arr.dtype)
  108. assert normalized_dset_dtype == normalized_arr_dtype, \
  109. "Dtype mismatch (%s vs %s)%s" % (normalized_dset_dtype, normalized_arr_dtype, message)
  110. if not check_alignment:
  111. if normalized_dset_dtype != dset.dtype:
  112. dset = repack_fields(np.asarray(dset))
  113. if normalized_arr_dtype != arr.dtype:
  114. arr = repack_fields(np.asarray(arr))
  115. if arr.dtype.names is not None:
  116. for n in arr.dtype.names:
  117. message = '[FIELD %s] %s' % (n, message)
  118. self.assertArrayEqual(dset[n], arr[n], message=message, precision=precision, check_alignment=check_alignment)
  119. elif arr.dtype.kind in ('i', 'f'):
  120. assert np.all(np.abs(dset[...] - arr[...]) < precision), \
  121. "Arrays differ by more than %.3f%s" % (precision, message)
  122. elif arr.dtype.kind == 'O':
  123. for v1, v2 in zip(dset.flat, arr.flat, strict=True):
  124. self.assertArrayEqual(v1, v2, message=message, precision=precision, check_alignment=check_alignment)
  125. else:
  126. assert np.all(dset[...] == arr[...]), \
  127. "Arrays are not equal (dtype %s) %s" % (arr.dtype.str, message)
  128. def assertNumpyBehavior(self, dset, arr, s, skip_fast_reader=False):
  129. """ Apply slicing arguments "s" to both dset and arr.
  130. Succeeds if the results of the slicing are identical, or the
  131. exception raised is of the same type for both.
  132. "arr" must be a Numpy array; "dset" may be a NumPy array or dataset.
  133. """
  134. exc = None
  135. try:
  136. arr_result = arr[s]
  137. except Exception as e:
  138. exc = type(e)
  139. s_fast = s if isinstance(s, tuple) else (s,)
  140. if exc is None:
  141. self.assertArrayEqual(dset[s], arr_result)
  142. if not skip_fast_reader:
  143. with phil:
  144. self.assertArrayEqual(
  145. dset._fast_reader.read(s_fast),
  146. arr_result,
  147. )
  148. else:
  149. with self.assertRaises(exc):
  150. dset[s]
  151. if not skip_fast_reader:
  152. with self.assertRaises(exc), phil:
  153. dset._fast_reader.read(s_fast)
  154. NUMPY_RELEASE_VERSION = tuple([int(i) for i in np.__version__.split(".")[0:2]])
  155. @contextmanager
  156. def closed_tempfile(suffix='', text=None):
  157. """
  158. Context manager which yields the path to a closed temporary file with the
  159. suffix `suffix`. The file will be deleted on exiting the context. An
  160. additional argument `text` can be provided to have the file contain `text`.
  161. """
  162. with tempfile.NamedTemporaryFile(
  163. 'w+t', suffix=suffix, delete=False
  164. ) as test_file:
  165. file_name = test_file.name
  166. if text is not None:
  167. test_file.write(text)
  168. test_file.flush()
  169. yield file_name
  170. shutil.rmtree(file_name, ignore_errors=True)
  171. def insubprocess(f):
  172. """Runs a test in its own subprocess"""
  173. @wraps(f)
  174. def wrapper(request, *args, **kwargs):
  175. curr_test = inspect.getsourcefile(f) + "::" + request.node.name
  176. # get block around test name
  177. insub = "IN_SUBPROCESS_" + curr_test
  178. for c in "/\\,:.":
  179. insub = insub.replace(c, "_")
  180. defined = os.environ.get(insub, None)
  181. if defined:
  182. # We're already running in a subprocess
  183. return f(request, *args, **kwargs)
  184. else:
  185. # Spawn a new interpreter and run pytest in it
  186. env = os.environ.copy()
  187. env[insub] = '1'
  188. env.update(getattr(f, 'subproc_env', {}))
  189. with closed_tempfile() as stdout:
  190. with open(stdout, 'w+t') as fh:
  191. rtn = subprocess.call([sys.executable, '-m', 'pytest', curr_test],
  192. stdout=fh, stderr=fh, env=env)
  193. with open(stdout, 'rt') as fh:
  194. out = fh.read()
  195. assert rtn == 0, "\n" + out
  196. return wrapper
  197. def subproc_env(d):
  198. """Set environment variables for the @insubprocess decorator"""
  199. def decorator(f):
  200. f.subproc_env = d
  201. return f
  202. return decorator
  203. MAIN_THREAD_ID = threading.get_ident()
  204. def make_name(template_or_prefix: str = "foo", /) -> str:
  205. """Return a static name, to be used e.g. as dataset name.
  206. When running in pytest-run-parallel, append a thread ID to the name.
  207. This allows running tests on shared resources, e.g. two threads can attempt to write
  208. to separate datasets on the same File at the same time (even though the actual
  209. writes will be serialized by the `phil` lock).
  210. Calling this function twice from the same thread will return the same name.
  211. Parameters
  212. ----------
  213. template_or_prefix
  214. Either a prefix to which potentially append the thread ID, or a template
  215. containing exactly one "{}" to be replaced with the thread ID.
  216. """
  217. tid = threading.get_ident()
  218. suffix = "" if tid == MAIN_THREAD_ID else f"-{tid}"
  219. if "{}" in template_or_prefix:
  220. return template_or_prefix.format(suffix)
  221. else:
  222. return template_or_prefix + suffix
  223. def is_main_thread() -> bool:
  224. """Return True if the test calling this function is being executed
  225. in the main thread; False otherwise.
  226. This can be used to detect when a test is running in pytest-run-parallel.
  227. that spawns multiple separate threads to run the tests.
  228. """
  229. tid = threading.get_ident()
  230. return tid == MAIN_THREAD_ID