test_ops.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. import itertools
  2. import numpy as np
  3. import pytest
  4. from einops import EinopsError
  5. from einops.einops import _enumerate_directions, rearrange, reduce, repeat
  6. from einops.tests import FLOAT_REDUCTIONS as REDUCTIONS
  7. from einops.tests import collect_test_backends, is_backend_tested
  8. imp_op_backends = collect_test_backends(symbolic=False, layers=False)
  9. sym_op_backends = collect_test_backends(symbolic=True, layers=False)
  10. rng = np.random.default_rng()
  11. identity_patterns = [
  12. "...->...",
  13. "a b c d e-> a b c d e",
  14. "a b c d e ...-> ... a b c d e",
  15. "a b c d e ...-> a ... b c d e",
  16. "... a b c d e -> ... a b c d e",
  17. "a ... e-> a ... e",
  18. "a ... -> a ... ",
  19. "a ... c d e -> a (...) c d e",
  20. ]
  21. equivalent_rearrange_patterns = [
  22. ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
  23. ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
  24. ("a b c d e -> a b c d e", "... -> ... "),
  25. ("a b c d e -> (a b c d e)", "... -> (...)"),
  26. ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"),
  27. ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"),
  28. ]
  29. equivalent_reduction_patterns = [
  30. ("a b c d e -> ", " ... -> "),
  31. ("a b c d e -> (e a)", "a ... e -> (e a)"),
  32. ("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "),
  33. ("a b c d e -> (a b)", " ... c d e -> (...) "),
  34. ]
  35. def test_collapsed_ellipsis_errors_out():
  36. x = np.zeros([1, 1, 1, 1, 1])
  37. rearrange(x, "a b c d ... -> a b c ... d")
  38. with pytest.raises(EinopsError):
  39. rearrange(x, "a b c d (...) -> a b c ... d")
  40. rearrange(x, "... -> (...)")
  41. with pytest.raises(EinopsError):
  42. rearrange(x, "(...) -> (...)")
  43. def test_ellipsis_ops_numpy():
  44. x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
  45. for pattern in identity_patterns:
  46. assert np.array_equal(x, rearrange(x, pattern)), pattern
  47. for pattern1, pattern2 in equivalent_rearrange_patterns:
  48. assert np.array_equal(rearrange(x, pattern1), rearrange(x, pattern2))
  49. for reduction in ["min", "max", "sum"]:
  50. for pattern1, pattern2 in equivalent_reduction_patterns:
  51. assert np.array_equal(reduce(x, pattern1, reduction=reduction), reduce(x, pattern2, reduction=reduction))
  52. # now just check coincidence with numpy
  53. all_rearrange_patterns = [*identity_patterns]
  54. for pattern_pairs in equivalent_rearrange_patterns:
  55. all_rearrange_patterns.extend(pattern_pairs)
  56. def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction="rearrange", is_symbolic=False):
  57. """
  58. Helper to test result of operation (rearrange or transpose) against numpy
  59. if reduction == 'rearrange', rearrange op is tested, otherwise reduce
  60. """
  61. def operation(x):
  62. if reduction == "rearrange":
  63. return rearrange(x, pattern, **axes_lengths)
  64. else:
  65. return reduce(x, pattern, reduction, **axes_lengths)
  66. numpy_result = operation(numpy_input)
  67. check_equal = np.array_equal
  68. p_none_dimension = 0.5
  69. if is_symbolic:
  70. symbol_shape = [d if rng.random() >= p_none_dimension else None for d in numpy_input.shape]
  71. symbol = backend.create_symbol(shape=symbol_shape)
  72. result_symbol = operation(symbol)
  73. backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)])
  74. else:
  75. backend_result = operation(backend.from_numpy(numpy_input))
  76. backend_result = backend.to_numpy(backend_result)
  77. check_equal(numpy_result, backend_result)
  78. def test_ellipsis_ops_imperative():
  79. """Checking various patterns against numpy"""
  80. x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
  81. for is_symbolic in [True, False]:
  82. for backend in collect_test_backends(symbolic=is_symbolic, layers=False):
  83. for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)):
  84. check_op_against_numpy(
  85. backend, x, pattern, axes_lengths={}, reduction="rearrange", is_symbolic=is_symbolic
  86. )
  87. for reduction in ["min", "max", "sum"]:
  88. for pattern in itertools.chain(*equivalent_reduction_patterns):
  89. check_op_against_numpy(
  90. backend, x, pattern, axes_lengths={}, reduction=reduction, is_symbolic=is_symbolic
  91. )
  92. def test_rearrange_array_api():
  93. import numpy as xp
  94. from einops import array_api as AA
  95. if xp.__version__ < "2.0.0":
  96. pytest.skip()
  97. x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
  98. for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)):
  99. expected = rearrange(x, pattern)
  100. result = AA.rearrange(xp.from_dlpack(x), pattern)
  101. assert np.array_equal(AA.asnumpy(result + 0), expected)
  102. def test_reduce_array_api():
  103. import numpy as xp
  104. from einops import array_api as AA
  105. if xp.__version__ < "2.0.0":
  106. pytest.skip()
  107. x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
  108. for pattern in itertools.chain(*equivalent_reduction_patterns):
  109. for reduction in ["min", "max", "sum"]:
  110. expected = reduce(x, pattern, reduction=reduction)
  111. result = AA.reduce(xp.from_dlpack(x), pattern, reduction=reduction)
  112. assert np.array_equal(AA.asnumpy(np.asarray(result + 0)), expected)
  113. def test_rearrange_consistency_numpy():
  114. shape = [1, 2, 3, 5, 7, 11]
  115. x = np.arange(np.prod(shape)).reshape(shape)
  116. for pattern in [
  117. "a b c d e f -> a b c d e f",
  118. "b a c d e f -> a b d e f c",
  119. "a b c d e f -> f e d c b a",
  120. "a b c d e f -> (f e) d (c b a)",
  121. "a b c d e f -> (f e d c b a)",
  122. ]:
  123. result = rearrange(x, pattern)
  124. assert len(np.setdiff1d(x, result)) == 0
  125. assert result.dtype == x.dtype
  126. result = rearrange(x, "a b c d e f -> a (b) (c d e) f")
  127. assert np.array_equal(x.flatten(), result.flatten())
  128. result = rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11")
  129. assert np.array_equal(x, result)
  130. result1 = rearrange(x, "a b c d e f -> f e d c b a")
  131. result2 = rearrange(x, "f e d c b a -> a b c d e f")
  132. assert np.array_equal(result1, result2)
  133. result = rearrange(rearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, d=5)
  134. assert np.array_equal(x, result)
  135. sizes = dict(zip("abcdef", shape))
  136. temp = rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes)
  137. result = rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes)
  138. assert np.array_equal(x, result)
  139. x2 = np.arange(2 * 3 * 4).reshape([2, 3, 4])
  140. result = rearrange(x2, "a b c -> b c a")
  141. assert x2[1, 2, 3] == result[2, 3, 1]
  142. assert x2[0, 1, 2] == result[1, 2, 0]
  143. def test_rearrange_permutations_numpy():
  144. # tests random permutation of axes against two independent numpy ways
  145. for n_axes in range(1, 10):
  146. input = np.arange(2**n_axes).reshape([2] * n_axes)
  147. permutation = rng.permutation(n_axes)
  148. left_expression = " ".join("i" + str(axis) for axis in range(n_axes))
  149. right_expression = " ".join("i" + str(axis) for axis in permutation)
  150. expression = left_expression + " -> " + right_expression
  151. result = rearrange(input, expression)
  152. for pick in rng.integers(0, 2, [10, n_axes]):
  153. assert input[tuple(pick)] == result[tuple(pick[permutation])]
  154. for n_axes in range(1, 10):
  155. input = np.arange(2**n_axes).reshape([2] * n_axes)
  156. permutation = rng.permutation(n_axes)
  157. left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1])
  158. right_expression = " ".join("i" + str(axis) for axis in permutation[::-1])
  159. expression = left_expression + " -> " + right_expression
  160. result = rearrange(input, expression)
  161. assert result.shape == input.shape
  162. expected_result = np.zeros_like(input)
  163. for original_axis, result_axis in enumerate(permutation):
  164. expected_result |= ((input >> original_axis) & 1) << result_axis
  165. assert np.array_equal(result, expected_result)
  166. def test_reduction_imperatives():
  167. for backend in imp_op_backends:
  168. print("Reduction tests for ", backend.framework_name)
  169. for reduction in REDUCTIONS:
  170. # slight redundancy for simpler order - numpy version is evaluated multiple times
  171. input = np.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6])
  172. if reduction in ["mean", "prod"]:
  173. input = input / input.astype("float64").mean()
  174. test_cases = [
  175. ["a b c d e -> ", {}, getattr(input, reduction)()],
  176. ["a ... -> ", {}, getattr(input, reduction)()],
  177. ["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()],
  178. [
  179. "a b c d e -> (e c) a",
  180. {},
  181. getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
  182. ],
  183. [
  184. "a ... c d e -> (e c) a",
  185. {},
  186. getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
  187. ],
  188. [
  189. "a b c d e ... -> (e c) a",
  190. {},
  191. getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
  192. ],
  193. ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
  194. ["(a a2) ... -> (a2 a) ...", dict(a2=1), input],
  195. ]
  196. for pattern, axes_lengths, expected_result in test_cases:
  197. result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
  198. result = backend.to_numpy(result)
  199. assert np.allclose(result, expected_result), f"Failed at {pattern}"
  200. def test_reduction_symbolic():
  201. for backend in sym_op_backends:
  202. print("Reduction tests for ", backend.framework_name)
  203. for reduction in REDUCTIONS:
  204. input = np.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6])
  205. input = input / input.astype("float64").mean()
  206. # slight redundancy for simpler order - numpy version is evaluated multiple times
  207. test_cases = [
  208. ["a b c d e -> ", {}, getattr(input, reduction)()],
  209. ["a ... -> ", {}, getattr(input, reduction)()],
  210. ["(a a2) ... (e e2) -> ", dict(a2=1, e2=1), getattr(input, reduction)()],
  211. [
  212. "a b c d e -> (e c) a",
  213. {},
  214. getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
  215. ],
  216. [
  217. "a ... c d e -> (e c) a",
  218. {},
  219. getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
  220. ],
  221. [
  222. "a b c d e ... -> (e c) a",
  223. {},
  224. getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]),
  225. ],
  226. ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
  227. ["(a a2) ... -> (a2 a) ...", dict(a2=1), input],
  228. ]
  229. for pattern, axes_lengths, expected_numpy_result in test_cases:
  230. shapes = [input.shape, [None for _ in input.shape]]
  231. for shape in shapes:
  232. sym = backend.create_symbol(shape)
  233. result_sym = reduce(sym, pattern, reduction=reduction, **axes_lengths)
  234. result = backend.eval_symbol(result_sym, [(sym, input)])
  235. assert np.allclose(result, expected_numpy_result)
  236. if True:
  237. shape = []
  238. _axes_lengths = {**axes_lengths}
  239. for axis, length in zip("abcde", input.shape):
  240. # filling as much as possible with Nones
  241. if axis in pattern:
  242. shape.append(None)
  243. _axes_lengths[axis] = length
  244. else:
  245. shape.append(length)
  246. sym = backend.create_symbol(shape)
  247. result_sym = reduce(sym, pattern, reduction=reduction, **_axes_lengths)
  248. result = backend.eval_symbol(result_sym, [(sym, input)])
  249. assert np.allclose(result, expected_numpy_result)
  250. def test_reduction_stress_imperatives():
  251. for backend in imp_op_backends:
  252. print("Stress-testing reduction for ", backend.framework_name)
  253. for reduction in [*REDUCTIONS, "rearrange"]:
  254. dtype = "int64"
  255. coincide = np.array_equal
  256. if reduction in ["mean", "prod"]:
  257. dtype = "float64"
  258. coincide = np.allclose
  259. max_dim = 11
  260. if "oneflow" in backend.framework_name:
  261. max_dim = 7
  262. if "paddle" in backend.framework_name:
  263. max_dim = 9
  264. for n_axes in range(max_dim):
  265. shape = rng.integers(2, 4, size=n_axes)
  266. permutation = rng.permutation(n_axes)
  267. skipped = 0 if reduction == "rearrange" else rng.integers(n_axes + 1)
  268. left = " ".join("x" + str(i) for i in range(n_axes))
  269. right = " ".join("x" + str(i) for i in permutation[skipped:])
  270. pattern = left + "->" + right
  271. x = np.arange(1, 1 + np.prod(shape), dtype=dtype).reshape(shape)
  272. if reduction == "prod":
  273. x /= x.mean() # to avoid overflows
  274. result1 = reduce(x, pattern, reduction=reduction)
  275. result2 = x.transpose(permutation)
  276. if skipped > 0:
  277. result2 = getattr(result2, reduction)(axis=tuple(range(skipped)))
  278. assert coincide(result1, result2)
  279. check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False)
  280. def test_reduction_with_callable_imperatives():
  281. x_numpy = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype("float32")
  282. x_numpy /= x_numpy.max()
  283. def logsumexp_torch(x, tuple_of_axes):
  284. return x.logsumexp(tuple_of_axes)
  285. def logsumexp_tf(x, tuple_of_axes):
  286. import tensorflow as tf
  287. return tf.reduce_logsumexp(x, tuple_of_axes)
  288. def logsumexp_keras(x, tuple_of_axes):
  289. import tensorflow.keras.backend as k
  290. return k.logsumexp(x, tuple_of_axes)
  291. def logsumexp_numpy(x, tuple_of_axes):
  292. # very naive logsumexp to compare to
  293. minused = x.max(tuple_of_axes)
  294. y = x - x.max(tuple_of_axes, keepdims=True)
  295. y = np.exp(y)
  296. y = np.sum(y, axis=tuple_of_axes)
  297. return np.log(y) + minused
  298. from einops._backends import NumpyBackend, TensorflowBackend, TFKerasBackend, TorchBackend
  299. backend2callback = {
  300. TorchBackend.framework_name: logsumexp_torch,
  301. TensorflowBackend.framework_name: logsumexp_tf,
  302. TFKerasBackend.framework_name: logsumexp_keras,
  303. NumpyBackend.framework_name: logsumexp_numpy,
  304. }
  305. for backend in imp_op_backends:
  306. if backend.framework_name not in backend2callback:
  307. continue
  308. backend_callback = backend2callback[backend.framework_name]
  309. x_backend = backend.from_numpy(x_numpy)
  310. for pattern1, pattern2 in equivalent_reduction_patterns:
  311. print("Test reduction with callable for ", backend.framework_name, pattern1, pattern2)
  312. output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy)
  313. output_backend = reduce(x_backend, pattern1, reduction=backend_callback)
  314. assert np.allclose(
  315. output_numpy,
  316. backend.to_numpy(output_backend),
  317. )
  318. def test_enumerating_directions():
  319. for backend in imp_op_backends:
  320. print("testing directions for", backend.framework_name)
  321. for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
  322. x = np.arange(np.prod(shape)).reshape(shape)
  323. axes1 = _enumerate_directions(x)
  324. axes2 = _enumerate_directions(backend.from_numpy(x))
  325. assert len(axes1) == len(axes2) == len(shape)
  326. axes2 = [backend.to_numpy(ax) for ax in axes2]
  327. for ax1, ax2 in zip(axes1, axes2):
  328. assert ax1.shape == ax2.shape
  329. assert np.allclose(ax1, ax2)
  330. def test_concatenations_and_stacking():
  331. for backend in imp_op_backends:
  332. print("testing shapes for ", backend.framework_name)
  333. for n_arrays in [1, 2, 5]:
  334. shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
  335. for shape in shapes:
  336. arrays1 = [np.arange(i, i + np.prod(shape)).reshape(shape) for i in range(n_arrays)]
  337. arrays2 = [backend.from_numpy(array) for array in arrays1]
  338. result0 = np.asarray(arrays1)
  339. result1 = rearrange(arrays1, "...->...")
  340. result2 = rearrange(arrays2, "...->...")
  341. assert np.array_equal(result0, result1)
  342. assert np.array_equal(result1, backend.to_numpy(result2))
  343. result1 = rearrange(arrays1, "b ... -> ... b")
  344. result2 = rearrange(arrays2, "b ... -> ... b")
  345. assert np.array_equal(result1, backend.to_numpy(result2))
  346. def test_gradients_imperatives():
  347. # lazy - just checking reductions
  348. for reduction in REDUCTIONS:
  349. if reduction in ("any", "all"):
  350. continue # non-differentiable ops
  351. x = np.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32")
  352. results = {}
  353. for backend in imp_op_backends:
  354. y0 = backend.from_numpy(x)
  355. if not hasattr(y0, "grad"):
  356. continue
  357. y1 = reduce(y0, "a b c -> c a", reduction=reduction)
  358. y2 = reduce(y1, "c a -> a c", reduction=reduction)
  359. y3 = reduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2)
  360. y4 = reduce(y3, "... -> ", reduction=reduction)
  361. y4.backward()
  362. grad = backend.to_numpy(y0.grad)
  363. results[backend.framework_name] = grad
  364. print("comparing gradients for", results.keys())
  365. for name1, grad1 in results.items():
  366. for name2, grad2 in results.items():
  367. assert np.allclose(grad1, grad2), [name1, name2, "provided different gradients"]
  368. def test_tiling_imperatives():
  369. for backend in imp_op_backends:
  370. print("Tiling tests for ", backend.framework_name)
  371. input = np.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5])
  372. test_cases = [
  373. (1, 1, 1, 1, 1),
  374. (1, 2, 1, 3, 1),
  375. (3, 1, 1, 4, 1),
  376. ]
  377. for repeats in test_cases:
  378. expected = np.tile(input, repeats)
  379. converted = backend.from_numpy(input)
  380. repeated = backend.tile(converted, repeats)
  381. result = backend.to_numpy(repeated)
  382. assert np.array_equal(result, expected)
  383. def test_tiling_symbolic():
  384. for backend in sym_op_backends:
  385. print("Tiling tests for ", backend.framework_name)
  386. input = np.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5])
  387. test_cases = [
  388. (1, 1, 1, 1, 1),
  389. (1, 2, 1, 3, 1),
  390. (3, 1, 1, 4, 1),
  391. ]
  392. for repeats in test_cases:
  393. expected = np.tile(input, repeats)
  394. sym = backend.create_symbol(input.shape)
  395. result = backend.eval_symbol(backend.tile(sym, repeats), [[sym, input]])
  396. assert np.array_equal(result, expected)
  397. sym = backend.create_symbol([None] * len(input.shape))
  398. result = backend.eval_symbol(backend.tile(sym, repeats), [[sym, input]])
  399. assert np.array_equal(result, expected)
  400. repeat_test_cases = [
  401. # all assume that input has shape [2, 3, 5]
  402. ("a b c -> c a b", dict()),
  403. ("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)),
  404. ("a b c -> (a copy) b c ", dict(copy=1)),
  405. ("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)),
  406. ("a ... -> a ... copy", dict(copy=4)),
  407. ("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)),
  408. ("... -> ... ", dict()),
  409. (" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)),
  410. ("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)),
  411. ]
  412. def check_reversion(x, repeat_pattern, **sizes):
  413. """Checks repeat pattern by running reduction"""
  414. left, right = repeat_pattern.split("->")
  415. reduce_pattern = right + "->" + left
  416. repeated = repeat(x, repeat_pattern, **sizes)
  417. reduced_min = reduce(repeated, reduce_pattern, reduction="min", **sizes)
  418. reduced_max = reduce(repeated, reduce_pattern, reduction="max", **sizes)
  419. assert np.array_equal(x, reduced_min)
  420. assert np.array_equal(x, reduced_max)
  421. def test_repeat_numpy():
  422. # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
  423. x = np.arange(2 * 3 * 5).reshape([2, 3, 5])
  424. x1 = repeat(x, "a b c -> copy a b c ", copy=1)
  425. assert np.array_equal(x[None], x1)
  426. for pattern, axis_dimensions in repeat_test_cases:
  427. check_reversion(x, pattern, **axis_dimensions)
  428. def test_repeat_imperatives():
  429. x = np.arange(2 * 3 * 5).reshape([2, 3, 5])
  430. for backend in imp_op_backends:
  431. print("Repeat tests for ", backend.framework_name)
  432. for pattern, axis_dimensions in repeat_test_cases:
  433. expected = repeat(x, pattern, **axis_dimensions)
  434. converted = backend.from_numpy(x)
  435. repeated = repeat(converted, pattern, **axis_dimensions)
  436. result = backend.to_numpy(repeated)
  437. assert np.array_equal(result, expected)
  438. def test_repeat_symbolic():
  439. x = np.arange(2 * 3 * 5).reshape([2, 3, 5])
  440. for backend in sym_op_backends:
  441. print("Repeat tests for ", backend.framework_name)
  442. for pattern, axis_dimensions in repeat_test_cases:
  443. expected = repeat(x, pattern, **axis_dimensions)
  444. sym = backend.create_symbol(x.shape)
  445. result = backend.eval_symbol(repeat(sym, pattern, **axis_dimensions), [[sym, x]])
  446. assert np.array_equal(result, expected)
  447. def test_repeat_array_api():
  448. import numpy as xp
  449. from einops import array_api as AA
  450. if xp.__version__ < "2.0.0":
  451. pytest.skip()
  452. x = np.arange(2 * 3 * 5).reshape([2, 3, 5])
  453. for pattern, axis_dimensions in repeat_test_cases:
  454. expected = repeat(x, pattern, **axis_dimensions)
  455. result = AA.repeat(xp.from_dlpack(x), pattern, **axis_dimensions)
  456. assert np.array_equal(AA.asnumpy(result + 0), expected)
  457. test_cases_repeat_anonymous = [
  458. # all assume that input has shape [1, 2, 4, 6]
  459. ("a b c d -> c a d b", dict()),
  460. ("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)),
  461. ("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)),
  462. ("1 ... -> 3 ... ", dict()),
  463. ("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)),
  464. ("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()),
  465. ]
  466. def test_anonymous_axes():
  467. x = np.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
  468. for pattern, axis_dimensions in test_cases_repeat_anonymous:
  469. check_reversion(x, pattern, **axis_dimensions)
  470. def test_list_inputs():
  471. x = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
  472. assert np.array_equal(
  473. rearrange(list(x), "... -> (...)"),
  474. rearrange(x, "... -> (...)"),
  475. )
  476. assert np.array_equal(
  477. reduce(list(x), "a ... e -> (...)", "min"),
  478. reduce(x, "a ... e -> (...)", "min"),
  479. )
  480. assert np.array_equal(
  481. repeat(list(x), "... -> b (...)", b=3),
  482. repeat(x, "... -> b (...)", b=3),
  483. )
  484. def test_torch_compile_with_dynamic_shape():
  485. if not is_backend_tested("torch"):
  486. pytest.skip()
  487. import torch
  488. # somewhat reasonable debug messages
  489. torch._dynamo.config.verbose = True
  490. def func1(x):
  491. # test contains ellipsis
  492. a, b, c, *other = x.shape
  493. x = rearrange(x, "(a a2) b c ... -> b (c a2) (a ...)", a2=2)
  494. # test contains passing expression as axis length
  495. x = reduce(x, "b ca2 A -> b A", "sum", ca2=c * 2)
  496. return x
  497. # seems can't test static and dynamic in the same test run.
  498. func1_compiled_static = torch.compile(func1, dynamic=False, fullgraph=True)
  499. func1_compiled_dynamic = torch.compile(func1, dynamic=True, fullgraph=True)
  500. x = torch.randn(size=[4, 5, 6, 3])
  501. assert torch.allclose(func1_compiled_static(x), func1(x), atol=1e-5)
  502. assert torch.allclose(func1_compiled_dynamic(x), func1(x), atol=1e-5)
  503. # check with input of different dimensionality, and with all shape elements changed
  504. x = torch.randn(size=[6, 3, 4, 2, 3])
  505. assert torch.allclose(func1_compiled_static(x), func1(x), atol=1e-5)
  506. assert torch.allclose(func1_compiled_dynamic(x), func1(x), atol=1e-5)
  507. def bit_count(x):
  508. return sum((x >> i) & 1 for i in range(20))
  509. def test_reduction_imperatives_booleans():
  510. """Checks that any/all reduction works in all frameworks"""
  511. x_np = np.asarray([(bit_count(x) % 2) == 0 for x in range(2**6)]).reshape([2] * 6)
  512. for backend in imp_op_backends:
  513. print("Reduction any/all tests for ", backend.framework_name)
  514. for axis in range(6):
  515. expected_result_any = np.any(x_np, axis=axis, keepdims=True)
  516. expected_result_all = np.all(x_np, axis=axis, keepdims=True)
  517. assert not np.array_equal(expected_result_any, expected_result_all)
  518. axes = list("abcdef")
  519. axes_in = list(axes)
  520. axes_out = list(axes)
  521. axes_out[axis] = "1"
  522. pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out))
  523. res_any = reduce(backend.from_numpy(x_np), pattern, reduction="any")
  524. res_all = reduce(backend.from_numpy(x_np), pattern, reduction="all")
  525. assert np.array_equal(expected_result_any, backend.to_numpy(res_any))
  526. assert np.array_equal(expected_result_all, backend.to_numpy(res_all))
  527. # expected result: any/all
  528. expected_result_any = np.any(x_np, axis=(0, 1), keepdims=True)
  529. expected_result_all = np.all(x_np, axis=(0, 1), keepdims=True)
  530. pattern = "a b ... -> 1 1 ..."
  531. res_any = reduce(backend.from_numpy(x_np), pattern, reduction="any")
  532. res_all = reduce(backend.from_numpy(x_np), pattern, reduction="all")
  533. assert np.array_equal(expected_result_any, backend.to_numpy(res_any))
  534. assert np.array_equal(expected_result_all, backend.to_numpy(res_all))