test_slicing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. # This file is part of h5py, a Python interface to the HDF5 library.
  2. #
  3. # http://www.h5py.org
  4. #
  5. # Copyright 2008-2013 Andrew Collette and contributors
  6. #
  7. # License: Standard 3-clause BSD; see "license.txt" for full license terms
  8. # and contributor agreement.
  9. """
  10. Dataset slicing test module.
  11. Tests all supported slicing operations, including read/write and
  12. broadcasting operations. Does not test type conversion except for
  13. corner cases overlapping with slicing; for example, when selecting
  14. specific fields of a compound type.
  15. """
  16. import numpy as np
  17. from .common import TestCase, make_name
  18. import h5py
  19. from h5py import File, MultiBlockSlice
  20. class BaseSlicing(TestCase):
  21. def setUp(self):
  22. self.f = File(self.mktemp(), 'w')
  23. def tearDown(self):
  24. if self.f:
  25. self.f.close()
  26. class TestSingleElement(BaseSlicing):
  27. """
  28. Feature: Retrieving a single element works with NumPy semantics
  29. """
  30. def test_single_index(self):
  31. """ Single-element selection with [index] yields array scalar """
  32. dset = self.f.create_dataset(make_name(), (1,), dtype='i1')
  33. out = dset[0]
  34. self.assertIsInstance(out, np.int8)
  35. def test_single_null(self):
  36. """ Single-element selection with [()] yields ndarray """
  37. dset = self.f.create_dataset(make_name(), (1,), dtype='i1')
  38. out = dset[()]
  39. self.assertIsInstance(out, np.ndarray)
  40. self.assertEqual(out.shape, (1,))
  41. def test_scalar_index(self):
  42. """ Slicing with [...] yields scalar ndarray """
  43. dset = self.f.create_dataset(make_name(), shape=(), dtype='f')
  44. out = dset[...]
  45. self.assertIsInstance(out, np.ndarray)
  46. self.assertEqual(out.shape, ())
  47. def test_scalar_null(self):
  48. """ Slicing with [()] yields array scalar """
  49. dset = self.f.create_dataset(make_name(), shape=(), dtype='i1')
  50. out = dset[()]
  51. self.assertIsInstance(out, np.int8)
  52. def test_compound(self):
  53. """ Compound scalar is numpy.void, not tuple (issue 135) """
  54. dt = np.dtype([('a','i4'),('b','f8')])
  55. v = np.ones((4,), dtype=dt)
  56. dset = self.f.create_dataset(make_name(), (4,), data=v)
  57. self.assertEqual(dset[0], v[0])
  58. self.assertIsInstance(dset[0], np.void)
  59. class TestObjectIndex(BaseSlicing):
  60. """
  61. Feature: numpy.object_ subtypes map to real Python objects
  62. """
  63. def test_reference(self):
  64. """ Indexing a reference dataset returns a h5py.Reference instance """
  65. dset = self.f.create_dataset(make_name(), (1,), dtype=h5py.ref_dtype)
  66. dset[0] = self.f.ref
  67. self.assertEqual(type(dset[0]), h5py.Reference)
  68. def test_regref(self):
  69. """ Indexing a region reference dataset returns a h5py.RegionReference
  70. """
  71. dset1 = self.f.create_dataset(make_name("x"), (10,10), "f4")
  72. regref = dset1.regionref[...]
  73. dset2 = self.f.create_dataset(make_name("y"), (1,), dtype=h5py.regionref_dtype)
  74. dset2[0] = regref
  75. self.assertEqual(type(dset2[0]), h5py.RegionReference)
  76. def test_reference_field(self):
  77. """ Compound types of which a reference is an element work right """
  78. dt = np.dtype([('a', 'i'),('b', h5py.ref_dtype)])
  79. dset = self.f.create_dataset(make_name(), (1,), dtype=dt)
  80. dset[0] = (42, self.f['/'].ref)
  81. out = dset[0]
  82. self.assertEqual(type(out[1]), h5py.Reference) # isinstance does NOT work
  83. def test_scalar(self):
  84. """ Indexing returns a real Python object on scalar datasets """
  85. dset = self.f.create_dataset(make_name(), (), dtype=h5py.ref_dtype)
  86. dset[()] = self.f.ref
  87. self.assertEqual(type(dset[()]), h5py.Reference)
  88. def test_bytestr(self):
  89. """ Indexing a byte string dataset returns a real python byte string
  90. """
  91. dset = self.f.create_dataset(make_name(), (1,), dtype=h5py.string_dtype(encoding='ascii'))
  92. dset[0] = b"Hello there!"
  93. self.assertEqual(type(dset[0]), bytes)
  94. class TestSimpleSlicing(TestCase):
  95. """
  96. Feature: Simple NumPy-style slices (start:stop:step) are supported.
  97. """
  98. def setUp(self):
  99. self.f = File(self.mktemp(), 'w')
  100. self.arr = np.arange(10)
  101. self.dset = self.f.create_dataset('x', data=self.arr)
  102. def tearDown(self):
  103. if self.f:
  104. self.f.close()
  105. def test_negative_stop(self):
  106. """ Negative stop indexes work as they do in NumPy """
  107. self.assertArrayEqual(self.dset[2:-2], self.arr[2:-2])
  108. def test_write(self):
  109. """Assigning to a 1D slice of a 2D dataset
  110. """
  111. dset = self.f.create_dataset(make_name(), (10, 2), "f4")
  112. x = np.zeros((10, 1))
  113. dset[:, 0] = x[:, 0]
  114. with self.assertRaises(TypeError):
  115. dset[:, 1] = x
  116. class TestArraySlicing(BaseSlicing):
  117. """
  118. Feature: Array types are handled appropriately
  119. """
  120. def test_read(self):
  121. """ Read arrays tack array dimensions onto end of shape tuple """
  122. dt = np.dtype('(3,)f8')
  123. dset = self.f.create_dataset(make_name(), (10,), dtype=dt)
  124. self.assertEqual(dset.shape, (10,))
  125. self.assertEqual(dset.dtype, dt)
  126. # Full read
  127. out = dset[...]
  128. self.assertEqual(out.dtype, np.dtype('f8'))
  129. self.assertEqual(out.shape, (10,3))
  130. # Single element
  131. out = dset[0]
  132. self.assertEqual(out.dtype, np.dtype('f8'))
  133. self.assertEqual(out.shape, (3,))
  134. # Range
  135. out = dset[2:8:2]
  136. self.assertEqual(out.dtype, np.dtype('f8'))
  137. self.assertEqual(out.shape, (3,3))
  138. def test_write_broadcast(self):
  139. """ Array fill from constant is not supported (issue 211).
  140. """
  141. dt = np.dtype('(3,)i')
  142. dset = self.f.create_dataset(make_name(), (10,), dtype=dt)
  143. with self.assertRaises(TypeError):
  144. dset[...] = 42
  145. def test_write_element(self):
  146. """ Write a single element to the array
  147. Issue 211.
  148. """
  149. dt = np.dtype('(3,)f8')
  150. dset = self.f.create_dataset(make_name(), (10,), dtype=dt)
  151. data = np.array([1,2,3.0])
  152. dset[4] = data
  153. out = dset[4]
  154. self.assertTrue(np.all(out == data))
  155. def test_write_slices(self):
  156. """ Write slices to array type """
  157. dt = np.dtype('(3,)i')
  158. data1 = np.ones((2,), dtype=dt)
  159. data2 = np.ones((4,5), dtype=dt)
  160. dset = self.f.create_dataset(make_name(), (10,9,11), dtype=dt)
  161. dset[0,0,2:4] = data1
  162. self.assertArrayEqual(dset[0,0,2:4], data1)
  163. dset[3, 1:5, 6:11] = data2
  164. self.assertArrayEqual(dset[3, 1:5, 6:11], data2)
  165. def test_roundtrip(self):
  166. """ Read the contents of an array and write them back
  167. Issue 211.
  168. """
  169. dt = np.dtype('(3,)f8')
  170. dset = self.f.create_dataset(make_name(), (10,), dtype=dt)
  171. out = dset[...]
  172. dset[...] = out
  173. self.assertTrue(np.all(dset[...] == out))
  174. class TestZeroLengthSlicing(BaseSlicing):
  175. """
  176. Slices resulting in empty arrays
  177. """
  178. def test_slice_zero_length_dimension(self):
  179. """ Slice a dataset with a zero in its shape vector
  180. along the zero-length dimension """
  181. for i, shape in enumerate([(0,), (0, 3), (0, 2, 1)]):
  182. dset = self.f.create_dataset(make_name(f"x{i}"), shape, dtype=int, maxshape=(None,)*len(shape))
  183. self.assertEqual(dset.shape, shape)
  184. out = dset[...]
  185. self.assertIsInstance(out, np.ndarray)
  186. self.assertEqual(out.shape, shape)
  187. out = dset[:]
  188. self.assertIsInstance(out, np.ndarray)
  189. self.assertEqual(out.shape, shape)
  190. if len(shape) > 1:
  191. out = dset[:, :1]
  192. self.assertIsInstance(out, np.ndarray)
  193. self.assertEqual(out.shape[:2], (0, 1))
  194. def test_slice_other_dimension(self):
  195. """ Slice a dataset with a zero in its shape vector
  196. along a non-zero-length dimension """
  197. for i, shape in enumerate([(3, 0), (1, 2, 0), (2, 0, 1)]):
  198. dset = self.f.create_dataset(make_name(f"x{i}"), shape, dtype=int, maxshape=(None,)*len(shape))
  199. self.assertEqual(dset.shape, shape)
  200. out = dset[:1]
  201. self.assertIsInstance(out, np.ndarray)
  202. self.assertEqual(out.shape, (1,)+shape[1:])
  203. def test_slice_of_length_zero(self):
  204. """ Get a slice of length zero from a non-empty dataset """
  205. for i, shape in enumerate([(3,), (2, 2,), (2, 1, 5)]):
  206. dset = self.f.create_dataset(make_name(f"x{i}"), data=np.zeros(shape, int), maxshape=(None,)*len(shape))
  207. self.assertEqual(dset.shape, shape)
  208. out = dset[1:1]
  209. self.assertIsInstance(out, np.ndarray)
  210. self.assertEqual(out.shape, (0,)+shape[1:])
  211. class TestFieldNames(BaseSlicing):
  212. """
  213. Field names for read & write
  214. """
  215. dt = np.dtype([('a', 'f'), ('b', 'i'), ('c', 'f4')])
  216. data = np.ones((100,), dtype=dt)
  217. def setUp(self):
  218. BaseSlicing.setUp(self)
  219. self.dset = self.f.create_dataset('x', (100,), dtype=self.dt)
  220. self.dset[...] = self.data
  221. def test_read(self):
  222. """ Test read with field selections """
  223. self.assertArrayEqual(self.dset['a'], self.data['a'])
  224. def test_unicode_names(self):
  225. """ Unicode field names for for read and write """
  226. self.assertArrayEqual(self.dset['a'], self.data['a'])
  227. data = self.data.copy()
  228. dset = self.f.create_dataset(make_name(), data=data)
  229. dset['a'] = 42
  230. data['a'] = 42
  231. self.assertArrayEqual(dset['a'], data['a'])
  232. def test_write(self):
  233. """ Test write with field selections """
  234. data = self.data.copy()
  235. dset = self.f.create_dataset(make_name(), data=data)
  236. data['a'] *= 2
  237. dset['a'] = data
  238. self.assertTrue(np.all(dset[...] == data))
  239. data['b'] *= 4
  240. dset['b'] = data
  241. self.assertTrue(np.all(dset[...] == data))
  242. data['a'] *= 3
  243. data['c'] *= 3
  244. dset['a','c'] = data
  245. self.assertTrue(np.all(dset[...] == data))
  246. def test_write_noncompound(self):
  247. """ Test write with non-compound source (single-field) """
  248. data = self.data.copy()
  249. dset = self.f.create_dataset(make_name(), data=data)
  250. data['b'] = 1.0
  251. dset['b'] = 1.0
  252. self.assertTrue(np.all(dset[...] == data))
  253. class TestMultiBlockSlice(BaseSlicing):
  254. def setUp(self):
  255. super().setUp()
  256. self.arr = np.arange(10)
  257. self.dset = self.f.create_dataset('x', data=self.arr)
  258. def test_default(self):
  259. # Default selects entire dataset as one block
  260. mbslice = MultiBlockSlice()
  261. self.assertEqual(mbslice.indices(10), (0, 1, 10, 1))
  262. np.testing.assert_array_equal(self.dset[mbslice], self.arr)
  263. def test_default_explicit(self):
  264. mbslice = MultiBlockSlice(start=0, count=10, stride=1, block=1)
  265. self.assertEqual(mbslice.indices(10), (0, 1, 10, 1))
  266. np.testing.assert_array_equal(self.dset[mbslice], self.arr)
  267. def test_start(self):
  268. mbslice = MultiBlockSlice(start=4)
  269. self.assertEqual(mbslice.indices(10), (4, 1, 6, 1))
  270. np.testing.assert_array_equal(self.dset[mbslice], np.array([4, 5, 6, 7, 8, 9]))
  271. def test_count(self):
  272. mbslice = MultiBlockSlice(count=7)
  273. self.assertEqual(mbslice.indices(10), (0, 1, 7, 1))
  274. np.testing.assert_array_equal(
  275. self.dset[mbslice], np.array([0, 1, 2, 3, 4, 5, 6])
  276. )
  277. def test_count_more_than_length_error(self):
  278. mbslice = MultiBlockSlice(count=11)
  279. with self.assertRaises(ValueError):
  280. mbslice.indices(10)
  281. def test_stride(self):
  282. mbslice = MultiBlockSlice(stride=2)
  283. self.assertEqual(mbslice.indices(10), (0, 2, 5, 1))
  284. np.testing.assert_array_equal(self.dset[mbslice], np.array([0, 2, 4, 6, 8]))
  285. def test_stride_zero_error(self):
  286. with self.assertRaises(ValueError):
  287. # This would cause a ZeroDivisionError if not caught
  288. MultiBlockSlice(stride=0, block=0).indices(10)
  289. def test_stride_block_equal(self):
  290. mbslice = MultiBlockSlice(stride=2, block=2)
  291. self.assertEqual(mbslice.indices(10), (0, 2, 5, 2))
  292. np.testing.assert_array_equal(self.dset[mbslice], self.arr)
  293. def test_block_more_than_stride_error(self):
  294. with self.assertRaises(ValueError):
  295. MultiBlockSlice(block=3)
  296. with self.assertRaises(ValueError):
  297. MultiBlockSlice(stride=2, block=3)
  298. def test_stride_more_than_block(self):
  299. mbslice = MultiBlockSlice(stride=3, block=2)
  300. self.assertEqual(mbslice.indices(10), (0, 3, 3, 2))
  301. np.testing.assert_array_equal(self.dset[mbslice], np.array([0, 1, 3, 4, 6, 7]))
  302. def test_block_overruns_extent_error(self):
  303. # If fully described then must fit within extent
  304. mbslice = MultiBlockSlice(start=2, count=2, stride=5, block=4)
  305. with self.assertRaises(ValueError):
  306. mbslice.indices(10)
  307. def test_fully_described(self):
  308. mbslice = MultiBlockSlice(start=1, count=2, stride=5, block=4)
  309. self.assertEqual(mbslice.indices(10), (1, 5, 2, 4))
  310. np.testing.assert_array_equal(
  311. self.dset[mbslice], np.array([1, 2, 3, 4, 6, 7, 8, 9])
  312. )
  313. def test_count_calculated(self):
  314. # If not given, count should be calculated to select as many full blocks as possible
  315. mbslice = MultiBlockSlice(start=1, stride=3, block=2)
  316. self.assertEqual(mbslice.indices(10), (1, 3, 3, 2))
  317. np.testing.assert_array_equal(self.dset[mbslice], np.array([1, 2, 4, 5, 7, 8]))
  318. def test_zero_count_calculated_error(self):
  319. # In this case, there is no possible count to select even one block, so error
  320. mbslice = MultiBlockSlice(start=8, stride=4, block=3)
  321. with self.assertRaises(ValueError):
  322. mbslice.indices(10)