| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- import string
- from typing import Any, Callable
- import numpy as np
- import pytest
- from einops.einops import EinopsError, _compactify_pattern_for_einsum, einsum
- from einops.tests import collect_test_backends
- class Arguments:
- def __init__(self, *args: Any, **kargs: Any):
- self.args = args
- self.kwargs = kargs
- def __call__(self, function: Callable):
- return function(*self.args, **self.kwargs)
- test_layer_cases = [
- (
- Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape=None, c_out=13, c_in=12),
- (2, 12, 3, 4),
- (4, 13, 3, 2),
- ),
- (
- Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape="c_out", c_out=13, c_in=12),
- (2, 12, 3, 4),
- (4, 13, 3, 2),
- ),
- (
- Arguments("b c_in h w -> w c_in h b", "", bias_shape=None, c_in=12),
- (2, 12, 3, 4),
- (4, 12, 3, 2),
- ),
- (
- Arguments("b c_in h w -> b c_out", "c_in h w c_out", bias_shape=None, c_in=12, h=3, w=4, c_out=5),
- (2, 12, 3, 4),
- (2, 5),
- ),
- (
- Arguments("b t head c_in -> b t head c_out", "head c_in c_out", bias_shape=None, head=4, c_in=5, c_out=6),
- (2, 3, 4, 5),
- (2, 3, 4, 6),
- ),
- ]
- # Each of the form:
- # (Arguments, true_einsum_pattern, in_shapes, out_shape)
- test_functional_cases = [
- (
- # Basic:
- "b c h w, b w -> b h",
- "abcd,ad->ac",
- ((2, 3, 4, 5), (2, 5)),
- (2, 4),
- ),
- (
- # Three tensors:
- "b c h w, b w, b c -> b h",
- "abcd,ad,ab->ac",
- ((2, 3, 40, 5), (2, 5), (2, 3)),
- (2, 40),
- ),
- (
- # Ellipsis, and full names:
- "... one two three, three four five -> ... two five",
- "...abc,cde->...be",
- ((32, 5, 2, 3, 4), (4, 5, 6)),
- (32, 5, 3, 6),
- ),
- (
- # Ellipsis at the end:
- "one two three ..., three four five -> two five ...",
- "abc...,cde->be...",
- ((2, 3, 4, 32, 5), (4, 5, 6)),
- (3, 6, 32, 5),
- ),
- (
- # Ellipsis on multiple tensors:
- "... one two three, ... three four five -> ... two five",
- "...abc,...cde->...be",
- ((32, 5, 2, 3, 4), (32, 5, 4, 5, 6)),
- (32, 5, 3, 6),
- ),
- (
- # One tensor, and underscores:
- "first_tensor second_tensor -> first_tensor",
- "ab->a",
- ((5, 4),),
- (5,),
- ),
- (
- # Trace (repeated index)
- "i i -> ",
- "aa->",
- ((5, 5),),
- (),
- ),
- (
- # Too many spaces in string:
- " one two , three four->two four ",
- "ab,cd->bd",
- ((2, 3), (4, 5)),
- (3, 5),
- ),
- # The following tests were inspired by numpy's einsum tests
- # https://github.com/numpy/numpy/blob/v1.23.0/numpy/core/tests/test_einsum.py
- (
- # Trace with other indices
- "i middle i -> middle",
- "aba->b",
- ((5, 10, 5),),
- (10,),
- ),
- (
- # Ellipsis in the middle:
- "i ... i -> ...",
- "a...a->...",
- ((5, 3, 2, 1, 4, 5),),
- (3, 2, 1, 4),
- ),
- (
- # Product of first and last axes:
- "i ... i -> i ...",
- "a...a->a...",
- ((5, 3, 2, 1, 4, 5),),
- (5, 3, 2, 1, 4),
- ),
- (
- # Triple diagonal
- "one one one -> one",
- "aaa->a",
- ((5, 5, 5),),
- (5,),
- ),
- (
- # Axis swap:
- "i j k -> j i k",
- "abc->bac",
- ((1, 2, 3),),
- (2, 1, 3),
- ),
- (
- # Identity:
- "... -> ...",
- "...->...",
- ((5, 4, 3, 2, 1),),
- (5, 4, 3, 2, 1),
- ),
- (
- # Elementwise product of three tensors
- "..., ..., ... -> ...",
- "...,...,...->...",
- ((3, 2), (3, 2), (3, 2)),
- (3, 2),
- ),
- (
- # Basic summation:
- "index ->",
- "a->",
- ((10,)),
- (()),
- ),
- ]
- def test_layer():
- for backend in collect_test_backends(layers=True, symbolic=False):
- rng = np.random.default_rng()
- if backend.framework_name in ["tensorflow", "torch", "oneflow", "paddle"]:
- layer_type = backend.layers().EinMix
- for args, in_shape, out_shape in test_layer_cases:
- layer = args(layer_type)
- print("Running", layer.einsum_pattern, "for", backend.framework_name)
- input = rng.uniform(size=in_shape).astype("float32")
- input_framework = backend.from_numpy(input)
- output_framework = layer(input_framework)
- output = backend.to_numpy(output_framework)
- assert output.shape == out_shape
- valid_backends_functional = [
- "tensorflow",
- "torch",
- "jax",
- "numpy",
- "oneflow",
- "cupy",
- "tensorflow.keras",
- "paddle",
- "pytensor",
- "mlx",
- ]
- def test_functional():
- # Functional tests:
- backends = filter(lambda x: x.framework_name in valid_backends_functional, collect_test_backends())
- for backend in backends:
- for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
- print(f"Running '{einops_pattern}' for {backend.framework_name}")
- # Create pattern:
- predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
- assert predicted_pattern == true_pattern
- # Generate example data:
- rstate = np.random.RandomState(0)
- in_arrays = [rstate.uniform(size=shape).astype("float32") for shape in in_shapes]
- in_arrays_framework = [backend.from_numpy(array) for array in in_arrays]
- # Loop over whether we call it manually with the backend,
- # or whether we use `einops.einsum`.
- for do_manual_call in [True, False]:
- # Actually run einsum:
- if do_manual_call:
- out_array = backend.einsum(predicted_pattern, *in_arrays_framework)
- else:
- out_array = einsum(*in_arrays_framework, einops_pattern)
- # Check shape:
- if tuple(out_array.shape) != out_shape:
- raise ValueError(f"Expected output shape {out_shape} but got {out_array.shape}")
- # Check values:
- true_out_array = np.einsum(true_pattern, *in_arrays)
- predicted_out_array = backend.to_numpy(out_array)
- np.testing.assert_array_almost_equal(predicted_out_array, true_out_array, decimal=5)
- def test_functional_symbolic():
- backends = filter(
- lambda x: x.framework_name in valid_backends_functional, collect_test_backends(symbolic=True, layers=False)
- )
- for backend in backends:
- for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
- print(f"Running '{einops_pattern}' for symbolic {backend.framework_name}")
- # Create pattern:
- predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
- assert predicted_pattern == true_pattern
- rstate = np.random.RandomState(0)
- in_syms = [backend.create_symbol(in_shape) for in_shape in in_shapes]
- in_data = [rstate.uniform(size=in_shape).astype("float32") for in_shape in in_shapes]
- expected_out_data = np.einsum(true_pattern, *in_data)
- for do_manual_call in [True, False]:
- if do_manual_call:
- predicted_out_symbol = backend.einsum(predicted_pattern, *in_syms)
- else:
- predicted_out_symbol = einsum(*in_syms, einops_pattern)
- predicted_out_data = backend.eval_symbol(
- predicted_out_symbol,
- list(zip(in_syms, in_data)),
- )
- if predicted_out_data.shape != out_shape:
- raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
- np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
- def test_functional_errors():
- # Specific backend does not matter, as errors are raised
- # during the pattern creation.
- rstate = np.random.RandomState(0)
- def create_tensor(*shape):
- return rstate.uniform(size=shape).astype("float32")
- # raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
- with pytest.raises(NotImplementedError, match="^Singleton"):
- einsum(
- create_tensor(5, 1),
- "i () -> i",
- )
- # raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
- with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
- einsum(
- create_tensor(5, 1),
- "a b -> (a b)",
- )
- with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
- einsum(
- create_tensor(10, 1),
- "(a b) -> a b",
- )
- # raise RuntimeError("Encountered empty axis name in einsum.")
- # raise RuntimeError("Axis name in einsum must be a string.")
- # ^ Not tested, these are just a failsafe in case an unexpected error occurs.
- # raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
- with pytest.raises(NotImplementedError, match="^Anonymous axes"):
- einsum(
- create_tensor(5, 1),
- "i 2 -> i",
- )
- # ParsedExpression error:
- with pytest.raises(EinopsError, match="^Invalid axis identifier"):
- einsum(
- create_tensor(5, 1),
- "i 2j -> i",
- )
- # raise ValueError("Einsum pattern must contain '->'.")
- with pytest.raises(ValueError, match="^Einsum pattern"):
- einsum(
- create_tensor(5, 3, 2),
- "i j k",
- )
- # raise RuntimeError("Too many axes in einsum.")
- with pytest.raises(RuntimeError, match="^Too many axes"):
- einsum(
- create_tensor(1),
- " ".join(string.ascii_letters) + " extra ->",
- )
- # raise RuntimeError("Unknown axis on right side of einsum.")
- with pytest.raises(RuntimeError, match="^Unknown axis"):
- einsum(
- create_tensor(5, 1),
- "i j -> k",
- )
- # raise ValueError(
- # "The last argument passed to `einops.einsum` must be a string,"
- # " representing the einsum pattern."
- # )
- with pytest.raises(ValueError, match="^The last argument"):
- einsum(
- "i j k -> i",
- create_tensor(5, 4, 3),
- )
- # raise ValueError(
- # "`einops.einsum` takes at minimum two arguments: the tensors,"
- # " followed by the pattern."
- # )
- with pytest.raises(ValueError, match="^`einops.einsum` takes"):
- einsum(
- "i j k -> i",
- )
- with pytest.raises(ValueError, match="^`einops.einsum` takes"):
- einsum(
- create_tensor(5, 1),
- )
- # TODO: Include check for giving normal einsum pattern rather than einops.
|