| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480 |
- import pickle
- from collections import namedtuple
- import numpy as np
- import pytest
- from einops import EinopsError, rearrange, reduce
- from einops.tests import FLOAT_REDUCTIONS as REDUCTIONS
- from einops.tests import collect_test_backends, is_backend_tested
- __author__ = "Alex Rogozhnikov"
- testcase = namedtuple("testcase", ["pattern", "axes_lengths", "input_shape", "wrong_shapes"])
- rearrangement_patterns = [
- testcase(
- "b c h w -> b (c h w)",
- dict(c=20),
- (10, 20, 30, 40),
- [(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]],
- ),
- testcase(
- "b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1",
- dict(h2=2, w2=2),
- (10, 20, 30, 40),
- [(), (1, 1, 1, 1), (1, 10, 3), ()],
- ),
- testcase(
- "b ... c -> c b ...",
- dict(b=10),
- (10, 20, 30),
- [(), (10,), (5, 10)],
- ),
- ]
- def test_rearrange_imperative():
- for backend in collect_test_backends(symbolic=False, layers=True):
- print("Test layer for ", backend.framework_name)
- for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
- x = np.arange(np.prod(input_shape), dtype="float32").reshape(input_shape)
- result_numpy = rearrange(x, pattern, **axes_lengths)
- layer = backend.layers().Rearrange(pattern, **axes_lengths)
- for shape in wrong_shapes:
- try:
- layer(backend.from_numpy(np.zeros(shape, dtype="float32")))
- except BaseException:
- pass
- else:
- raise AssertionError("Failure expected")
- # simple pickling / unpickling
- layer2 = pickle.loads(pickle.dumps(layer))
- result1 = backend.to_numpy(layer(backend.from_numpy(x)))
- result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
- assert np.allclose(result_numpy, result1)
- assert np.allclose(result1, result2)
- just_sum = backend.layers().Reduce("...->", reduction="sum")
- variable = backend.from_numpy(x)
- result = just_sum(layer(variable))
- result.backward()
- assert np.allclose(backend.to_numpy(variable.grad), 1)
- def test_rearrange_symbolic():
- for backend in collect_test_backends(symbolic=True, layers=True):
- print("Test layer for ", backend.framework_name)
- for pattern, axes_lengths, input_shape, _wrong_shapes in rearrangement_patterns:
- x = np.arange(np.prod(input_shape), dtype="float32").reshape(input_shape)
- result_numpy = rearrange(x, pattern, **axes_lengths)
- layer = backend.layers().Rearrange(pattern, **axes_lengths)
- input_shape_of_nones = [None] * len(input_shape)
- shapes = [input_shape, input_shape_of_nones]
- for shape in shapes:
- symbol = backend.create_symbol(shape)
- eval_inputs = [(symbol, x)]
- result_symbol1 = layer(symbol)
- result1 = backend.eval_symbol(result_symbol1, eval_inputs)
- assert np.allclose(result_numpy, result1)
- layer2 = pickle.loads(pickle.dumps(layer))
- result_symbol2 = layer2(symbol)
- result2 = backend.eval_symbol(result_symbol2, eval_inputs)
- assert np.allclose(result1, result2)
- # now testing back-propagation
- just_sum = backend.layers().Reduce("...->", reduction="sum")
- result_sum1 = backend.eval_symbol(just_sum(result_symbol1), eval_inputs)
- result_sum2 = np.sum(x)
- assert np.allclose(result_sum1, result_sum2)
- reduction_patterns = [
- *rearrangement_patterns,
- testcase("b c h w -> b ()", dict(b=10), (10, 20, 30, 40), [(10,), (10, 20, 30)]),
- testcase("b c (h1 h2) (w1 w2) -> b c h1 w1", dict(h1=15, h2=2, w2=2), (10, 20, 30, 40), [(10, 20, 31, 40)]),
- testcase("b ... c -> b", dict(b=10), (10, 20, 30, 40), [(10,), (11, 10)]),
- ]
- def test_reduce_imperative():
- for backend in collect_test_backends(symbolic=False, layers=True):
- print("Test layer for ", backend.framework_name)
- for reduction in REDUCTIONS:
- for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
- print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
- x = np.arange(1, 1 + np.prod(input_shape), dtype="float32").reshape(input_shape)
- x /= x.mean()
- result_numpy = reduce(x, pattern, reduction, **axes_lengths)
- layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
- for shape in wrong_shapes:
- try:
- layer(backend.from_numpy(np.zeros(shape, dtype="float32")))
- except BaseException:
- pass
- else:
- raise AssertionError("Failure expected")
- # simple pickling / unpickling
- layer2 = pickle.loads(pickle.dumps(layer))
- result1 = backend.to_numpy(layer(backend.from_numpy(x)))
- result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
- assert np.allclose(result_numpy, result1)
- assert np.allclose(result1, result2)
- just_sum = backend.layers().Reduce("...->", reduction="sum")
- variable = backend.from_numpy(x)
- result = just_sum(layer(variable))
- result.backward()
- grad = backend.to_numpy(variable.grad)
- if reduction == "sum":
- assert np.allclose(grad, 1)
- if reduction == "mean":
- assert np.allclose(grad, grad.min())
- if reduction in ["max", "min"]:
- assert np.all(np.isin(grad, [0, 1]))
- assert np.sum(grad) > 0.5
- def test_reduce_symbolic():
- for backend in collect_test_backends(symbolic=True, layers=True):
- print("Test layer for ", backend.framework_name)
- for reduction in REDUCTIONS:
- for pattern, axes_lengths, input_shape, _wrong_shapes in reduction_patterns:
- x = np.arange(1, 1 + np.prod(input_shape), dtype="float32").reshape(input_shape)
- x /= x.mean()
- result_numpy = reduce(x, pattern, reduction, **axes_lengths)
- layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
- input_shape_of_nones = [None] * len(input_shape)
- shapes = [input_shape, input_shape_of_nones]
- for shape in shapes:
- symbol = backend.create_symbol(shape)
- eval_inputs = [(symbol, x)]
- result_symbol1 = layer(symbol)
- result1 = backend.eval_symbol(result_symbol1, eval_inputs)
- assert np.allclose(result_numpy, result1)
- layer2 = pickle.loads(pickle.dumps(layer))
- result_symbol2 = layer2(symbol)
- result2 = backend.eval_symbol(result_symbol2, eval_inputs)
- assert np.allclose(result1, result2)
- def create_torch_model(use_reduce=False, add_scripted_layer=False):
- if not is_backend_tested("torch"):
- pytest.skip()
- else:
- import torch.jit
- from torch.nn import Conv2d, Linear, MaxPool2d, ReLU, Sequential
- from einops.layers.torch import EinMix, Rearrange, Reduce
- return Sequential(
- Conv2d(3, 6, kernel_size=(5, 5)),
- Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
- Conv2d(6, 16, kernel_size=(5, 5)),
- Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
- torch.jit.script(Rearrange("b c h w -> b (c h w)"))
- if add_scripted_layer
- else Rearrange("b c h w -> b (c h w)"),
- Linear(16 * 5 * 5, 120),
- ReLU(),
- Linear(120, 84),
- ReLU(),
- EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84),
- EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84),
- Linear(84, 10),
- )
- def test_torch_layer():
- if not is_backend_tested("torch"):
- pytest.skip()
- else:
- # checked that torch present
- import torch
- import torch.jit
- model1 = create_torch_model(use_reduce=True)
- model2 = create_torch_model(use_reduce=False)
- input = torch.randn([10, 3, 32, 32])
- # random models have different predictions
- assert not torch.allclose(model1(input), model2(input))
- model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
- assert torch.allclose(model1(input), model2(input))
- # tracing (freezing)
- model3 = torch.jit.trace(model2, example_inputs=input)
- torch.testing.assert_close(model1(input), model3(input), atol=1e-3, rtol=1e-3)
- torch.testing.assert_close(model1(input + 1), model3(input + 1), atol=1e-3, rtol=1e-3)
- model4 = torch.jit.trace(model2, example_inputs=input)
- torch.testing.assert_close(model1(input), model4(input), atol=1e-3, rtol=1e-3)
- torch.testing.assert_close(model1(input + 1), model4(input + 1), atol=1e-3, rtol=1e-3)
- def test_torch_layers_scripting():
- if not is_backend_tested("torch"):
- pytest.skip()
- else:
- import torch
- for script_layer in [False, True]:
- model1 = create_torch_model(use_reduce=True, add_scripted_layer=script_layer)
- model2 = torch.jit.script(model1)
- input = torch.randn([10, 3, 32, 32])
- torch.testing.assert_close(model1(input), model2(input), atol=1e-3, rtol=1e-3)
- def test_keras_layer():
- rng = np.random.default_rng()
- if not is_backend_tested("tensorflow"):
- pytest.skip()
- else:
- import tensorflow as tf
- if tf.__version__ < "2.16.":
- # current implementation of layers follows new TF interface
- pytest.skip()
- from tensorflow.keras.layers import Conv2D as Conv2d
- from tensorflow.keras.layers import Dense as Linear
- from tensorflow.keras.layers import ReLU
- from tensorflow.keras.models import Sequential
- from einops.layers.keras import EinMix, Rearrange, Reduce, keras_custom_objects
- def create_keras_model():
- return Sequential(
- [
- Conv2d(6, kernel_size=5, input_shape=[32, 32, 3]),
- Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
- Conv2d(16, kernel_size=5),
- Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
- Rearrange("b c h w -> b (c h w)"),
- Linear(120),
- ReLU(),
- Linear(84),
- ReLU(),
- EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84),
- EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84),
- Linear(10),
- ]
- )
- model1 = create_keras_model()
- model2 = create_keras_model()
- input = rng.normal(size=[10, 32, 32, 3]).astype("float32")
- # two randomly init models should provide different outputs
- assert not np.allclose(model1.predict_on_batch(input), model2.predict_on_batch(input))
- # get some temp filename
- tmp_model_filename = "/tmp/einops_tf_model.h5"
- # save arch + weights
- print("temp_path_keras1", tmp_model_filename)
- tf.keras.models.save_model(model1, tmp_model_filename)
- model3 = tf.keras.models.load_model(tmp_model_filename, custom_objects=keras_custom_objects)
- np.testing.assert_allclose(model1.predict_on_batch(input), model3.predict_on_batch(input))
- weight_filename = "/tmp/einops_tf_model.weights.h5"
- # save arch as json
- model4 = tf.keras.models.model_from_json(model1.to_json(), custom_objects=keras_custom_objects)
- model1.save_weights(weight_filename)
- model4.load_weights(weight_filename)
- model2.load_weights(weight_filename)
- # check that differently-inialized model receives same weights
- np.testing.assert_allclose(model1.predict_on_batch(input), model2.predict_on_batch(input))
- # ulimate test
- # save-load architecture, and then load weights - should return same result
- np.testing.assert_allclose(model1.predict_on_batch(input), model4.predict_on_batch(input))
- def test_flax_layers():
- """
- One-off simple tests for Flax layers.
- Unfortunately, Flax layers have a different interface from other layers.
- """
- if not is_backend_tested("jax"):
- pytest.skip()
- else:
- import flax
- import jax
- import jax.numpy as jnp
- from flax import linen as nn
- from einops.layers.flax import EinMix, Rearrange, Reduce
- class NN(nn.Module):
- @nn.compact
- def __call__(self, x):
- x = EinMix(
- "b (h h2) (w w2) c -> b h w c_out", "h2 w2 c c_out", "c_out", sizes=dict(h2=2, w2=3, c=4, c_out=5)
- )(x)
- x = Rearrange("b h w c -> b (w h c)", sizes=dict(c=5))(x)
- x = Reduce("b hwc -> b", "mean", dict(hwc=2 * 3 * 5))(x)
- return x
- model = NN()
- fixed_input = jnp.ones([10, 2 * 2, 3 * 3, 4])
- params = model.init(jax.random.PRNGKey(0), fixed_input)
- def eval_at_point(params):
- return jnp.linalg.norm(model.apply(params, fixed_input))
- vandg = jax.value_and_grad(eval_at_point)
- value0 = eval_at_point(params)
- value1, grad1 = vandg(params)
- assert jnp.allclose(value0, value1)
- if jax.__version__ < "0.6.0":
- tree_map = jax.tree_map
- else:
- tree_map = jax.tree.map
- params2 = tree_map(lambda x1, x2: x1 - x2 * 0.001, params, grad1)
- value2 = eval_at_point(params2)
- assert value0 >= value2, (value0, value2)
- # check serialization
- fbytes = flax.serialization.to_bytes(params)
- _loaded = flax.serialization.from_bytes(params, fbytes)
- def test_einmix_decomposition():
- """
- Testing that einmix correctly decomposes into smaller transformations.
- """
- from einops.layers._einmix import _EinmixDebugger
- mixin1 = _EinmixDebugger(
- "a b c d e -> e d c b a",
- weight_shape="d a b",
- d=2, a=3, b=5,
- ) # fmt: off
- assert mixin1.pre_reshape_pattern is None
- assert mixin1.post_reshape_pattern is None
- assert mixin1.einsum_pattern == "abcde,dab->edcba"
- assert mixin1.saved_weight_shape == [2, 3, 5]
- assert mixin1.saved_bias_shape is None
- mixin2 = _EinmixDebugger(
- "a b c d e -> e d c b a",
- weight_shape="d a b",
- bias_shape="a b c d e",
- a=1, b=2, c=3, d=4, e=5,
- ) # fmt: off
- assert mixin2.pre_reshape_pattern is None
- assert mixin2.post_reshape_pattern is None
- assert mixin2.einsum_pattern == "abcde,dab->edcba"
- assert mixin2.saved_weight_shape == [4, 1, 2]
- assert mixin2.saved_bias_shape == [5, 4, 3, 2, 1]
- mixin3 = _EinmixDebugger(
- "... -> ...",
- weight_shape="",
- bias_shape="",
- ) # fmt: off
- assert mixin3.pre_reshape_pattern is None
- assert mixin3.post_reshape_pattern is None
- assert mixin3.einsum_pattern == "...,->..."
- assert mixin3.saved_weight_shape == []
- assert mixin3.saved_bias_shape == []
- mixin4 = _EinmixDebugger(
- "b a ... -> b c ...",
- weight_shape="b a c",
- a=1, b=2, c=3,
- ) # fmt: off
- assert mixin4.pre_reshape_pattern is None
- assert mixin4.post_reshape_pattern is None
- assert mixin4.einsum_pattern == "ba...,bac->bc..."
- assert mixin4.saved_weight_shape == [2, 1, 3]
- assert mixin4.saved_bias_shape is None
- mixin5 = _EinmixDebugger(
- "(b a) ... -> b c (...)",
- weight_shape="b a c",
- a=1, b=2, c=3,
- ) # fmt: off
- assert mixin5.pre_reshape_pattern == "(b a) ... -> b a ..."
- assert mixin5.pre_reshape_lengths == dict(a=1, b=2)
- assert mixin5.post_reshape_pattern == "b c ... -> b c (...)"
- assert mixin5.einsum_pattern == "ba...,bac->bc..."
- assert mixin5.saved_weight_shape == [2, 1, 3]
- assert mixin5.saved_bias_shape is None
- mixin6 = _EinmixDebugger(
- "b ... (a c) -> b ... (a d)",
- weight_shape="c d",
- bias_shape="a d",
- a=1, c=3, d=4,
- ) # fmt: off
- assert mixin6.pre_reshape_pattern == "b ... (a c) -> b ... a c"
- assert mixin6.pre_reshape_lengths == dict(a=1, c=3)
- assert mixin6.post_reshape_pattern == "b ... a d -> b ... (a d)"
- assert mixin6.einsum_pattern == "b...ac,cd->b...ad"
- assert mixin6.saved_weight_shape == [3, 4]
- assert mixin6.saved_bias_shape == [1, 1, 4] # (b) a d, ellipsis does not participate
- mixin7 = _EinmixDebugger(
- "a ... (b c) -> a (... d b)",
- weight_shape="c d b",
- bias_shape="d b",
- b=2, c=3, d=4,
- ) # fmt: off
- assert mixin7.pre_reshape_pattern == "a ... (b c) -> a ... b c"
- assert mixin7.pre_reshape_lengths == dict(b=2, c=3)
- assert mixin7.post_reshape_pattern == "a ... d b -> a (... d b)"
- assert mixin7.einsum_pattern == "a...bc,cdb->a...db"
- assert mixin7.saved_weight_shape == [3, 4, 2]
- assert mixin7.saved_bias_shape == [1, 4, 2] # (a) d b, ellipsis does not participate
- def test_einmix_restrictions():
- """
- Testing different cases
- """
- from einops.layers._einmix import _EinmixDebugger
- with pytest.raises(EinopsError):
- _EinmixDebugger(
- "a b c d e -> e d c b a",
- weight_shape="d a b",
- d=2, a=3, # missing b
- ) # fmt: off
- with pytest.raises(EinopsError):
- _EinmixDebugger(
- "a b c d e -> e d c b a",
- weight_shape="w a b",
- d=2, a=3, b=1 # missing d
- ) # fmt: off
- with pytest.raises(EinopsError):
- _EinmixDebugger(
- "(...) a -> ... a",
- weight_shape="a", a=1, # ellipsis on the left
- ) # fmt: off
- with pytest.raises(EinopsError):
- _EinmixDebugger(
- "(...) a -> a ...",
- weight_shape="a", a=1, # ellipsis on the right side after bias axis
- bias_shape="a",
- ) # fmt: off
|