test_arrayterator.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from functools import reduce
  2. from operator import mul
  3. import numpy as np
  4. from numpy.lib import Arrayterator
  5. from numpy.random import randint
  6. from numpy.testing import assert_
  7. def test():
  8. np.random.seed(np.arange(10))
  9. # Create a random array
  10. ndims = randint(5) + 1
  11. shape = tuple(randint(10) + 1 for dim in range(ndims))
  12. els = reduce(mul, shape)
  13. a = np.arange(els).reshape(shape)
  14. buf_size = randint(2 * els)
  15. b = Arrayterator(a, buf_size)
  16. # Check that each block has at most ``buf_size`` elements
  17. for block in b:
  18. assert_(len(block.flat) <= (buf_size or els))
  19. # Check that all elements are iterated correctly
  20. assert_(list(b.flat) == list(a.flat))
  21. # Slice arrayterator
  22. start = [randint(dim) for dim in shape]
  23. stop = [randint(dim) + 1 for dim in shape]
  24. step = [randint(dim) + 1 for dim in shape]
  25. slice_ = tuple(slice(*t) for t in zip(start, stop, step))
  26. c = b[slice_]
  27. d = a[slice_]
  28. # Check that each block has at most ``buf_size`` elements
  29. for block in c:
  30. assert_(len(block.flat) <= (buf_size or els))
  31. # Check that the arrayterator is sliced correctly
  32. assert_(np.all(c.__array__() == d))
  33. # Check that all elements are iterated correctly
  34. assert_(list(c.flat) == list(d.flat))