test_stride_tricks.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. import pytest
  2. import numpy as np
  3. from numpy._core._rational_tests import rational
  4. from numpy.lib._stride_tricks_impl import (
  5. _broadcast_shape,
  6. as_strided,
  7. broadcast_arrays,
  8. broadcast_shapes,
  9. broadcast_to,
  10. sliding_window_view,
  11. )
  12. from numpy.testing import (
  13. assert_,
  14. assert_array_equal,
  15. assert_equal,
  16. assert_raises,
  17. assert_raises_regex,
  18. )
  19. def assert_shapes_correct(input_shapes, expected_shape):
  20. # Broadcast a list of arrays with the given input shapes and check the
  21. # common output shape.
  22. inarrays = [np.zeros(s) for s in input_shapes]
  23. outarrays = broadcast_arrays(*inarrays)
  24. outshapes = [a.shape for a in outarrays]
  25. expected = [expected_shape] * len(inarrays)
  26. assert_equal(outshapes, expected)
  27. def assert_incompatible_shapes_raise(input_shapes):
  28. # Broadcast a list of arrays with the given (incompatible) input shapes
  29. # and check that they raise a ValueError.
  30. inarrays = [np.zeros(s) for s in input_shapes]
  31. assert_raises(ValueError, broadcast_arrays, *inarrays)
  32. def assert_same_as_ufunc(shape0, shape1, transposed=False, flipped=False):
  33. # Broadcast two shapes against each other and check that the data layout
  34. # is the same as if a ufunc did the broadcasting.
  35. x0 = np.zeros(shape0, dtype=int)
  36. # Note that multiply.reduce's identity element is 1.0, so when shape1==(),
  37. # this gives the desired n==1.
  38. n = int(np.multiply.reduce(shape1))
  39. x1 = np.arange(n).reshape(shape1)
  40. if transposed:
  41. x0 = x0.T
  42. x1 = x1.T
  43. if flipped:
  44. x0 = x0[::-1]
  45. x1 = x1[::-1]
  46. # Use the add ufunc to do the broadcasting. Since we're adding 0s to x1, the
  47. # result should be exactly the same as the broadcasted view of x1.
  48. y = x0 + x1
  49. b0, b1 = broadcast_arrays(x0, x1)
  50. assert_array_equal(y, b1)
  51. def test_same():
  52. x = np.arange(10)
  53. y = np.arange(10)
  54. bx, by = broadcast_arrays(x, y)
  55. assert_array_equal(x, bx)
  56. assert_array_equal(y, by)
  57. def test_broadcast_kwargs():
  58. # ensure that a TypeError is appropriately raised when
  59. # np.broadcast_arrays() is called with any keyword
  60. # argument other than 'subok'
  61. x = np.arange(10)
  62. y = np.arange(10)
  63. with assert_raises_regex(TypeError, 'got an unexpected keyword'):
  64. broadcast_arrays(x, y, dtype='float64')
  65. def test_one_off():
  66. x = np.array([[1, 2, 3]])
  67. y = np.array([[1], [2], [3]])
  68. bx, by = broadcast_arrays(x, y)
  69. bx0 = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
  70. by0 = bx0.T
  71. assert_array_equal(bx0, bx)
  72. assert_array_equal(by0, by)
  73. def test_same_input_shapes():
  74. # Check that the final shape is just the input shape.
  75. data = [
  76. (),
  77. (1,),
  78. (3,),
  79. (0, 1),
  80. (0, 3),
  81. (1, 0),
  82. (3, 0),
  83. (1, 3),
  84. (3, 1),
  85. (3, 3),
  86. ]
  87. for shape in data:
  88. input_shapes = [shape]
  89. # Single input.
  90. assert_shapes_correct(input_shapes, shape)
  91. # Double input.
  92. input_shapes2 = [shape, shape]
  93. assert_shapes_correct(input_shapes2, shape)
  94. # Triple input.
  95. input_shapes3 = [shape, shape, shape]
  96. assert_shapes_correct(input_shapes3, shape)
  97. def test_two_compatible_by_ones_input_shapes():
  98. # Check that two different input shapes of the same length, but some have
  99. # ones, broadcast to the correct shape.
  100. data = [
  101. [[(1,), (3,)], (3,)],
  102. [[(1, 3), (3, 3)], (3, 3)],
  103. [[(3, 1), (3, 3)], (3, 3)],
  104. [[(1, 3), (3, 1)], (3, 3)],
  105. [[(1, 1), (3, 3)], (3, 3)],
  106. [[(1, 1), (1, 3)], (1, 3)],
  107. [[(1, 1), (3, 1)], (3, 1)],
  108. [[(1, 0), (0, 0)], (0, 0)],
  109. [[(0, 1), (0, 0)], (0, 0)],
  110. [[(1, 0), (0, 1)], (0, 0)],
  111. [[(1, 1), (0, 0)], (0, 0)],
  112. [[(1, 1), (1, 0)], (1, 0)],
  113. [[(1, 1), (0, 1)], (0, 1)],
  114. ]
  115. for input_shapes, expected_shape in data:
  116. assert_shapes_correct(input_shapes, expected_shape)
  117. # Reverse the input shapes since broadcasting should be symmetric.
  118. assert_shapes_correct(input_shapes[::-1], expected_shape)
  119. def test_two_compatible_by_prepending_ones_input_shapes():
  120. # Check that two different input shapes (of different lengths) broadcast
  121. # to the correct shape.
  122. data = [
  123. [[(), (3,)], (3,)],
  124. [[(3,), (3, 3)], (3, 3)],
  125. [[(3,), (3, 1)], (3, 3)],
  126. [[(1,), (3, 3)], (3, 3)],
  127. [[(), (3, 3)], (3, 3)],
  128. [[(1, 1), (3,)], (1, 3)],
  129. [[(1,), (3, 1)], (3, 1)],
  130. [[(1,), (1, 3)], (1, 3)],
  131. [[(), (1, 3)], (1, 3)],
  132. [[(), (3, 1)], (3, 1)],
  133. [[(), (0,)], (0,)],
  134. [[(0,), (0, 0)], (0, 0)],
  135. [[(0,), (0, 1)], (0, 0)],
  136. [[(1,), (0, 0)], (0, 0)],
  137. [[(), (0, 0)], (0, 0)],
  138. [[(1, 1), (0,)], (1, 0)],
  139. [[(1,), (0, 1)], (0, 1)],
  140. [[(1,), (1, 0)], (1, 0)],
  141. [[(), (1, 0)], (1, 0)],
  142. [[(), (0, 1)], (0, 1)],
  143. ]
  144. for input_shapes, expected_shape in data:
  145. assert_shapes_correct(input_shapes, expected_shape)
  146. # Reverse the input shapes since broadcasting should be symmetric.
  147. assert_shapes_correct(input_shapes[::-1], expected_shape)
  148. def test_incompatible_shapes_raise_valueerror():
  149. # Check that a ValueError is raised for incompatible shapes.
  150. data = [
  151. [(3,), (4,)],
  152. [(2, 3), (2,)],
  153. [(3,), (3,), (4,)],
  154. [(1, 3, 4), (2, 3, 3)],
  155. ]
  156. for input_shapes in data:
  157. assert_incompatible_shapes_raise(input_shapes)
  158. # Reverse the input shapes since broadcasting should be symmetric.
  159. assert_incompatible_shapes_raise(input_shapes[::-1])
  160. def test_same_as_ufunc():
  161. # Check that the data layout is the same as if a ufunc did the operation.
  162. data = [
  163. [[(1,), (3,)], (3,)],
  164. [[(1, 3), (3, 3)], (3, 3)],
  165. [[(3, 1), (3, 3)], (3, 3)],
  166. [[(1, 3), (3, 1)], (3, 3)],
  167. [[(1, 1), (3, 3)], (3, 3)],
  168. [[(1, 1), (1, 3)], (1, 3)],
  169. [[(1, 1), (3, 1)], (3, 1)],
  170. [[(1, 0), (0, 0)], (0, 0)],
  171. [[(0, 1), (0, 0)], (0, 0)],
  172. [[(1, 0), (0, 1)], (0, 0)],
  173. [[(1, 1), (0, 0)], (0, 0)],
  174. [[(1, 1), (1, 0)], (1, 0)],
  175. [[(1, 1), (0, 1)], (0, 1)],
  176. [[(), (3,)], (3,)],
  177. [[(3,), (3, 3)], (3, 3)],
  178. [[(3,), (3, 1)], (3, 3)],
  179. [[(1,), (3, 3)], (3, 3)],
  180. [[(), (3, 3)], (3, 3)],
  181. [[(1, 1), (3,)], (1, 3)],
  182. [[(1,), (3, 1)], (3, 1)],
  183. [[(1,), (1, 3)], (1, 3)],
  184. [[(), (1, 3)], (1, 3)],
  185. [[(), (3, 1)], (3, 1)],
  186. [[(), (0,)], (0,)],
  187. [[(0,), (0, 0)], (0, 0)],
  188. [[(0,), (0, 1)], (0, 0)],
  189. [[(1,), (0, 0)], (0, 0)],
  190. [[(), (0, 0)], (0, 0)],
  191. [[(1, 1), (0,)], (1, 0)],
  192. [[(1,), (0, 1)], (0, 1)],
  193. [[(1,), (1, 0)], (1, 0)],
  194. [[(), (1, 0)], (1, 0)],
  195. [[(), (0, 1)], (0, 1)],
  196. ]
  197. for input_shapes, expected_shape in data:
  198. assert_same_as_ufunc(input_shapes[0], input_shapes[1],
  199. f"Shapes: {input_shapes[0]} {input_shapes[1]}")
  200. # Reverse the input shapes since broadcasting should be symmetric.
  201. assert_same_as_ufunc(input_shapes[1], input_shapes[0])
  202. # Try them transposed, too.
  203. assert_same_as_ufunc(input_shapes[0], input_shapes[1], True)
  204. # ... and flipped for non-rank-0 inputs in order to test negative
  205. # strides.
  206. if () not in input_shapes:
  207. assert_same_as_ufunc(input_shapes[0], input_shapes[1], False, True)
  208. assert_same_as_ufunc(input_shapes[0], input_shapes[1], True, True)
  209. def test_broadcast_to_succeeds():
  210. data = [
  211. [np.array(0), (0,), np.array(0)],
  212. [np.array(0), (1,), np.zeros(1)],
  213. [np.array(0), (3,), np.zeros(3)],
  214. [np.ones(1), (1,), np.ones(1)],
  215. [np.ones(1), (2,), np.ones(2)],
  216. [np.ones(1), (1, 2, 3), np.ones((1, 2, 3))],
  217. [np.arange(3), (3,), np.arange(3)],
  218. [np.arange(3), (1, 3), np.arange(3).reshape(1, -1)],
  219. [np.arange(3), (2, 3), np.array([[0, 1, 2], [0, 1, 2]])],
  220. # test if shape is not a tuple
  221. [np.ones(0), 0, np.ones(0)],
  222. [np.ones(1), 1, np.ones(1)],
  223. [np.ones(1), 2, np.ones(2)],
  224. # these cases with size 0 are strange, but they reproduce the behavior
  225. # of broadcasting with ufuncs (see test_same_as_ufunc above)
  226. [np.ones(1), (0,), np.ones(0)],
  227. [np.ones((1, 2)), (0, 2), np.ones((0, 2))],
  228. [np.ones((2, 1)), (2, 0), np.ones((2, 0))],
  229. ]
  230. for input_array, shape, expected in data:
  231. actual = broadcast_to(input_array, shape)
  232. assert_array_equal(expected, actual)
  233. def test_broadcast_to_raises():
  234. data = [
  235. [(0,), ()],
  236. [(1,), ()],
  237. [(3,), ()],
  238. [(3,), (1,)],
  239. [(3,), (2,)],
  240. [(3,), (4,)],
  241. [(1, 2), (2, 1)],
  242. [(1, 1), (1,)],
  243. [(1,), -1],
  244. [(1,), (-1,)],
  245. [(1, 2), (-1, 2)],
  246. ]
  247. for orig_shape, target_shape in data:
  248. arr = np.zeros(orig_shape)
  249. assert_raises(ValueError, lambda: broadcast_to(arr, target_shape))
  250. def test_broadcast_shape():
  251. # tests internal _broadcast_shape
  252. # _broadcast_shape is already exercised indirectly by broadcast_arrays
  253. # _broadcast_shape is also exercised by the public broadcast_shapes function
  254. assert_equal(_broadcast_shape(), ())
  255. assert_equal(_broadcast_shape([1, 2]), (2,))
  256. assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1))
  257. assert_equal(_broadcast_shape(np.ones((1, 1)), np.ones((3, 4))), (3, 4))
  258. assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 32)), (1, 2))
  259. assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 100)), (1, 2))
  260. # regression tests for gh-5862
  261. assert_equal(_broadcast_shape(*([np.ones(2)] * 32 + [1])), (2,))
  262. bad_args = [np.ones(2)] * 32 + [np.ones(3)] * 32
  263. assert_raises(ValueError, lambda: _broadcast_shape(*bad_args))
  264. def test_broadcast_shapes_succeeds():
  265. # tests public broadcast_shapes
  266. data = [
  267. [[], ()],
  268. [[()], ()],
  269. [[(7,)], (7,)],
  270. [[(1, 2), (2,)], (1, 2)],
  271. [[(1, 1)], (1, 1)],
  272. [[(1, 1), (3, 4)], (3, 4)],
  273. [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
  274. [[(5, 6, 1)], (5, 6, 1)],
  275. [[(1, 3), (3, 1)], (3, 3)],
  276. [[(1, 0), (0, 0)], (0, 0)],
  277. [[(0, 1), (0, 0)], (0, 0)],
  278. [[(1, 0), (0, 1)], (0, 0)],
  279. [[(1, 1), (0, 0)], (0, 0)],
  280. [[(1, 1), (1, 0)], (1, 0)],
  281. [[(1, 1), (0, 1)], (0, 1)],
  282. [[(), (0,)], (0,)],
  283. [[(0,), (0, 0)], (0, 0)],
  284. [[(0,), (0, 1)], (0, 0)],
  285. [[(1,), (0, 0)], (0, 0)],
  286. [[(), (0, 0)], (0, 0)],
  287. [[(1, 1), (0,)], (1, 0)],
  288. [[(1,), (0, 1)], (0, 1)],
  289. [[(1,), (1, 0)], (1, 0)],
  290. [[(), (1, 0)], (1, 0)],
  291. [[(), (0, 1)], (0, 1)],
  292. [[(1,), (3,)], (3,)],
  293. [[2, (3, 2)], (3, 2)],
  294. ]
  295. for input_shapes, target_shape in data:
  296. assert_equal(broadcast_shapes(*input_shapes), target_shape)
  297. assert_equal(broadcast_shapes(*([(1, 2)] * 32)), (1, 2))
  298. assert_equal(broadcast_shapes(*([(1, 2)] * 100)), (1, 2))
  299. # regression tests for gh-5862
  300. assert_equal(broadcast_shapes(*([(2,)] * 32)), (2,))
  301. def test_broadcast_shapes_raises():
  302. # tests public broadcast_shapes
  303. data = [
  304. [(3,), (4,)],
  305. [(2, 3), (2,)],
  306. [(3,), (3,), (4,)],
  307. [(1, 3, 4), (2, 3, 3)],
  308. [(1, 2), (3, 1), (3, 2), (10, 5)],
  309. [2, (2, 3)],
  310. ]
  311. for input_shapes in data:
  312. assert_raises(ValueError, lambda: broadcast_shapes(*input_shapes))
  313. bad_args = [(2,)] * 32 + [(3,)] * 32
  314. assert_raises(ValueError, lambda: broadcast_shapes(*bad_args))
  315. def test_as_strided():
  316. a = np.array([None])
  317. a_view = as_strided(a)
  318. expected = np.array([None])
  319. assert_array_equal(a_view, np.array([None]))
  320. a = np.array([1, 2, 3, 4])
  321. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
  322. expected = np.array([1, 3])
  323. assert_array_equal(a_view, expected)
  324. a = np.array([1, 2, 3, 4])
  325. a_view = as_strided(a, shape=(3, 4), strides=(0, 1 * a.itemsize))
  326. expected = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])
  327. assert_array_equal(a_view, expected)
  328. # Regression test for gh-5081
  329. dt = np.dtype([('num', 'i4'), ('obj', 'O')])
  330. a = np.empty((4,), dtype=dt)
  331. a['num'] = np.arange(1, 5)
  332. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  333. expected_num = [[1, 2, 3, 4]] * 3
  334. expected_obj = [[None] * 4] * 3
  335. assert_equal(a_view.dtype, dt)
  336. assert_array_equal(expected_num, a_view['num'])
  337. assert_array_equal(expected_obj, a_view['obj'])
  338. # Make sure that void types without fields are kept unchanged
  339. a = np.empty((4,), dtype='V4')
  340. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  341. assert_equal(a.dtype, a_view.dtype)
  342. # Make sure that the only type that could fail is properly handled
  343. dt = np.dtype({'names': [''], 'formats': ['V4']})
  344. a = np.empty((4,), dtype=dt)
  345. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  346. assert_equal(a.dtype, a_view.dtype)
  347. # Custom dtypes should not be lost (gh-9161)
  348. r = [rational(i) for i in range(4)]
  349. a = np.array(r, dtype=rational)
  350. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  351. assert_equal(a.dtype, a_view.dtype)
  352. assert_array_equal([r] * 3, a_view)
  353. class TestSlidingWindowView:
  354. def test_1d(self):
  355. arr = np.arange(5)
  356. arr_view = sliding_window_view(arr, 2)
  357. expected = np.array([[0, 1],
  358. [1, 2],
  359. [2, 3],
  360. [3, 4]])
  361. assert_array_equal(arr_view, expected)
  362. def test_2d(self):
  363. i, j = np.ogrid[:3, :4]
  364. arr = 10 * i + j
  365. shape = (2, 2)
  366. arr_view = sliding_window_view(arr, shape)
  367. expected = np.array([[[[0, 1], [10, 11]],
  368. [[1, 2], [11, 12]],
  369. [[2, 3], [12, 13]]],
  370. [[[10, 11], [20, 21]],
  371. [[11, 12], [21, 22]],
  372. [[12, 13], [22, 23]]]])
  373. assert_array_equal(arr_view, expected)
  374. def test_2d_with_axis(self):
  375. i, j = np.ogrid[:3, :4]
  376. arr = 10 * i + j
  377. arr_view = sliding_window_view(arr, 3, 0)
  378. expected = np.array([[[0, 10, 20],
  379. [1, 11, 21],
  380. [2, 12, 22],
  381. [3, 13, 23]]])
  382. assert_array_equal(arr_view, expected)
  383. def test_2d_repeated_axis(self):
  384. i, j = np.ogrid[:3, :4]
  385. arr = 10 * i + j
  386. arr_view = sliding_window_view(arr, (2, 3), (1, 1))
  387. expected = np.array([[[[0, 1, 2],
  388. [1, 2, 3]]],
  389. [[[10, 11, 12],
  390. [11, 12, 13]]],
  391. [[[20, 21, 22],
  392. [21, 22, 23]]]])
  393. assert_array_equal(arr_view, expected)
  394. def test_2d_without_axis(self):
  395. i, j = np.ogrid[:4, :4]
  396. arr = 10 * i + j
  397. shape = (2, 3)
  398. arr_view = sliding_window_view(arr, shape)
  399. expected = np.array([[[[0, 1, 2], [10, 11, 12]],
  400. [[1, 2, 3], [11, 12, 13]]],
  401. [[[10, 11, 12], [20, 21, 22]],
  402. [[11, 12, 13], [21, 22, 23]]],
  403. [[[20, 21, 22], [30, 31, 32]],
  404. [[21, 22, 23], [31, 32, 33]]]])
  405. assert_array_equal(arr_view, expected)
  406. def test_errors(self):
  407. i, j = np.ogrid[:4, :4]
  408. arr = 10 * i + j
  409. with pytest.raises(ValueError, match='cannot contain negative values'):
  410. sliding_window_view(arr, (-1, 3))
  411. with pytest.raises(
  412. ValueError,
  413. match='must provide window_shape for all dimensions of `x`'):
  414. sliding_window_view(arr, (1,))
  415. with pytest.raises(
  416. ValueError,
  417. match='Must provide matching length window_shape and axis'):
  418. sliding_window_view(arr, (1, 3, 4), axis=(0, 1))
  419. with pytest.raises(
  420. ValueError,
  421. match='window shape cannot be larger than input array'):
  422. sliding_window_view(arr, (5, 5))
  423. def test_writeable(self):
  424. arr = np.arange(5)
  425. view = sliding_window_view(arr, 2, writeable=False)
  426. assert_(not view.flags.writeable)
  427. with pytest.raises(
  428. ValueError,
  429. match='assignment destination is read-only'):
  430. view[0, 0] = 3
  431. view = sliding_window_view(arr, 2, writeable=True)
  432. assert_(view.flags.writeable)
  433. view[0, 1] = 3
  434. assert_array_equal(arr, np.array([0, 3, 2, 3, 4]))
  435. def test_subok(self):
  436. class MyArray(np.ndarray):
  437. pass
  438. arr = np.arange(5).view(MyArray)
  439. assert_(not isinstance(sliding_window_view(arr, 2,
  440. subok=False),
  441. MyArray))
  442. assert_(isinstance(sliding_window_view(arr, 2, subok=True), MyArray))
  443. # Default behavior
  444. assert_(not isinstance(sliding_window_view(arr, 2), MyArray))
  445. def as_strided_writeable():
  446. arr = np.ones(10)
  447. view = as_strided(arr, writeable=False)
  448. assert_(not view.flags.writeable)
  449. # Check that writeable also is fine:
  450. view = as_strided(arr, writeable=True)
  451. assert_(view.flags.writeable)
  452. view[...] = 3
  453. assert_array_equal(arr, np.full_like(arr, 3))
  454. # Test that things do not break down for readonly:
  455. arr.flags.writeable = False
  456. view = as_strided(arr, writeable=False)
  457. view = as_strided(arr, writeable=True)
  458. assert_(not view.flags.writeable)
  459. class VerySimpleSubClass(np.ndarray):
  460. def __new__(cls, *args, **kwargs):
  461. return np.array(*args, subok=True, **kwargs).view(cls)
  462. class SimpleSubClass(VerySimpleSubClass):
  463. def __new__(cls, *args, **kwargs):
  464. self = np.array(*args, subok=True, **kwargs).view(cls)
  465. self.info = 'simple'
  466. return self
  467. def __array_finalize__(self, obj):
  468. self.info = getattr(obj, 'info', '') + ' finalized'
  469. def test_subclasses():
  470. # test that subclass is preserved only if subok=True
  471. a = VerySimpleSubClass([1, 2, 3, 4])
  472. assert_(type(a) is VerySimpleSubClass)
  473. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
  474. assert_(type(a_view) is np.ndarray)
  475. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
  476. assert_(type(a_view) is VerySimpleSubClass)
  477. # test that if a subclass has __array_finalize__, it is used
  478. a = SimpleSubClass([1, 2, 3, 4])
  479. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
  480. assert_(type(a_view) is SimpleSubClass)
  481. assert_(a_view.info == 'simple finalized')
  482. # similar tests for broadcast_arrays
  483. b = np.arange(len(a)).reshape(-1, 1)
  484. a_view, b_view = broadcast_arrays(a, b)
  485. assert_(type(a_view) is np.ndarray)
  486. assert_(type(b_view) is np.ndarray)
  487. assert_(a_view.shape == b_view.shape)
  488. a_view, b_view = broadcast_arrays(a, b, subok=True)
  489. assert_(type(a_view) is SimpleSubClass)
  490. assert_(a_view.info == 'simple finalized')
  491. assert_(type(b_view) is np.ndarray)
  492. assert_(a_view.shape == b_view.shape)
  493. # and for broadcast_to
  494. shape = (2, 4)
  495. a_view = broadcast_to(a, shape)
  496. assert_(type(a_view) is np.ndarray)
  497. assert_(a_view.shape == shape)
  498. a_view = broadcast_to(a, shape, subok=True)
  499. assert_(type(a_view) is SimpleSubClass)
  500. assert_(a_view.info == 'simple finalized')
  501. assert_(a_view.shape == shape)
  502. def test_writeable():
  503. # broadcast_to should return a readonly array
  504. original = np.array([1, 2, 3])
  505. result = broadcast_to(original, (2, 3))
  506. assert_equal(result.flags.writeable, False)
  507. assert_raises(ValueError, result.__setitem__, slice(None), 0)
  508. # but the result of broadcast_arrays needs to be writeable, to
  509. # preserve backwards compatibility
  510. test_cases = [((False,), broadcast_arrays(original,)),
  511. ((True, False), broadcast_arrays(0, original))]
  512. for is_broadcast, results in test_cases:
  513. for array_is_broadcast, result in zip(is_broadcast, results):
  514. # This will change to False in a future version
  515. if array_is_broadcast:
  516. with pytest.warns(FutureWarning):
  517. assert_equal(result.flags.writeable, True)
  518. with pytest.warns(DeprecationWarning):
  519. result[:] = 0
  520. # Warning not emitted, writing to the array resets it
  521. assert_equal(result.flags.writeable, True)
  522. else:
  523. # No warning:
  524. assert_equal(result.flags.writeable, True)
  525. for results in [broadcast_arrays(original),
  526. broadcast_arrays(0, original)]:
  527. for result in results:
  528. # resets the warn_on_write DeprecationWarning
  529. result.flags.writeable = True
  530. # check: no warning emitted
  531. assert_equal(result.flags.writeable, True)
  532. result[:] = 0
  533. # keep readonly input readonly
  534. original.flags.writeable = False
  535. _, result = broadcast_arrays(0, original)
  536. assert_equal(result.flags.writeable, False)
  537. # regression test for GH6491
  538. shape = (2,)
  539. strides = [0]
  540. tricky_array = as_strided(np.array(0), shape, strides)
  541. other = np.zeros((1,))
  542. first, second = broadcast_arrays(tricky_array, other)
  543. assert_(first.shape == second.shape)
  544. def test_writeable_memoryview():
  545. # The result of broadcast_arrays exports as a non-writeable memoryview
  546. # because otherwise there is no good way to opt in to the new behaviour
  547. # (i.e. you would need to set writeable to False explicitly).
  548. # See gh-13929.
  549. original = np.array([1, 2, 3])
  550. test_cases = [((False, ), broadcast_arrays(original,)),
  551. ((True, False), broadcast_arrays(0, original))]
  552. for is_broadcast, results in test_cases:
  553. for array_is_broadcast, result in zip(is_broadcast, results):
  554. # This will change to False in a future version
  555. if array_is_broadcast:
  556. # memoryview(result, writable=True) will give warning but cannot
  557. # be tested using the python API.
  558. assert memoryview(result).readonly
  559. else:
  560. assert not memoryview(result).readonly
  561. def test_reference_types():
  562. input_array = np.array('a', dtype=object)
  563. expected = np.array(['a'] * 3, dtype=object)
  564. actual = broadcast_to(input_array, (3,))
  565. assert_array_equal(expected, actual)
  566. actual, _ = broadcast_arrays(input_array, np.ones(3))
  567. assert_array_equal(expected, actual)