test_native_complex.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # This file is part of h5py, a Python interface to the HDF5 library.
  2. #
  3. # http://www.h5py.org
  4. #
  5. # Copyright 2008-2013 Andrew Collette and contributors
  6. #
  7. # License: Standard 3-clause BSD; see "license.txt" for full license terms
  8. # and contributor agreement.
  9. """
  10. Testing native complex number datatypes.
  11. """
  12. import sys
  13. import numpy as np
  14. import pytest
  15. import h5py
  16. from h5py.h5 import get_config # type: ignore
  17. from h5py import h5t
  18. from .common import make_name
  19. from .data_files import get_data_file_path
  20. cfg = get_config()
  21. pytestmark = [
  22. pytest.mark.skipif(
  23. h5py.version.hdf5_version_tuple < (2, 0, 0),
  24. reason="Requires HDF5 >= 2.0",
  25. ),
  26. pytest.mark.skipif(
  27. not cfg.has_native_complex,
  28. reason="Native HDF5 complex number datatypes not available",
  29. ),
  30. ]
  31. REQUIRE_NUMPY_COMPLEX256 = pytest.mark.skipif(
  32. not hasattr(np, "complex256"),
  33. reason="complex256 type is not available in numpy",
  34. )
  35. NATIVE_BYTE_ORDER = h5t.ORDER_BE if sys.byteorder == "big" else h5t.ORDER_LE
  36. @pytest.mark.parametrize(
  37. "h5type_str, size, order",
  38. [
  39. pytest.param("COMPLEX_IEEE_F16LE", 4, h5t.ORDER_LE),
  40. pytest.param("COMPLEX_IEEE_F16BE", 4, h5t.ORDER_BE, id="F16BE"),
  41. pytest.param("COMPLEX_IEEE_F32LE", 8, h5t.ORDER_LE, id="F32LE"),
  42. pytest.param("COMPLEX_IEEE_F32BE", 8, h5t.ORDER_BE, id="F32BE"),
  43. pytest.param("COMPLEX_IEEE_F64LE", 16, h5t.ORDER_LE, id="F64LE"),
  44. pytest.param("COMPLEX_IEEE_F64BE", 16, h5t.ORDER_BE, id="F64BE"),
  45. pytest.param("NATIVE_FLOAT_COMPLEX", 8, NATIVE_BYTE_ORDER, id="float-native"),
  46. pytest.param("NATIVE_DOUBLE_COMPLEX", 16, NATIVE_BYTE_ORDER, id="double-native"),
  47. pytest.param(
  48. "NATIVE_LDOUBLE_COMPLEX", 32, NATIVE_BYTE_ORDER,
  49. id="long-native",
  50. marks=REQUIRE_NUMPY_COMPLEX256,
  51. )
  52. ]
  53. )
  54. def test_hdf5_dtype(h5type_str, size, order):
  55. """Low-level checks of HDF5 native complex number datatypes"""
  56. h5type = getattr(h5t, h5type_str)
  57. assert isinstance(h5type, h5t.TypeComplexID)
  58. assert h5type.get_size() == size
  59. assert h5type.get_order() == order
  60. H5_VS_NUMPY_DTYPES = [
  61. pytest.param("COMPLEX_IEEE_F32LE", "<c8", id="<c8"),
  62. pytest.param("COMPLEX_IEEE_F32BE", ">c8", id=">c8"),
  63. pytest.param("COMPLEX_IEEE_F64LE", "<c16", id="<c16"),
  64. pytest.param("COMPLEX_IEEE_F64BE", ">c16", id=">c16"),
  65. pytest.param("NATIVE_FLOAT_COMPLEX", "=c8", id="=c8"),
  66. pytest.param("NATIVE_DOUBLE_COMPLEX", "=c16", id="=c16"),
  67. pytest.param(
  68. "NATIVE_LDOUBLE_COMPLEX", "=c32", id="=c32",
  69. marks=REQUIRE_NUMPY_COMPLEX256,
  70. ),
  71. ]
  72. @pytest.mark.parametrize("h5type_str, dt", H5_VS_NUMPY_DTYPES)
  73. def test_cmplx_type_trnslt(h5type_str, dt):
  74. """Translate native HDF5 complex number datatype to NumPy dtype"""
  75. assert getattr(h5t, h5type_str).dtype == np.dtype(dt)
  76. def check_compound_complex_datatype(datatype, np_dtype):
  77. """Check h5py's old complex number type"""
  78. assert isinstance(datatype, h5t.TypeCompoundID)
  79. assert datatype.get_nmembers() == 2
  80. assert datatype.get_member_name(0) == cfg._r_name
  81. assert datatype.get_member_name(1) == cfg._i_name
  82. assert isinstance(datatype.get_member_type(0), h5t.TypeFloatID)
  83. assert isinstance(datatype.get_member_type(1), h5t.TypeFloatID)
  84. assert (
  85. datatype.get_member_type(0).get_size()
  86. + datatype.get_member_type(1).get_size()
  87. ) == np_dtype.itemsize
  88. match (np_dtype.byteorder, sys.byteorder):
  89. case (">", _) | ("=", "big"):
  90. expected = h5t.ORDER_BE
  91. case ("<", _) | ("=", "little"):
  92. expected = h5t.ORDER_LE
  93. case _ as _unreachable:
  94. raise AssertionError
  95. assert datatype.get_member_type(0).get_order() == expected
  96. assert datatype.get_member_type(1).get_order() == expected
  97. @pytest.mark.parametrize("dt", ["<c8", ">c8", "<c16", ">c16"])
  98. def test_default_create(writable_file, dt):
  99. """Test default translation when creating datasets & attributes"""
  100. complex_array = (np.random.rand(100) + 1j * np.random.rand(100)).astype(dt)
  101. ds = writable_file.create_dataset(make_name(dt), data=complex_array)
  102. c = np.array(1.9 + 1j * 6.7, dtype=dt)
  103. ds.attrs["c"] = c
  104. check_compound_complex_datatype(ds.id.get_type(), np.dtype(dt))
  105. check_compound_complex_datatype(ds.attrs.get_id("c").get_type(), np.dtype(dt))
  106. np.testing.assert_array_equal(ds[...], complex_array)
  107. np.testing.assert_array_equal(ds.attrs["c"], c)
  108. @pytest.mark.parametrize("h5type_str, dt", H5_VS_NUMPY_DTYPES)
  109. def test_explicit_create(writable_file, h5type_str, dt):
  110. """Explicitly use native complex datatype to create datasets & attributes"""
  111. h5type = getattr(h5t, h5type_str)
  112. np_dt = np.dtype(dt)
  113. complex_array = (np.random.rand(100) + 1j * np.random.rand(100)).astype(np_dt)
  114. ds = writable_file.create_dataset(make_name(dt), (100,), dtype=h5py.Datatype(h5type))
  115. ds[:] = complex_array
  116. c = np.array(1.9 + 1j * 6.7, dtype=np_dt)
  117. ds.attrs.create("c", c, dtype=h5py.Datatype(h5type))
  118. assert isinstance(ds.id.get_type(), h5t.TypeComplexID)
  119. assert isinstance(ds.attrs.get_id("c").get_type(), h5t.TypeComplexID)
  120. np.testing.assert_array_equal(ds[...], complex_array)
  121. np.testing.assert_array_equal(ds.attrs["c"], c)
  122. def test_dtype_cmpnd_cmplx():
  123. """Check the resulting dtype of the two-field compound data as complex numbers"""
  124. with h5py.File(get_data_file_path("compound-dtype-complex.h5"), mode="r") as f:
  125. for obj in filter(lambda obj: isinstance(obj, h5py.Dataset), f.values()):
  126. check_compound_complex_datatype(obj.id.get_type(), obj.dtype)