test_packing.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import dataclasses
  2. import typing
  3. import numpy as np
  4. import pytest
  5. from einops import EinopsError, asnumpy, pack, unpack
  6. from einops.tests import collect_test_backends
  7. rng = np.random.default_rng()
  8. def pack_unpack(xs, pattern):
  9. x, ps = pack(xs, pattern)
  10. unpacked = unpack(xs, ps, pattern)
  11. assert len(unpacked) == len(xs)
  12. for a, b in zip(unpacked, xs):
  13. assert np.allclose(asnumpy(a), asnumpy(b))
  14. def unpack_and_pack(x, ps, pattern: str):
  15. unpacked = unpack(x, ps, pattern)
  16. packed, ps2 = pack(unpacked, pattern=pattern)
  17. assert np.allclose(asnumpy(packed), asnumpy(x))
  18. return unpacked
  19. def unpack_and_pack_against_numpy(x, ps, pattern: str):
  20. capturer_backend = CaptureException()
  21. capturer_numpy = CaptureException()
  22. with capturer_backend:
  23. unpacked = unpack(x, ps, pattern)
  24. packed, ps2 = pack(unpacked, pattern=pattern)
  25. with capturer_numpy:
  26. x_np = asnumpy(x)
  27. unpacked_np = unpack(x_np, ps, pattern)
  28. packed_np, ps3 = pack(unpacked_np, pattern=pattern)
  29. assert type(capturer_numpy.exception) == type(capturer_backend.exception) # noqa E721
  30. if capturer_numpy.exception is not None:
  31. # both failed
  32. return
  33. else:
  34. # neither failed, check results are identical
  35. assert np.allclose(asnumpy(packed), asnumpy(x))
  36. assert np.allclose(asnumpy(packed_np), asnumpy(x))
  37. assert len(unpacked) == len(unpacked_np)
  38. for a, b in zip(unpacked, unpacked_np):
  39. assert np.allclose(asnumpy(a), b)
  40. class CaptureException:
  41. def __enter__(self):
  42. self.exception = None
  43. def __exit__(self, exc_type, exc_val, exc_tb):
  44. self.exception = exc_val
  45. return True
  46. def test_numpy_trivial(H=13, W=17):
  47. def rand(*shape):
  48. return rng.random(shape)
  49. def check(a, b):
  50. assert a.dtype == b.dtype
  51. assert a.shape == b.shape
  52. assert np.all(a == b)
  53. r, g, b = rand(3, H, W)
  54. embeddings = rand(H, W, 32)
  55. check(
  56. np.stack([r, g, b], axis=2),
  57. pack([r, g, b], "h w *")[0],
  58. )
  59. check(
  60. np.stack([r, g, b], axis=1),
  61. pack([r, g, b], "h * w")[0],
  62. )
  63. check(
  64. np.stack([r, g, b], axis=0),
  65. pack([r, g, b], "* h w")[0],
  66. )
  67. check(
  68. np.concatenate([r, g, b], axis=1),
  69. pack([r, g, b], "h *")[0],
  70. )
  71. check(
  72. np.concatenate([r, g, b], axis=0),
  73. pack([r, g, b], "* w")[0],
  74. )
  75. i = np.index_exp[:, :, None]
  76. check(
  77. np.concatenate([r[i], g[i], b[i], embeddings], axis=2),
  78. pack([r, g, b, embeddings], "h w *")[0],
  79. )
  80. with pytest.raises(EinopsError):
  81. pack([r, g, b, embeddings], "h w nonexisting_axis *")
  82. pack([r, g, b], "some_name_for_H some_name_for_w1 *")
  83. with pytest.raises(EinopsError):
  84. pack([r, g, b, embeddings], "h _w *") # no leading underscore
  85. with pytest.raises(EinopsError):
  86. pack([r, g, b, embeddings], "h_ w *") # no trailing underscore
  87. with pytest.raises(EinopsError):
  88. pack([r, g, b, embeddings], "1h_ w *")
  89. with pytest.raises(EinopsError):
  90. pack([r, g, b, embeddings], "1 w *")
  91. with pytest.raises(EinopsError):
  92. pack([r, g, b, embeddings], "h h *")
  93. # capital and non-capital are different
  94. pack([r, g, b, embeddings], "h H *")
  95. @dataclasses.dataclass
  96. class UnpackTestCase:
  97. shape: typing.Tuple[int, ...]
  98. pattern: str
  99. def dim(self):
  100. return self.pattern.split().index("*")
  101. def selfcheck(self):
  102. assert self.shape[self.dim()] == 5
  103. cases = [
  104. # NB: in all cases unpacked axis is of length 5.
  105. # that's actively used in tests below
  106. UnpackTestCase((5,), "*"),
  107. UnpackTestCase((5, 7), "* seven"),
  108. UnpackTestCase((7, 5), "seven *"),
  109. UnpackTestCase((5, 3, 4), "* three four"),
  110. UnpackTestCase((4, 5, 3), "four * three"),
  111. UnpackTestCase((3, 4, 5), "three four *"),
  112. ]
  113. def test_pack_unpack_with_numpy():
  114. case: UnpackTestCase
  115. for case in cases:
  116. shape = case.shape
  117. pattern = case.pattern
  118. x = rng.random(shape)
  119. # all correct, no minus 1
  120. unpack_and_pack(x, [[2], [1], [2]], pattern)
  121. # no -1, asking for wrong shapes
  122. with pytest.raises(EinopsError):
  123. unpack_and_pack(x, [[2], [1], [2]], pattern + " non_existent_axis")
  124. with pytest.raises(EinopsError):
  125. unpack_and_pack(x, [[2], [1], [1]], pattern)
  126. with pytest.raises(EinopsError):
  127. unpack_and_pack(x, [[4], [1], [1]], pattern)
  128. # all correct, with -1
  129. unpack_and_pack(x, [[2], [1], [-1]], pattern)
  130. unpack_and_pack(x, [[2], [-1], [2]], pattern)
  131. unpack_and_pack(x, [[-1], [1], [2]], pattern)
  132. _, _, last = unpack_and_pack(x, [[2], [3], [-1]], pattern)
  133. assert last.shape[case.dim()] == 0
  134. # asking for more elements than available
  135. with pytest.raises(EinopsError):
  136. unpack(x, [[2], [4], [-1]], pattern)
  137. # this one does not raise, because indexing x[2:1] just returns zero elements
  138. # with pytest.raises(EinopsError):
  139. # unpack(x, [[2], [-1], [4]], pattern)
  140. with pytest.raises(EinopsError):
  141. unpack(x, [[-1], [1], [5]], pattern)
  142. # all correct, -1 nested
  143. rs = unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern)
  144. assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
  145. rs = unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern)
  146. assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
  147. rs = unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern)
  148. assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
  149. # asking for more elements, -1 nested
  150. with pytest.raises(EinopsError):
  151. unpack(x, [[-1, 2], [1], [5]], pattern)
  152. with pytest.raises(EinopsError):
  153. unpack(x, [[2, 2], [2], [5, -1]], pattern)
  154. # asking for non-divisible number of elements
  155. with pytest.raises(EinopsError):
  156. unpack(x, [[2, 1], [1], [3, -1]], pattern)
  157. with pytest.raises(EinopsError):
  158. unpack(x, [[2, 1], [3, -1], [1]], pattern)
  159. with pytest.raises(EinopsError):
  160. unpack(x, [[3, -1], [2, 1], [1]], pattern)
  161. # -1 takes zero
  162. unpack_and_pack(x, [[0], [5], [-1]], pattern)
  163. unpack_and_pack(x, [[0], [-1], [5]], pattern)
  164. unpack_and_pack(x, [[-1], [5], [0]], pattern)
  165. # -1 takes zero, -1
  166. unpack_and_pack(x, [[2, -1], [1, 5]], pattern)
  167. def test_pack_unpack_against_numpy():
  168. for backend in collect_test_backends(symbolic=False, layers=False):
  169. print(f"test packing against numpy for {backend.framework_name}")
  170. check_zero_len = True
  171. for case in cases:
  172. unpack_and_pack = unpack_and_pack_against_numpy
  173. shape = case.shape
  174. pattern = case.pattern
  175. x = rng.random(shape)
  176. x = backend.from_numpy(x)
  177. # all correct, no minus 1
  178. unpack_and_pack(x, [[2], [1], [2]], pattern)
  179. # no -1, asking for wrong shapes
  180. with pytest.raises(EinopsError):
  181. unpack(x, [[2], [1], [1]], pattern)
  182. with pytest.raises(EinopsError):
  183. unpack(x, [[4], [1], [1]], pattern)
  184. # all correct, with -1
  185. unpack_and_pack(x, [[2], [1], [-1]], pattern)
  186. unpack_and_pack(x, [[2], [-1], [2]], pattern)
  187. unpack_and_pack(x, [[-1], [1], [2]], pattern)
  188. # asking for more elements than available
  189. with pytest.raises(EinopsError):
  190. unpack(x, [[2], [4], [-1]], pattern)
  191. # this one does not raise, because indexing x[2:1] just returns zero elements
  192. # with pytest.raises(EinopsError):
  193. # unpack(x, [[2], [-1], [4]], pattern)
  194. with pytest.raises(EinopsError):
  195. unpack(x, [[-1], [1], [5]], pattern)
  196. # all correct, -1 nested
  197. unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern)
  198. unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern)
  199. unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern)
  200. # asking for more elements, -1 nested
  201. with pytest.raises(EinopsError):
  202. unpack(x, [[-1, 2], [1], [5]], pattern)
  203. with pytest.raises(EinopsError):
  204. unpack(x, [[2, 2], [2], [5, -1]], pattern)
  205. # asking for non-divisible number of elements
  206. with pytest.raises(EinopsError):
  207. unpack(x, [[2, 1], [1], [3, -1]], pattern)
  208. with pytest.raises(EinopsError):
  209. unpack(x, [[2, 1], [3, -1], [1]], pattern)
  210. with pytest.raises(EinopsError):
  211. unpack(x, [[3, -1], [2, 1], [1]], pattern)
  212. if check_zero_len:
  213. # -1 takes zero
  214. unpack_and_pack(x, [[2], [3], [-1]], pattern)
  215. unpack_and_pack(x, [[0], [5], [-1]], pattern)
  216. unpack_and_pack(x, [[0], [-1], [5]], pattern)
  217. unpack_and_pack(x, [[-1], [5], [0]], pattern)
  218. # -1 takes zero, -1
  219. unpack_and_pack(x, [[2, -1], [1, 5]], pattern)
  220. def test_pack_unpack_array_api():
  221. import numpy as xp
  222. from einops import array_api as AA
  223. if xp.__version__ < "2.0.0":
  224. pytest.skip()
  225. for case in cases:
  226. shape = case.shape
  227. pattern = case.pattern
  228. x_np = rng.random(shape)
  229. x_xp = xp.from_dlpack(x_np)
  230. for ps in [
  231. [[2], [1], [2]],
  232. [[1], [1], [-1]],
  233. [[1], [1], [-1, 3]],
  234. [[2, 1], [1, 1, 1], [-1]],
  235. ]:
  236. x_np_split = unpack(x_np, ps, pattern)
  237. x_xp_split = AA.unpack(x_xp, ps, pattern)
  238. for a, b in zip(x_np_split, x_xp_split):
  239. assert np.allclose(a, AA.asnumpy(b + 0))
  240. x_agg_np, ps1 = pack(x_np_split, pattern)
  241. x_agg_xp, ps2 = AA.pack(x_xp_split, pattern)
  242. assert ps1 == ps2
  243. assert np.allclose(x_agg_np, AA.asnumpy(x_agg_xp))
  244. for ps in [
  245. [[2, 3]],
  246. [[1], [5]],
  247. [[1], [5], [-1]],
  248. [[1], [2, 3]],
  249. [[1], [5], [-1, 2]],
  250. ]:
  251. with pytest.raises(EinopsError):
  252. unpack(x_np, ps, pattern)