| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- import dataclasses
- import typing
- import numpy as np
- import pytest
- from einops import EinopsError, asnumpy, pack, unpack
- from einops.tests import collect_test_backends
- rng = np.random.default_rng()
- def pack_unpack(xs, pattern):
- x, ps = pack(xs, pattern)
- unpacked = unpack(xs, ps, pattern)
- assert len(unpacked) == len(xs)
- for a, b in zip(unpacked, xs):
- assert np.allclose(asnumpy(a), asnumpy(b))
- def unpack_and_pack(x, ps, pattern: str):
- unpacked = unpack(x, ps, pattern)
- packed, ps2 = pack(unpacked, pattern=pattern)
- assert np.allclose(asnumpy(packed), asnumpy(x))
- return unpacked
- def unpack_and_pack_against_numpy(x, ps, pattern: str):
- capturer_backend = CaptureException()
- capturer_numpy = CaptureException()
- with capturer_backend:
- unpacked = unpack(x, ps, pattern)
- packed, ps2 = pack(unpacked, pattern=pattern)
- with capturer_numpy:
- x_np = asnumpy(x)
- unpacked_np = unpack(x_np, ps, pattern)
- packed_np, ps3 = pack(unpacked_np, pattern=pattern)
- assert type(capturer_numpy.exception) == type(capturer_backend.exception) # noqa E721
- if capturer_numpy.exception is not None:
- # both failed
- return
- else:
- # neither failed, check results are identical
- assert np.allclose(asnumpy(packed), asnumpy(x))
- assert np.allclose(asnumpy(packed_np), asnumpy(x))
- assert len(unpacked) == len(unpacked_np)
- for a, b in zip(unpacked, unpacked_np):
- assert np.allclose(asnumpy(a), b)
- class CaptureException:
- def __enter__(self):
- self.exception = None
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.exception = exc_val
- return True
- def test_numpy_trivial(H=13, W=17):
- def rand(*shape):
- return rng.random(shape)
- def check(a, b):
- assert a.dtype == b.dtype
- assert a.shape == b.shape
- assert np.all(a == b)
- r, g, b = rand(3, H, W)
- embeddings = rand(H, W, 32)
- check(
- np.stack([r, g, b], axis=2),
- pack([r, g, b], "h w *")[0],
- )
- check(
- np.stack([r, g, b], axis=1),
- pack([r, g, b], "h * w")[0],
- )
- check(
- np.stack([r, g, b], axis=0),
- pack([r, g, b], "* h w")[0],
- )
- check(
- np.concatenate([r, g, b], axis=1),
- pack([r, g, b], "h *")[0],
- )
- check(
- np.concatenate([r, g, b], axis=0),
- pack([r, g, b], "* w")[0],
- )
- i = np.index_exp[:, :, None]
- check(
- np.concatenate([r[i], g[i], b[i], embeddings], axis=2),
- pack([r, g, b, embeddings], "h w *")[0],
- )
- with pytest.raises(EinopsError):
- pack([r, g, b, embeddings], "h w nonexisting_axis *")
- pack([r, g, b], "some_name_for_H some_name_for_w1 *")
- with pytest.raises(EinopsError):
- pack([r, g, b, embeddings], "h _w *") # no leading underscore
- with pytest.raises(EinopsError):
- pack([r, g, b, embeddings], "h_ w *") # no trailing underscore
- with pytest.raises(EinopsError):
- pack([r, g, b, embeddings], "1h_ w *")
- with pytest.raises(EinopsError):
- pack([r, g, b, embeddings], "1 w *")
- with pytest.raises(EinopsError):
- pack([r, g, b, embeddings], "h h *")
- # capital and non-capital are different
- pack([r, g, b, embeddings], "h H *")
- @dataclasses.dataclass
- class UnpackTestCase:
- shape: typing.Tuple[int, ...]
- pattern: str
- def dim(self):
- return self.pattern.split().index("*")
- def selfcheck(self):
- assert self.shape[self.dim()] == 5
- cases = [
- # NB: in all cases unpacked axis is of length 5.
- # that's actively used in tests below
- UnpackTestCase((5,), "*"),
- UnpackTestCase((5, 7), "* seven"),
- UnpackTestCase((7, 5), "seven *"),
- UnpackTestCase((5, 3, 4), "* three four"),
- UnpackTestCase((4, 5, 3), "four * three"),
- UnpackTestCase((3, 4, 5), "three four *"),
- ]
- def test_pack_unpack_with_numpy():
- case: UnpackTestCase
- for case in cases:
- shape = case.shape
- pattern = case.pattern
- x = rng.random(shape)
- # all correct, no minus 1
- unpack_and_pack(x, [[2], [1], [2]], pattern)
- # no -1, asking for wrong shapes
- with pytest.raises(EinopsError):
- unpack_and_pack(x, [[2], [1], [2]], pattern + " non_existent_axis")
- with pytest.raises(EinopsError):
- unpack_and_pack(x, [[2], [1], [1]], pattern)
- with pytest.raises(EinopsError):
- unpack_and_pack(x, [[4], [1], [1]], pattern)
- # all correct, with -1
- unpack_and_pack(x, [[2], [1], [-1]], pattern)
- unpack_and_pack(x, [[2], [-1], [2]], pattern)
- unpack_and_pack(x, [[-1], [1], [2]], pattern)
- _, _, last = unpack_and_pack(x, [[2], [3], [-1]], pattern)
- assert last.shape[case.dim()] == 0
- # asking for more elements than available
- with pytest.raises(EinopsError):
- unpack(x, [[2], [4], [-1]], pattern)
- # this one does not raise, because indexing x[2:1] just returns zero elements
- # with pytest.raises(EinopsError):
- # unpack(x, [[2], [-1], [4]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[-1], [1], [5]], pattern)
- # all correct, -1 nested
- rs = unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern)
- assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
- rs = unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern)
- assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
- rs = unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern)
- assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
- # asking for more elements, -1 nested
- with pytest.raises(EinopsError):
- unpack(x, [[-1, 2], [1], [5]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[2, 2], [2], [5, -1]], pattern)
- # asking for non-divisible number of elements
- with pytest.raises(EinopsError):
- unpack(x, [[2, 1], [1], [3, -1]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[2, 1], [3, -1], [1]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[3, -1], [2, 1], [1]], pattern)
- # -1 takes zero
- unpack_and_pack(x, [[0], [5], [-1]], pattern)
- unpack_and_pack(x, [[0], [-1], [5]], pattern)
- unpack_and_pack(x, [[-1], [5], [0]], pattern)
- # -1 takes zero, -1
- unpack_and_pack(x, [[2, -1], [1, 5]], pattern)
- def test_pack_unpack_against_numpy():
- for backend in collect_test_backends(symbolic=False, layers=False):
- print(f"test packing against numpy for {backend.framework_name}")
- check_zero_len = True
- for case in cases:
- unpack_and_pack = unpack_and_pack_against_numpy
- shape = case.shape
- pattern = case.pattern
- x = rng.random(shape)
- x = backend.from_numpy(x)
- # all correct, no minus 1
- unpack_and_pack(x, [[2], [1], [2]], pattern)
- # no -1, asking for wrong shapes
- with pytest.raises(EinopsError):
- unpack(x, [[2], [1], [1]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[4], [1], [1]], pattern)
- # all correct, with -1
- unpack_and_pack(x, [[2], [1], [-1]], pattern)
- unpack_and_pack(x, [[2], [-1], [2]], pattern)
- unpack_and_pack(x, [[-1], [1], [2]], pattern)
- # asking for more elements than available
- with pytest.raises(EinopsError):
- unpack(x, [[2], [4], [-1]], pattern)
- # this one does not raise, because indexing x[2:1] just returns zero elements
- # with pytest.raises(EinopsError):
- # unpack(x, [[2], [-1], [4]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[-1], [1], [5]], pattern)
- # all correct, -1 nested
- unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern)
- unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern)
- unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern)
- # asking for more elements, -1 nested
- with pytest.raises(EinopsError):
- unpack(x, [[-1, 2], [1], [5]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[2, 2], [2], [5, -1]], pattern)
- # asking for non-divisible number of elements
- with pytest.raises(EinopsError):
- unpack(x, [[2, 1], [1], [3, -1]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[2, 1], [3, -1], [1]], pattern)
- with pytest.raises(EinopsError):
- unpack(x, [[3, -1], [2, 1], [1]], pattern)
- if check_zero_len:
- # -1 takes zero
- unpack_and_pack(x, [[2], [3], [-1]], pattern)
- unpack_and_pack(x, [[0], [5], [-1]], pattern)
- unpack_and_pack(x, [[0], [-1], [5]], pattern)
- unpack_and_pack(x, [[-1], [5], [0]], pattern)
- # -1 takes zero, -1
- unpack_and_pack(x, [[2, -1], [1, 5]], pattern)
- def test_pack_unpack_array_api():
- import numpy as xp
- from einops import array_api as AA
- if xp.__version__ < "2.0.0":
- pytest.skip()
- for case in cases:
- shape = case.shape
- pattern = case.pattern
- x_np = rng.random(shape)
- x_xp = xp.from_dlpack(x_np)
- for ps in [
- [[2], [1], [2]],
- [[1], [1], [-1]],
- [[1], [1], [-1, 3]],
- [[2, 1], [1, 1, 1], [-1]],
- ]:
- x_np_split = unpack(x_np, ps, pattern)
- x_xp_split = AA.unpack(x_xp, ps, pattern)
- for a, b in zip(x_np_split, x_xp_split):
- assert np.allclose(a, AA.asnumpy(b + 0))
- x_agg_np, ps1 = pack(x_np_split, pattern)
- x_agg_xp, ps2 = AA.pack(x_xp_split, pattern)
- assert ps1 == ps2
- assert np.allclose(x_agg_np, AA.asnumpy(x_agg_xp))
- for ps in [
- [[2, 3]],
- [[1], [5]],
- [[1], [5], [-1]],
- [[1], [2, 3]],
- [[1], [5], [-1, 2]],
- ]:
- with pytest.raises(EinopsError):
- unpack(x_np, ps, pattern)
|