test_shape_base.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. import functools
  2. import sys
  3. import pytest
  4. import numpy as np
  5. from numpy import (
  6. apply_along_axis,
  7. apply_over_axes,
  8. array_split,
  9. column_stack,
  10. dsplit,
  11. dstack,
  12. expand_dims,
  13. hsplit,
  14. kron,
  15. put_along_axis,
  16. split,
  17. take_along_axis,
  18. tile,
  19. vsplit,
  20. )
  21. from numpy.exceptions import AxisError
  22. from numpy.testing import assert_, assert_array_equal, assert_equal, assert_raises
  23. IS_64BIT = sys.maxsize > 2**32
  24. def _add_keepdims(func):
  25. """ hack in keepdims behavior into a function taking an axis """
  26. @functools.wraps(func)
  27. def wrapped(a, axis, **kwargs):
  28. res = func(a, axis=axis, **kwargs)
  29. if axis is None:
  30. axis = 0 # res is now a scalar, so we can insert this anywhere
  31. return np.expand_dims(res, axis=axis)
  32. return wrapped
  33. class TestTakeAlongAxis:
  34. def test_argequivalent(self):
  35. """ Test it translates from arg<func> to <func> """
  36. from numpy.random import rand
  37. a = rand(3, 4, 5)
  38. funcs = [
  39. (np.sort, np.argsort, {}),
  40. (_add_keepdims(np.min), _add_keepdims(np.argmin), {}),
  41. (_add_keepdims(np.max), _add_keepdims(np.argmax), {}),
  42. #(np.partition, np.argpartition, dict(kth=2)),
  43. ]
  44. for func, argfunc, kwargs in funcs:
  45. for axis in list(range(a.ndim)) + [None]:
  46. a_func = func(a, axis=axis, **kwargs)
  47. ai_func = argfunc(a, axis=axis, **kwargs)
  48. assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))
  49. def test_invalid(self):
  50. """ Test it errors when indices has too few dimensions """
  51. a = np.ones((10, 10))
  52. ai = np.ones((10, 2), dtype=np.intp)
  53. # sanity check
  54. take_along_axis(a, ai, axis=1)
  55. # not enough indices
  56. assert_raises(ValueError, take_along_axis, a, np.array(1), axis=1)
  57. # bool arrays not allowed
  58. assert_raises(IndexError, take_along_axis, a, ai.astype(bool), axis=1)
  59. # float arrays not allowed
  60. assert_raises(IndexError, take_along_axis, a, ai.astype(float), axis=1)
  61. # invalid axis
  62. assert_raises(AxisError, take_along_axis, a, ai, axis=10)
  63. # invalid indices
  64. assert_raises(ValueError, take_along_axis, a, ai, axis=None)
  65. def test_empty(self):
  66. """ Test everything is ok with empty results, even with inserted dims """
  67. a = np.ones((3, 4, 5))
  68. ai = np.ones((3, 0, 5), dtype=np.intp)
  69. actual = take_along_axis(a, ai, axis=1)
  70. assert_equal(actual.shape, ai.shape)
  71. def test_broadcast(self):
  72. """ Test that non-indexing dimensions are broadcast in both directions """
  73. a = np.ones((3, 4, 1))
  74. ai = np.ones((1, 2, 5), dtype=np.intp)
  75. actual = take_along_axis(a, ai, axis=1)
  76. assert_equal(actual.shape, (3, 2, 5))
  77. class TestPutAlongAxis:
  78. def test_replace_max(self):
  79. a_base = np.array([[10, 30, 20], [60, 40, 50]])
  80. for axis in list(range(a_base.ndim)) + [None]:
  81. # we mutate this in the loop
  82. a = a_base.copy()
  83. # replace the max with a small value
  84. i_max = _add_keepdims(np.argmax)(a, axis=axis)
  85. put_along_axis(a, i_max, -99, axis=axis)
  86. # find the new minimum, which should max
  87. i_min = _add_keepdims(np.argmin)(a, axis=axis)
  88. assert_equal(i_min, i_max)
  89. def test_broadcast(self):
  90. """ Test that non-indexing dimensions are broadcast in both directions """
  91. a = np.ones((3, 4, 1))
  92. ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4
  93. put_along_axis(a, ai, 20, axis=1)
  94. assert_equal(take_along_axis(a, ai, axis=1), 20)
  95. def test_invalid(self):
  96. """ Test invalid inputs """
  97. a_base = np.array([[10, 30, 20], [60, 40, 50]])
  98. indices = np.array([[0], [1]])
  99. values = np.array([[2], [1]])
  100. # sanity check
  101. a = a_base.copy()
  102. put_along_axis(a, indices, values, axis=0)
  103. assert np.all(a == [[2, 2, 2], [1, 1, 1]])
  104. # invalid indices
  105. a = a_base.copy()
  106. with assert_raises(ValueError) as exc:
  107. put_along_axis(a, indices, values, axis=None)
  108. assert "single dimension" in str(exc.exception)
  109. class TestApplyAlongAxis:
  110. def test_simple(self):
  111. a = np.ones((20, 10), 'd')
  112. assert_array_equal(
  113. apply_along_axis(len, 0, a), len(a) * np.ones(a.shape[1]))
  114. def test_simple101(self):
  115. a = np.ones((10, 101), 'd')
  116. assert_array_equal(
  117. apply_along_axis(len, 0, a), len(a) * np.ones(a.shape[1]))
  118. def test_3d(self):
  119. a = np.arange(27).reshape((3, 3, 3))
  120. assert_array_equal(apply_along_axis(np.sum, 0, a),
  121. [[27, 30, 33], [36, 39, 42], [45, 48, 51]])
  122. def test_preserve_subclass(self):
  123. def double(row):
  124. return row * 2
  125. class MyNDArray(np.ndarray):
  126. pass
  127. m = np.array([[0, 1], [2, 3]]).view(MyNDArray)
  128. expected = np.array([[0, 2], [4, 6]]).view(MyNDArray)
  129. result = apply_along_axis(double, 0, m)
  130. assert_(isinstance(result, MyNDArray))
  131. assert_array_equal(result, expected)
  132. result = apply_along_axis(double, 1, m)
  133. assert_(isinstance(result, MyNDArray))
  134. assert_array_equal(result, expected)
  135. def test_subclass(self):
  136. class MinimalSubclass(np.ndarray):
  137. data = 1
  138. def minimal_function(array):
  139. return array.data
  140. a = np.zeros((6, 3)).view(MinimalSubclass)
  141. assert_array_equal(
  142. apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1])
  143. )
  144. def test_scalar_array(self, cls=np.ndarray):
  145. a = np.ones((6, 3)).view(cls)
  146. res = apply_along_axis(np.sum, 0, a)
  147. assert_(isinstance(res, cls))
  148. assert_array_equal(res, np.array([6, 6, 6]).view(cls))
  149. def test_0d_array(self, cls=np.ndarray):
  150. def sum_to_0d(x):
  151. """ Sum x, returning a 0d array of the same class """
  152. assert_equal(x.ndim, 1)
  153. return np.squeeze(np.sum(x, keepdims=True))
  154. a = np.ones((6, 3)).view(cls)
  155. res = apply_along_axis(sum_to_0d, 0, a)
  156. assert_(isinstance(res, cls))
  157. assert_array_equal(res, np.array([6, 6, 6]).view(cls))
  158. res = apply_along_axis(sum_to_0d, 1, a)
  159. assert_(isinstance(res, cls))
  160. assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls))
  161. def test_axis_insertion(self, cls=np.ndarray):
  162. def f1to2(x):
  163. """produces an asymmetric non-square matrix from x"""
  164. assert_equal(x.ndim, 1)
  165. return (x[::-1] * x[1:, None]).view(cls)
  166. a2d = np.arange(6 * 3).reshape((6, 3))
  167. # 2d insertion along first axis
  168. actual = apply_along_axis(f1to2, 0, a2d)
  169. expected = np.stack([
  170. f1to2(a2d[:, i]) for i in range(a2d.shape[1])
  171. ], axis=-1).view(cls)
  172. assert_equal(type(actual), type(expected))
  173. assert_equal(actual, expected)
  174. # 2d insertion along last axis
  175. actual = apply_along_axis(f1to2, 1, a2d)
  176. expected = np.stack([
  177. f1to2(a2d[i, :]) for i in range(a2d.shape[0])
  178. ], axis=0).view(cls)
  179. assert_equal(type(actual), type(expected))
  180. assert_equal(actual, expected)
  181. # 3d insertion along middle axis
  182. a3d = np.arange(6 * 5 * 3).reshape((6, 5, 3))
  183. actual = apply_along_axis(f1to2, 1, a3d)
  184. expected = np.stack([
  185. np.stack([
  186. f1to2(a3d[i, :, j]) for i in range(a3d.shape[0])
  187. ], axis=0)
  188. for j in range(a3d.shape[2])
  189. ], axis=-1).view(cls)
  190. assert_equal(type(actual), type(expected))
  191. assert_equal(actual, expected)
  192. def test_subclass_preservation(self):
  193. class MinimalSubclass(np.ndarray):
  194. pass
  195. self.test_scalar_array(MinimalSubclass)
  196. self.test_0d_array(MinimalSubclass)
  197. self.test_axis_insertion(MinimalSubclass)
  198. def test_axis_insertion_ma(self):
  199. def f1to2(x):
  200. """produces an asymmetric non-square matrix from x"""
  201. assert_equal(x.ndim, 1)
  202. res = x[::-1] * x[1:, None]
  203. return np.ma.masked_where(res % 5 == 0, res)
  204. a = np.arange(6 * 3).reshape((6, 3))
  205. res = apply_along_axis(f1to2, 0, a)
  206. assert_(isinstance(res, np.ma.masked_array))
  207. assert_equal(res.ndim, 3)
  208. assert_array_equal(res[:, :, 0].mask, f1to2(a[:, 0]).mask)
  209. assert_array_equal(res[:, :, 1].mask, f1to2(a[:, 1]).mask)
  210. assert_array_equal(res[:, :, 2].mask, f1to2(a[:, 2]).mask)
  211. def test_tuple_func1d(self):
  212. def sample_1d(x):
  213. return x[1], x[0]
  214. res = np.apply_along_axis(sample_1d, 1, np.array([[1, 2], [3, 4]]))
  215. assert_array_equal(res, np.array([[2, 1], [4, 3]]))
  216. def test_empty(self):
  217. # can't apply_along_axis when there's no chance to call the function
  218. def never_call(x):
  219. assert_(False) # should never be reached
  220. a = np.empty((0, 0))
  221. assert_raises(ValueError, np.apply_along_axis, never_call, 0, a)
  222. assert_raises(ValueError, np.apply_along_axis, never_call, 1, a)
  223. # but it's sometimes ok with some non-zero dimensions
  224. def empty_to_1(x):
  225. assert_(len(x) == 0)
  226. return 1
  227. a = np.empty((10, 0))
  228. actual = np.apply_along_axis(empty_to_1, 1, a)
  229. assert_equal(actual, np.ones(10))
  230. assert_raises(ValueError, np.apply_along_axis, empty_to_1, 0, a)
  231. def test_with_iterable_object(self):
  232. # from issue 5248
  233. d = np.array([
  234. [{1, 11}, {2, 22}, {3, 33}],
  235. [{4, 44}, {5, 55}, {6, 66}]
  236. ])
  237. actual = np.apply_along_axis(lambda a: set.union(*a), 0, d)
  238. expected = np.array([{1, 11, 4, 44}, {2, 22, 5, 55}, {3, 33, 6, 66}])
  239. assert_equal(actual, expected)
  240. # issue 8642 - assert_equal doesn't detect this!
  241. for i in np.ndindex(actual.shape):
  242. assert_equal(type(actual[i]), type(expected[i]))
  243. class TestApplyOverAxes:
  244. def test_simple(self):
  245. a = np.arange(24).reshape(2, 3, 4)
  246. aoa_a = apply_over_axes(np.sum, a, [0, 2])
  247. assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))
  248. class TestExpandDims:
  249. def test_functionality(self):
  250. s = (2, 3, 4, 5)
  251. a = np.empty(s)
  252. for axis in range(-5, 4):
  253. b = expand_dims(a, axis)
  254. assert_(b.shape[axis] == 1)
  255. assert_(np.squeeze(b).shape == s)
  256. def test_axis_tuple(self):
  257. a = np.empty((3, 3, 3))
  258. assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3)
  259. assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1)
  260. assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1)
  261. assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)
  262. def test_axis_out_of_range(self):
  263. s = (2, 3, 4, 5)
  264. a = np.empty(s)
  265. assert_raises(AxisError, expand_dims, a, -6)
  266. assert_raises(AxisError, expand_dims, a, 5)
  267. a = np.empty((3, 3, 3))
  268. assert_raises(AxisError, expand_dims, a, (0, -6))
  269. assert_raises(AxisError, expand_dims, a, (0, 5))
  270. def test_repeated_axis(self):
  271. a = np.empty((3, 3, 3))
  272. assert_raises(ValueError, expand_dims, a, axis=(1, 1))
  273. def test_subclasses(self):
  274. a = np.arange(10).reshape((2, 5))
  275. a = np.ma.array(a, mask=a % 3 == 0)
  276. expanded = np.expand_dims(a, axis=1)
  277. assert_(isinstance(expanded, np.ma.MaskedArray))
  278. assert_equal(expanded.shape, (2, 1, 5))
  279. assert_equal(expanded.mask.shape, (2, 1, 5))
  280. class TestArraySplit:
  281. def test_integer_0_split(self):
  282. a = np.arange(10)
  283. assert_raises(ValueError, array_split, a, 0)
  284. def test_integer_split(self):
  285. a = np.arange(10)
  286. res = array_split(a, 1)
  287. desired = [np.arange(10)]
  288. compare_results(res, desired)
  289. res = array_split(a, 2)
  290. desired = [np.arange(5), np.arange(5, 10)]
  291. compare_results(res, desired)
  292. res = array_split(a, 3)
  293. desired = [np.arange(4), np.arange(4, 7), np.arange(7, 10)]
  294. compare_results(res, desired)
  295. res = array_split(a, 4)
  296. desired = [np.arange(3), np.arange(3, 6), np.arange(6, 8),
  297. np.arange(8, 10)]
  298. compare_results(res, desired)
  299. res = array_split(a, 5)
  300. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
  301. np.arange(6, 8), np.arange(8, 10)]
  302. compare_results(res, desired)
  303. res = array_split(a, 6)
  304. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
  305. np.arange(6, 8), np.arange(8, 9), np.arange(9, 10)]
  306. compare_results(res, desired)
  307. res = array_split(a, 7)
  308. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
  309. np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
  310. np.arange(9, 10)]
  311. compare_results(res, desired)
  312. res = array_split(a, 8)
  313. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 5),
  314. np.arange(5, 6), np.arange(6, 7), np.arange(7, 8),
  315. np.arange(8, 9), np.arange(9, 10)]
  316. compare_results(res, desired)
  317. res = array_split(a, 9)
  318. desired = [np.arange(2), np.arange(2, 3), np.arange(3, 4),
  319. np.arange(4, 5), np.arange(5, 6), np.arange(6, 7),
  320. np.arange(7, 8), np.arange(8, 9), np.arange(9, 10)]
  321. compare_results(res, desired)
  322. res = array_split(a, 10)
  323. desired = [np.arange(1), np.arange(1, 2), np.arange(2, 3),
  324. np.arange(3, 4), np.arange(4, 5), np.arange(5, 6),
  325. np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
  326. np.arange(9, 10)]
  327. compare_results(res, desired)
  328. res = array_split(a, 11)
  329. desired = [np.arange(1), np.arange(1, 2), np.arange(2, 3),
  330. np.arange(3, 4), np.arange(4, 5), np.arange(5, 6),
  331. np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
  332. np.arange(9, 10), np.array([])]
  333. compare_results(res, desired)
  334. def test_integer_split_2D_rows(self):
  335. a = np.array([np.arange(10), np.arange(10)])
  336. res = array_split(a, 3, axis=0)
  337. tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]),
  338. np.zeros((0, 10))]
  339. compare_results(res, tgt)
  340. assert_(a.dtype.type is res[-1].dtype.type)
  341. # Same thing for manual splits:
  342. res = array_split(a, [0, 1], axis=0)
  343. tgt = [np.zeros((0, 10)), np.array([np.arange(10)]),
  344. np.array([np.arange(10)])]
  345. compare_results(res, tgt)
  346. assert_(a.dtype.type is res[-1].dtype.type)
  347. def test_integer_split_2D_cols(self):
  348. a = np.array([np.arange(10), np.arange(10)])
  349. res = array_split(a, 3, axis=-1)
  350. desired = [np.array([np.arange(4), np.arange(4)]),
  351. np.array([np.arange(4, 7), np.arange(4, 7)]),
  352. np.array([np.arange(7, 10), np.arange(7, 10)])]
  353. compare_results(res, desired)
  354. def test_integer_split_2D_default(self):
  355. """ This will fail if we change default axis
  356. """
  357. a = np.array([np.arange(10), np.arange(10)])
  358. res = array_split(a, 3)
  359. tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]),
  360. np.zeros((0, 10))]
  361. compare_results(res, tgt)
  362. assert_(a.dtype.type is res[-1].dtype.type)
  363. # perhaps should check higher dimensions
  364. @pytest.mark.skipif(not IS_64BIT, reason="Needs 64bit platform")
  365. def test_integer_split_2D_rows_greater_max_int32(self):
  366. a = np.broadcast_to([0], (1 << 32, 2))
  367. res = array_split(a, 4)
  368. chunk = np.broadcast_to([0], (1 << 30, 2))
  369. tgt = [chunk] * 4
  370. for i in range(len(tgt)):
  371. assert_equal(res[i].shape, tgt[i].shape)
  372. def test_index_split_simple(self):
  373. a = np.arange(10)
  374. indices = [1, 5, 7]
  375. res = array_split(a, indices, axis=-1)
  376. desired = [np.arange(0, 1), np.arange(1, 5), np.arange(5, 7),
  377. np.arange(7, 10)]
  378. compare_results(res, desired)
  379. def test_index_split_low_bound(self):
  380. a = np.arange(10)
  381. indices = [0, 5, 7]
  382. res = array_split(a, indices, axis=-1)
  383. desired = [np.array([]), np.arange(0, 5), np.arange(5, 7),
  384. np.arange(7, 10)]
  385. compare_results(res, desired)
  386. def test_index_split_high_bound(self):
  387. a = np.arange(10)
  388. indices = [0, 5, 7, 10, 12]
  389. res = array_split(a, indices, axis=-1)
  390. desired = [np.array([]), np.arange(0, 5), np.arange(5, 7),
  391. np.arange(7, 10), np.array([]), np.array([])]
  392. compare_results(res, desired)
  393. class TestSplit:
  394. # The split function is essentially the same as array_split,
  395. # except that it test if splitting will result in an
  396. # equal split. Only test for this case.
  397. def test_equal_split(self):
  398. a = np.arange(10)
  399. res = split(a, 2)
  400. desired = [np.arange(5), np.arange(5, 10)]
  401. compare_results(res, desired)
  402. def test_unequal_split(self):
  403. a = np.arange(10)
  404. assert_raises(ValueError, split, a, 3)
  405. class TestColumnStack:
  406. def test_non_iterable(self):
  407. assert_raises(TypeError, column_stack, 1)
  408. def test_1D_arrays(self):
  409. # example from docstring
  410. a = np.array((1, 2, 3))
  411. b = np.array((2, 3, 4))
  412. expected = np.array([[1, 2],
  413. [2, 3],
  414. [3, 4]])
  415. actual = np.column_stack((a, b))
  416. assert_equal(actual, expected)
  417. def test_2D_arrays(self):
  418. # same as hstack 2D docstring example
  419. a = np.array([[1], [2], [3]])
  420. b = np.array([[2], [3], [4]])
  421. expected = np.array([[1, 2],
  422. [2, 3],
  423. [3, 4]])
  424. actual = np.column_stack((a, b))
  425. assert_equal(actual, expected)
  426. def test_generator(self):
  427. with pytest.raises(TypeError, match="arrays to stack must be"):
  428. column_stack(np.arange(3) for _ in range(2))
  429. class TestDstack:
  430. def test_non_iterable(self):
  431. assert_raises(TypeError, dstack, 1)
  432. def test_0D_array(self):
  433. a = np.array(1)
  434. b = np.array(2)
  435. res = dstack([a, b])
  436. desired = np.array([[[1, 2]]])
  437. assert_array_equal(res, desired)
  438. def test_1D_array(self):
  439. a = np.array([1])
  440. b = np.array([2])
  441. res = dstack([a, b])
  442. desired = np.array([[[1, 2]]])
  443. assert_array_equal(res, desired)
  444. def test_2D_array(self):
  445. a = np.array([[1], [2]])
  446. b = np.array([[1], [2]])
  447. res = dstack([a, b])
  448. desired = np.array([[[1, 1]], [[2, 2, ]]])
  449. assert_array_equal(res, desired)
  450. def test_2D_array2(self):
  451. a = np.array([1, 2])
  452. b = np.array([1, 2])
  453. res = dstack([a, b])
  454. desired = np.array([[[1, 1], [2, 2]]])
  455. assert_array_equal(res, desired)
  456. def test_generator(self):
  457. with pytest.raises(TypeError, match="arrays to stack must be"):
  458. dstack(np.arange(3) for _ in range(2))
  459. # array_split has more comprehensive test of splitting.
  460. # only do simple test on hsplit, vsplit, and dsplit
  461. class TestHsplit:
  462. """Only testing for integer splits.
  463. """
  464. def test_non_iterable(self):
  465. assert_raises(ValueError, hsplit, 1, 1)
  466. def test_0D_array(self):
  467. a = np.array(1)
  468. try:
  469. hsplit(a, 2)
  470. assert_(0)
  471. except ValueError:
  472. pass
  473. def test_1D_array(self):
  474. a = np.array([1, 2, 3, 4])
  475. res = hsplit(a, 2)
  476. desired = [np.array([1, 2]), np.array([3, 4])]
  477. compare_results(res, desired)
  478. def test_2D_array(self):
  479. a = np.array([[1, 2, 3, 4],
  480. [1, 2, 3, 4]])
  481. res = hsplit(a, 2)
  482. desired = [np.array([[1, 2], [1, 2]]), np.array([[3, 4], [3, 4]])]
  483. compare_results(res, desired)
  484. class TestVsplit:
  485. """Only testing for integer splits.
  486. """
  487. def test_non_iterable(self):
  488. assert_raises(ValueError, vsplit, 1, 1)
  489. def test_0D_array(self):
  490. a = np.array(1)
  491. assert_raises(ValueError, vsplit, a, 2)
  492. def test_1D_array(self):
  493. a = np.array([1, 2, 3, 4])
  494. try:
  495. vsplit(a, 2)
  496. assert_(0)
  497. except ValueError:
  498. pass
  499. def test_2D_array(self):
  500. a = np.array([[1, 2, 3, 4],
  501. [1, 2, 3, 4]])
  502. res = vsplit(a, 2)
  503. desired = [np.array([[1, 2, 3, 4]]), np.array([[1, 2, 3, 4]])]
  504. compare_results(res, desired)
  505. class TestDsplit:
  506. # Only testing for integer splits.
  507. def test_non_iterable(self):
  508. assert_raises(ValueError, dsplit, 1, 1)
  509. def test_0D_array(self):
  510. a = np.array(1)
  511. assert_raises(ValueError, dsplit, a, 2)
  512. def test_1D_array(self):
  513. a = np.array([1, 2, 3, 4])
  514. assert_raises(ValueError, dsplit, a, 2)
  515. def test_2D_array(self):
  516. a = np.array([[1, 2, 3, 4],
  517. [1, 2, 3, 4]])
  518. try:
  519. dsplit(a, 2)
  520. assert_(0)
  521. except ValueError:
  522. pass
  523. def test_3D_array(self):
  524. a = np.array([[[1, 2, 3, 4],
  525. [1, 2, 3, 4]],
  526. [[1, 2, 3, 4],
  527. [1, 2, 3, 4]]])
  528. res = dsplit(a, 2)
  529. desired = [np.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]),
  530. np.array([[[3, 4], [3, 4]], [[3, 4], [3, 4]]])]
  531. compare_results(res, desired)
  532. class TestSqueeze:
  533. def test_basic(self):
  534. from numpy.random import rand
  535. a = rand(20, 10, 10, 1, 1)
  536. b = rand(20, 1, 10, 1, 20)
  537. c = rand(1, 1, 20, 10)
  538. assert_array_equal(np.squeeze(a), np.reshape(a, (20, 10, 10)))
  539. assert_array_equal(np.squeeze(b), np.reshape(b, (20, 10, 20)))
  540. assert_array_equal(np.squeeze(c), np.reshape(c, (20, 10)))
  541. # Squeezing to 0-dim should still give an ndarray
  542. a = [[[1.5]]]
  543. res = np.squeeze(a)
  544. assert_equal(res, 1.5)
  545. assert_equal(res.ndim, 0)
  546. assert_equal(type(res), np.ndarray)
  547. class TestKron:
  548. def test_basic(self):
  549. # Using 0-dimensional ndarray
  550. a = np.array(1)
  551. b = np.array([[1, 2], [3, 4]])
  552. k = np.array([[1, 2], [3, 4]])
  553. assert_array_equal(np.kron(a, b), k)
  554. a = np.array([[1, 2], [3, 4]])
  555. b = np.array(1)
  556. assert_array_equal(np.kron(a, b), k)
  557. # Using 1-dimensional ndarray
  558. a = np.array([3])
  559. b = np.array([[1, 2], [3, 4]])
  560. k = np.array([[3, 6], [9, 12]])
  561. assert_array_equal(np.kron(a, b), k)
  562. a = np.array([[1, 2], [3, 4]])
  563. b = np.array([3])
  564. assert_array_equal(np.kron(a, b), k)
  565. # Using 3-dimensional ndarray
  566. a = np.array([[[1]], [[2]]])
  567. b = np.array([[1, 2], [3, 4]])
  568. k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
  569. assert_array_equal(np.kron(a, b), k)
  570. a = np.array([[1, 2], [3, 4]])
  571. b = np.array([[[1]], [[2]]])
  572. k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
  573. assert_array_equal(np.kron(a, b), k)
  574. def test_return_type(self):
  575. class myarray(np.ndarray):
  576. __array_priority__ = 1.0
  577. a = np.ones([2, 2])
  578. ma = myarray(a.shape, a.dtype, a.data)
  579. assert_equal(type(kron(a, a)), np.ndarray)
  580. assert_equal(type(kron(ma, ma)), myarray)
  581. assert_equal(type(kron(a, ma)), myarray)
  582. assert_equal(type(kron(ma, a)), myarray)
  583. @pytest.mark.parametrize(
  584. "array_class", [np.asarray, np.asmatrix]
  585. )
  586. def test_kron_smoke(self, array_class):
  587. a = array_class(np.ones([3, 3]))
  588. b = array_class(np.ones([3, 3]))
  589. k = array_class(np.ones([9, 9]))
  590. assert_array_equal(np.kron(a, b), k)
  591. def test_kron_ma(self):
  592. x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])
  593. k = np.ma.array(np.diag([1, 4, 4, 16]),
  594. mask=~np.array(np.identity(4), dtype=bool))
  595. assert_array_equal(k, np.kron(x, x))
  596. @pytest.mark.parametrize(
  597. "shape_a,shape_b", [
  598. ((1, 1), (1, 1)),
  599. ((1, 2, 3), (4, 5, 6)),
  600. ((2, 2), (2, 2, 2)),
  601. ((1, 0), (1, 1)),
  602. ((2, 0, 2), (2, 2)),
  603. ((2, 0, 0, 2), (2, 0, 2)),
  604. ])
  605. def test_kron_shape(self, shape_a, shape_b):
  606. a = np.ones(shape_a)
  607. b = np.ones(shape_b)
  608. normalised_shape_a = (1,) * max(0, len(shape_b) - len(shape_a)) + shape_a
  609. normalised_shape_b = (1,) * max(0, len(shape_a) - len(shape_b)) + shape_b
  610. expected_shape = np.multiply(normalised_shape_a, normalised_shape_b)
  611. k = np.kron(a, b)
  612. assert np.array_equal(
  613. k.shape, expected_shape), "Unexpected shape from kron"
  614. class TestTile:
  615. def test_basic(self):
  616. a = np.array([0, 1, 2])
  617. b = [[1, 2], [3, 4]]
  618. assert_equal(tile(a, 2), [0, 1, 2, 0, 1, 2])
  619. assert_equal(tile(a, (2, 2)), [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]])
  620. assert_equal(tile(a, (1, 2)), [[0, 1, 2, 0, 1, 2]])
  621. assert_equal(tile(b, 2), [[1, 2, 1, 2], [3, 4, 3, 4]])
  622. assert_equal(tile(b, (2, 1)), [[1, 2], [3, 4], [1, 2], [3, 4]])
  623. assert_equal(tile(b, (2, 2)), [[1, 2, 1, 2], [3, 4, 3, 4],
  624. [1, 2, 1, 2], [3, 4, 3, 4]])
  625. def test_tile_one_repetition_on_array_gh4679(self):
  626. a = np.arange(5)
  627. b = tile(a, 1)
  628. b += 2
  629. assert_equal(a, np.arange(5))
  630. def test_empty(self):
  631. a = np.array([[[]]])
  632. b = np.array([[], []])
  633. c = tile(b, 2).shape
  634. d = tile(a, (3, 2, 5)).shape
  635. assert_equal(c, (2, 0))
  636. assert_equal(d, (3, 2, 0))
  637. def test_kroncompare(self):
  638. from numpy.random import randint
  639. reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)]
  640. shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)]
  641. for s in shape:
  642. b = randint(0, 10, size=s)
  643. for r in reps:
  644. a = np.ones(r, b.dtype)
  645. large = tile(b, r)
  646. klarge = kron(a, b)
  647. assert_equal(large, klarge)
  648. class TestMayShareMemory:
  649. def test_basic(self):
  650. d = np.ones((50, 60))
  651. d2 = np.ones((30, 60, 6))
  652. assert_(np.may_share_memory(d, d))
  653. assert_(np.may_share_memory(d, d[::-1]))
  654. assert_(np.may_share_memory(d, d[::2]))
  655. assert_(np.may_share_memory(d, d[1:, ::-1]))
  656. assert_(not np.may_share_memory(d[::-1], d2))
  657. assert_(not np.may_share_memory(d[::2], d2))
  658. assert_(not np.may_share_memory(d[1:, ::-1], d2))
  659. assert_(np.may_share_memory(d2[1:, ::-1], d2))
  660. # Utility
  661. def compare_results(res, desired):
  662. """Compare lists of arrays."""
  663. for x, y in zip(res, desired, strict=False):
  664. assert_array_equal(x, y)