test_other.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import subprocess
  2. import tempfile
  3. from doctest import testmod
  4. from pathlib import Path
  5. import numpy as np
  6. import pytest
  7. import einops
  8. import einops.layers
  9. from einops._backends import AbstractBackend
  10. from einops.einops import _optimize_transformation, parse_shape, rearrange
  11. from einops.tests import collect_test_backends, is_backend_tested
  12. __author__ = "Alex Rogozhnikov"
  13. rng = np.random.default_rng()
  14. def test_doctests_examples():
  15. # tests docstrings, additionally
  16. testmod(einops.layers, raise_on_error=True, extraglobs=dict(np=np))
  17. testmod(einops.einops, raise_on_error=True, extraglobs=dict(np=np))
  18. def test_backends_installed():
  19. """
  20. This test will fail if some of backends are not installed or can't be imported
  21. Other tests will just work and only test installed backends.
  22. """
  23. from . import parse_backends_to_test
  24. backends_to_test = set(parse_backends_to_test())
  25. errors = []
  26. # Find backend subclasses recursively
  27. backend_subclasses = []
  28. backends = AbstractBackend.__subclasses__()
  29. while backends:
  30. backend = backends.pop()
  31. backends += backend.__subclasses__()
  32. backend_subclasses.append(backend)
  33. for backend_type in backend_subclasses:
  34. if backend_type.framework_name not in backends_to_test:
  35. continue
  36. try:
  37. # instantiate
  38. backend_type()
  39. backends_to_test.remove(backend_type.framework_name)
  40. except Exception as e:
  41. errors.append((backend_type.framework_name, e))
  42. assert len(errors) == 0, errors
  43. assert len(backends_to_test) == 0, f"did not instantiate {backends_to_test=}, they won't be tested"
  44. def test_optimize_transformations_numpy():
  45. print("Testing optimizations")
  46. shapes = [[2] * n_dimensions for n_dimensions in range(14)]
  47. shapes += [[3] * n_dimensions for n_dimensions in range(6)]
  48. shapes += [[2, 3, 5, 7]]
  49. shapes += [[2, 3, 5, 7, 11, 17]]
  50. for shape in shapes:
  51. for _attempt in range(5):
  52. n_dimensions = len(shape)
  53. x = rng.integers(0, 2**12, size=shape).reshape([-1])
  54. init_shape = shape[:]
  55. n_reduced = rng.integers(0, n_dimensions + 1)
  56. reduced_axes = tuple(rng.permutation(n_dimensions)[:n_reduced])
  57. axes_reordering = rng.permutation(n_dimensions - n_reduced)
  58. final_shape = rng.integers(0, 1024, size=333) # just random
  59. init_shape2, reduced_axes2, axes_reordering2, final_shape2 = combination2 = _optimize_transformation(
  60. init_shape, reduced_axes, axes_reordering, final_shape
  61. )
  62. assert np.array_equal(final_shape, final_shape2)
  63. result1 = x.reshape(init_shape).sum(axis=reduced_axes).transpose(axes_reordering).reshape([-1])
  64. result2 = x.reshape(init_shape2).sum(axis=reduced_axes2).transpose(axes_reordering2).reshape([-1])
  65. assert np.array_equal(result1, result2)
  66. # testing we can't optimize this formula again
  67. combination3 = _optimize_transformation(*combination2)
  68. for a, b in zip(combination2, combination3):
  69. assert np.array_equal(a, b)
  70. _IMPERATIVE_BACKENDS = collect_test_backends(symbolic=False, layers=False)
  71. x_np = np.zeros([10, 20, 30, 40])
  72. def test_parse_shape_imperative():
  73. for backend in _IMPERATIVE_BACKENDS:
  74. print("Shape parsing for ", backend.framework_name)
  75. parsed1 = parse_shape(x_np, "a b c d")
  76. parsed2 = parse_shape(backend.from_numpy(x_np), "a b c d")
  77. assert parsed1 == parsed2 == dict(a=10, b=20, c=30, d=40)
  78. assert parsed1 != dict(a=1, b=20, c=30, d=40) != parsed2
  79. def test_underscore():
  80. for backend in _IMPERATIVE_BACKENDS:
  81. parsed1 = parse_shape(x_np, "_ _ _ _")
  82. parsed2 = parse_shape(backend.from_numpy(x_np), "_ _ _ _")
  83. assert parsed1 == parsed2 == dict()
  84. def test_underscore_one():
  85. for backend in _IMPERATIVE_BACKENDS:
  86. parsed1 = parse_shape(x_np, "_ _ _ hello")
  87. parsed2 = parse_shape(backend.from_numpy(x_np), "_ _ _ hello")
  88. assert parsed1 == parsed2 == dict(hello=40)
  89. def test_underscore_several():
  90. for backend in _IMPERATIVE_BACKENDS:
  91. parsed1 = parse_shape(x_np, "_ _ a1 a1a111a")
  92. parsed2 = parse_shape(backend.from_numpy(x_np), "_ _ a1 a1a111a")
  93. assert parsed1 == parsed2 == dict(a1=30, a1a111a=40)
  94. def test_repeating():
  95. with pytest.raises(einops.EinopsError):
  96. parse_shape(x_np, "a a b b")
  97. for backend in _IMPERATIVE_BACKENDS:
  98. with pytest.raises(einops.EinopsError):
  99. parse_shape(backend.from_numpy(x_np), "a a b b")
  100. def test_ellipsis():
  101. for backend in _IMPERATIVE_BACKENDS:
  102. for shape, pattern, expected in [
  103. ([10, 20], "...", dict()),
  104. ([10], "... a", dict(a=10)),
  105. ([10, 20], "... a", dict(a=20)),
  106. ([10, 20, 30], "... a", dict(a=30)),
  107. ([10, 20, 30, 40], "... a", dict(a=40)),
  108. ([10], "a ...", dict(a=10)),
  109. ([10, 20], "a ...", dict(a=10)),
  110. ([10, 20, 30], "a ...", dict(a=10)),
  111. ([10, 20, 30, 40], "a ...", dict(a=10)),
  112. ([10, 20, 30, 40], " a ... b", dict(a=10, b=40)),
  113. ([10, 40], " a ... b", dict(a=10, b=40)),
  114. ]:
  115. x = np.ones(shape)
  116. parsed1 = parse_shape(x, pattern)
  117. parsed2 = parse_shape(backend.from_numpy(x), pattern)
  118. assert parsed1 == parsed2 == expected
  119. def test_parse_with_anonymous_axes():
  120. for backend in _IMPERATIVE_BACKENDS:
  121. for shape, pattern, expected in [
  122. ([1, 2, 3, 4], "1 2 3 a", dict(a=4)),
  123. ([10, 1, 2], "a 1 2", dict(a=10)),
  124. ([10, 1, 2], "a () 2", dict(a=10)),
  125. ]:
  126. x = np.ones(shape)
  127. parsed1 = parse_shape(x, pattern)
  128. parsed2 = parse_shape(backend.from_numpy(x), pattern)
  129. assert parsed1 == parsed2 == expected
  130. def test_failures():
  131. for backend in _IMPERATIVE_BACKENDS:
  132. # every test should fail
  133. for shape, pattern in [
  134. ([1, 2, 3, 4], "a b c"),
  135. ([1, 2, 3, 4], "2 a b c"),
  136. ([1, 2, 3, 4], "a b c ()"),
  137. ([1, 2, 3, 4], "a b c d e"),
  138. ([1, 2, 3, 4], "a b c d e ..."),
  139. ([1, 2, 3, 4], "a b c ()"),
  140. ]:
  141. with pytest.raises(RuntimeError):
  142. x = np.ones(shape)
  143. parse_shape(backend.from_numpy(x), pattern)
  144. _SYMBOLIC_BACKENDS = [
  145. *collect_test_backends(symbolic=True, layers=False),
  146. *collect_test_backends(symbolic=True, layers=True),
  147. ]
  148. # tensorflow.keras needs special way to compile,
  149. # shape vars can be used only inside layers but not as outputs
  150. _SYMBOLIC_BACKENDS = [backend for backend in _SYMBOLIC_BACKENDS if backend.framework_name != "tensorflow.keras"]
  151. @pytest.mark.parametrize("backend", _SYMBOLIC_BACKENDS)
  152. def test_parse_shape_symbolic(backend):
  153. for input_shape in [
  154. [10, 20, 30, 40],
  155. [10, 20, None, None],
  156. [None, None, None, None],
  157. ]:
  158. print(f"special shape parsing {backend.framework_name=} {input_shape=}")
  159. input_symbol = backend.create_symbol(input_shape)
  160. shape_placeholder = parse_shape(input_symbol, "a b c d")
  161. out_shape = {}
  162. for name, symbol in shape_placeholder.items():
  163. out_shape[name] = (
  164. symbol
  165. if isinstance(symbol, int)
  166. else backend.eval_symbol(symbol, [(input_symbol, np.zeros([10, 20, 30, 40]))])
  167. ) # out shape element is either int, or symbol that we are able to eval
  168. print(out_shape)
  169. result_placeholder = rearrange(
  170. input_symbol, "a b (c1 c2) (d1 d2) -> (a b d1) c1 (c2 d2)", **parse_shape(input_symbol, "a b c1 _"), d2=2
  171. )
  172. result = backend.eval_symbol(result_placeholder, [(input_symbol, np.zeros([10, 20, 30, 40]))])
  173. print(result.shape)
  174. assert result.shape == (10 * 20 * 20, 30, 1 * 2)
  175. assert np.allclose(result, 0)
  176. @pytest.mark.parametrize("backend", _SYMBOLIC_BACKENDS)
  177. def test_parse_shape_symbolic_ellipsis(backend):
  178. for static_shape, shape, pattern, expected in [
  179. ([10, 20], [None, None], "...", dict()),
  180. ([10], [None], "... a", dict(a=10)),
  181. ([10, 20], [None, None], "... a", dict(a=20)),
  182. ([10, 20, 30], [None, None, None], "... a", dict(a=30)),
  183. ([10, 20, 30, 40], [None, None, None, None], "... a", dict(a=40)),
  184. ([10], [None], "a ...", dict(a=10)),
  185. ([10, 20], [None, None], "a ...", dict(a=10)),
  186. ([10, 20, 30], [None, None, None], "a ...", dict(a=10)),
  187. ([10, 20, 30, 40], [None, None, None, None], "a ...", dict(a=10)),
  188. ([10, 20, 30, 40], [None, None, None, None], " a ... b", dict(a=10, b=40)),
  189. ([10, 40], [None, None], " a ... b ", dict(a=10, b=40)),
  190. ]:
  191. input_symbol = backend.create_symbol(shape)
  192. shape_placeholder = parse_shape(input_symbol, pattern)
  193. out_shape = {}
  194. for name, symbol in shape_placeholder.items():
  195. if isinstance(symbol, int):
  196. out_shape[name] = symbol
  197. else:
  198. out_shape[name] = backend.eval_symbol(symbol, [(input_symbol, np.zeros(static_shape))])
  199. assert out_shape == expected
  200. def test_is_float_type():
  201. backends = collect_test_backends(symbolic=False, layers=False)
  202. backends += collect_test_backends(symbolic=False, layers=True)
  203. for backend in backends:
  204. for dtype in ["int32", "int64", "float32", "float64"]:
  205. is_float = "float" in dtype
  206. input = np.zeros([3, 4, 5], dtype=dtype)
  207. input = backend.from_numpy(input)
  208. assert backend.is_float_type(input) == is_float, (dtype, backend, input.dtype)
  209. def test_torch_compile_for_functions():
  210. """
  211. Test ensures that allow_ops_in_compiled_graph allows compiling in a single graph
  212. Additionally we ensure that after compilation cache works properly
  213. (by changing shapes and patterns)
  214. We additionally check that pack/unpack still can be handled
  215. despite variable number of inputs/outputs
  216. """
  217. if not is_backend_tested("torch"):
  218. pytest.skip()
  219. import torch
  220. from torch import nn
  221. from einops import einsum, pack, reduce, repeat, unpack
  222. from einops._torch_specific import allow_ops_in_compiled_graph
  223. allow_ops_in_compiled_graph()
  224. class TorchModuleWithOperations(nn.Module):
  225. def __init__(self) -> None:
  226. super().__init__()
  227. def forward(self, x_abc, suffix=""):
  228. a, b, c = x_abc.shape
  229. def suf(pattern):
  230. parts = pattern.split()
  231. return " ".join([p if p[-1] not in "acd" else p + suffix for p in parts])
  232. # patterns look a bit strange because names a, c, d will be modified on every run
  233. # by suf function
  234. x_abcd = repeat(x_abc, suf("a b c -> a b c 4"))
  235. x_abc = reduce(x_abcd, suf("a b c d -> a b c"), "min")
  236. x_abdc, ps = pack([x_abc] * (2 + len(suffix)), suf("a b * c"))
  237. x_array = unpack(rearrange(x_abdc, suf("a b d c -> (a b ) 1 c d")), ps, "ab one1 c *")
  238. x1 = x_array[0] + len(x_array)
  239. x1 = rearrange(x1, suf("(a b ) 1 c -> a b c"), b=b)
  240. addition = einsum(x_abc, x_abcd, suf("a b c , a b c d -> d"))[0]
  241. return x1 + addition
  242. original = TorchModuleWithOperations()
  243. compiled = torch.compile(original, fullgraph=True)
  244. for size in [10, 20, 40]:
  245. x = torch.rand([size, size + 1, size + 2])
  246. for suffix in ["", "suf1", "other_suffix"]:
  247. result1 = compiled(x, suffix)
  248. result2 = original(x.double(), suffix).float()
  249. torch.testing.assert_close(result1, result2, atol=1e-5, rtol=1e-5)
  250. def test_torch_compile_for_layers():
  251. """
  252. Einops layers are in general very friendly towards tracing/compiling,
  253. but we still want to make sure we can compile them.
  254. """
  255. if not is_backend_tested("torch"):
  256. pytest.skip()
  257. import torch
  258. from torch import nn
  259. from einops.layers.torch import EinMix, Rearrange, Reduce
  260. original = nn.Sequential(
  261. Rearrange("b (t c) -> b t c", c=16),
  262. EinMix("b t c -> qkv b t cout", weight_shape="qkv c cout", bias_shape="qkv cout", qkv=3, c=16, cout=8),
  263. Reduce("qkv b t cout -> b t qkv", "min", cout=8),
  264. )
  265. compiled = torch.compile(original, fullgraph=True)
  266. for size in [16, 32, 64]:
  267. x = torch.rand([size, size])
  268. result1 = original(x)
  269. result2 = compiled(x)
  270. assert torch.allclose(result1, result2)
  271. src = """
  272. import einops
  273. import numpy as np
  274. from concurrent.futures import ThreadPoolExecutor
  275. import torch
  276. def f():
  277. return einops.rearrange(np.ndarray((20, 150, 150)), "... i j -> ... j i")
  278. with ThreadPoolExecutor(max_workers=2) as ex:
  279. fs = []
  280. for i in range(20):
  281. fs.append(ex.submit(f))
  282. for fut in fs:
  283. fut.result()
  284. """
  285. def test_einops_threading():
  286. # requires both. Reproduces problem from https://github.com/arogozhnikov/einops/issues/391
  287. if not is_backend_tested("torch"):
  288. pytest.skip()
  289. if not is_backend_tested("numpy"):
  290. pytest.skip()
  291. with tempfile.TemporaryDirectory() as d:
  292. testfile = Path(d).joinpath("test.py")
  293. testfile.write_text(src)
  294. subprocess.run(["python", testfile.absolute().as_posix()], check=True)