test_einsum.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. import string
  2. from typing import Any, Callable
  3. import numpy as np
  4. import pytest
  5. from einops.einops import EinopsError, _compactify_pattern_for_einsum, einsum
  6. from einops.tests import collect_test_backends
  7. class Arguments:
  8. def __init__(self, *args: Any, **kargs: Any):
  9. self.args = args
  10. self.kwargs = kargs
  11. def __call__(self, function: Callable):
  12. return function(*self.args, **self.kwargs)
  13. test_layer_cases = [
  14. (
  15. Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape=None, c_out=13, c_in=12),
  16. (2, 12, 3, 4),
  17. (4, 13, 3, 2),
  18. ),
  19. (
  20. Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape="c_out", c_out=13, c_in=12),
  21. (2, 12, 3, 4),
  22. (4, 13, 3, 2),
  23. ),
  24. (
  25. Arguments("b c_in h w -> w c_in h b", "", bias_shape=None, c_in=12),
  26. (2, 12, 3, 4),
  27. (4, 12, 3, 2),
  28. ),
  29. (
  30. Arguments("b c_in h w -> b c_out", "c_in h w c_out", bias_shape=None, c_in=12, h=3, w=4, c_out=5),
  31. (2, 12, 3, 4),
  32. (2, 5),
  33. ),
  34. (
  35. Arguments("b t head c_in -> b t head c_out", "head c_in c_out", bias_shape=None, head=4, c_in=5, c_out=6),
  36. (2, 3, 4, 5),
  37. (2, 3, 4, 6),
  38. ),
  39. ]
  40. # Each of the form:
  41. # (Arguments, true_einsum_pattern, in_shapes, out_shape)
  42. test_functional_cases = [
  43. (
  44. # Basic:
  45. "b c h w, b w -> b h",
  46. "abcd,ad->ac",
  47. ((2, 3, 4, 5), (2, 5)),
  48. (2, 4),
  49. ),
  50. (
  51. # Three tensors:
  52. "b c h w, b w, b c -> b h",
  53. "abcd,ad,ab->ac",
  54. ((2, 3, 40, 5), (2, 5), (2, 3)),
  55. (2, 40),
  56. ),
  57. (
  58. # Ellipsis, and full names:
  59. "... one two three, three four five -> ... two five",
  60. "...abc,cde->...be",
  61. ((32, 5, 2, 3, 4), (4, 5, 6)),
  62. (32, 5, 3, 6),
  63. ),
  64. (
  65. # Ellipsis at the end:
  66. "one two three ..., three four five -> two five ...",
  67. "abc...,cde->be...",
  68. ((2, 3, 4, 32, 5), (4, 5, 6)),
  69. (3, 6, 32, 5),
  70. ),
  71. (
  72. # Ellipsis on multiple tensors:
  73. "... one two three, ... three four five -> ... two five",
  74. "...abc,...cde->...be",
  75. ((32, 5, 2, 3, 4), (32, 5, 4, 5, 6)),
  76. (32, 5, 3, 6),
  77. ),
  78. (
  79. # One tensor, and underscores:
  80. "first_tensor second_tensor -> first_tensor",
  81. "ab->a",
  82. ((5, 4),),
  83. (5,),
  84. ),
  85. (
  86. # Trace (repeated index)
  87. "i i -> ",
  88. "aa->",
  89. ((5, 5),),
  90. (),
  91. ),
  92. (
  93. # Too many spaces in string:
  94. " one two , three four->two four ",
  95. "ab,cd->bd",
  96. ((2, 3), (4, 5)),
  97. (3, 5),
  98. ),
  99. # The following tests were inspired by numpy's einsum tests
  100. # https://github.com/numpy/numpy/blob/v1.23.0/numpy/core/tests/test_einsum.py
  101. (
  102. # Trace with other indices
  103. "i middle i -> middle",
  104. "aba->b",
  105. ((5, 10, 5),),
  106. (10,),
  107. ),
  108. (
  109. # Ellipsis in the middle:
  110. "i ... i -> ...",
  111. "a...a->...",
  112. ((5, 3, 2, 1, 4, 5),),
  113. (3, 2, 1, 4),
  114. ),
  115. (
  116. # Product of first and last axes:
  117. "i ... i -> i ...",
  118. "a...a->a...",
  119. ((5, 3, 2, 1, 4, 5),),
  120. (5, 3, 2, 1, 4),
  121. ),
  122. (
  123. # Triple diagonal
  124. "one one one -> one",
  125. "aaa->a",
  126. ((5, 5, 5),),
  127. (5,),
  128. ),
  129. (
  130. # Axis swap:
  131. "i j k -> j i k",
  132. "abc->bac",
  133. ((1, 2, 3),),
  134. (2, 1, 3),
  135. ),
  136. (
  137. # Identity:
  138. "... -> ...",
  139. "...->...",
  140. ((5, 4, 3, 2, 1),),
  141. (5, 4, 3, 2, 1),
  142. ),
  143. (
  144. # Elementwise product of three tensors
  145. "..., ..., ... -> ...",
  146. "...,...,...->...",
  147. ((3, 2), (3, 2), (3, 2)),
  148. (3, 2),
  149. ),
  150. (
  151. # Basic summation:
  152. "index ->",
  153. "a->",
  154. ((10,)),
  155. (()),
  156. ),
  157. ]
  158. def test_layer():
  159. for backend in collect_test_backends(layers=True, symbolic=False):
  160. rng = np.random.default_rng()
  161. if backend.framework_name in ["tensorflow", "torch", "oneflow", "paddle"]:
  162. layer_type = backend.layers().EinMix
  163. for args, in_shape, out_shape in test_layer_cases:
  164. layer = args(layer_type)
  165. print("Running", layer.einsum_pattern, "for", backend.framework_name)
  166. input = rng.uniform(size=in_shape).astype("float32")
  167. input_framework = backend.from_numpy(input)
  168. output_framework = layer(input_framework)
  169. output = backend.to_numpy(output_framework)
  170. assert output.shape == out_shape
  171. valid_backends_functional = [
  172. "tensorflow",
  173. "torch",
  174. "jax",
  175. "numpy",
  176. "oneflow",
  177. "cupy",
  178. "tensorflow.keras",
  179. "paddle",
  180. "pytensor",
  181. "mlx",
  182. ]
  183. def test_functional():
  184. # Functional tests:
  185. backends = filter(lambda x: x.framework_name in valid_backends_functional, collect_test_backends())
  186. for backend in backends:
  187. for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
  188. print(f"Running '{einops_pattern}' for {backend.framework_name}")
  189. # Create pattern:
  190. predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
  191. assert predicted_pattern == true_pattern
  192. # Generate example data:
  193. rstate = np.random.RandomState(0)
  194. in_arrays = [rstate.uniform(size=shape).astype("float32") for shape in in_shapes]
  195. in_arrays_framework = [backend.from_numpy(array) for array in in_arrays]
  196. # Loop over whether we call it manually with the backend,
  197. # or whether we use `einops.einsum`.
  198. for do_manual_call in [True, False]:
  199. # Actually run einsum:
  200. if do_manual_call:
  201. out_array = backend.einsum(predicted_pattern, *in_arrays_framework)
  202. else:
  203. out_array = einsum(*in_arrays_framework, einops_pattern)
  204. # Check shape:
  205. if tuple(out_array.shape) != out_shape:
  206. raise ValueError(f"Expected output shape {out_shape} but got {out_array.shape}")
  207. # Check values:
  208. true_out_array = np.einsum(true_pattern, *in_arrays)
  209. predicted_out_array = backend.to_numpy(out_array)
  210. np.testing.assert_array_almost_equal(predicted_out_array, true_out_array, decimal=5)
  211. def test_functional_symbolic():
  212. backends = filter(
  213. lambda x: x.framework_name in valid_backends_functional, collect_test_backends(symbolic=True, layers=False)
  214. )
  215. for backend in backends:
  216. for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
  217. print(f"Running '{einops_pattern}' for symbolic {backend.framework_name}")
  218. # Create pattern:
  219. predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
  220. assert predicted_pattern == true_pattern
  221. rstate = np.random.RandomState(0)
  222. in_syms = [backend.create_symbol(in_shape) for in_shape in in_shapes]
  223. in_data = [rstate.uniform(size=in_shape).astype("float32") for in_shape in in_shapes]
  224. expected_out_data = np.einsum(true_pattern, *in_data)
  225. for do_manual_call in [True, False]:
  226. if do_manual_call:
  227. predicted_out_symbol = backend.einsum(predicted_pattern, *in_syms)
  228. else:
  229. predicted_out_symbol = einsum(*in_syms, einops_pattern)
  230. predicted_out_data = backend.eval_symbol(
  231. predicted_out_symbol,
  232. list(zip(in_syms, in_data)),
  233. )
  234. if predicted_out_data.shape != out_shape:
  235. raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
  236. np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
  237. def test_functional_errors():
  238. # Specific backend does not matter, as errors are raised
  239. # during the pattern creation.
  240. rstate = np.random.RandomState(0)
  241. def create_tensor(*shape):
  242. return rstate.uniform(size=shape).astype("float32")
  243. # raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
  244. with pytest.raises(NotImplementedError, match="^Singleton"):
  245. einsum(
  246. create_tensor(5, 1),
  247. "i () -> i",
  248. )
  249. # raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
  250. with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
  251. einsum(
  252. create_tensor(5, 1),
  253. "a b -> (a b)",
  254. )
  255. with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
  256. einsum(
  257. create_tensor(10, 1),
  258. "(a b) -> a b",
  259. )
  260. # raise RuntimeError("Encountered empty axis name in einsum.")
  261. # raise RuntimeError("Axis name in einsum must be a string.")
  262. # ^ Not tested, these are just a failsafe in case an unexpected error occurs.
  263. # raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
  264. with pytest.raises(NotImplementedError, match="^Anonymous axes"):
  265. einsum(
  266. create_tensor(5, 1),
  267. "i 2 -> i",
  268. )
  269. # ParsedExpression error:
  270. with pytest.raises(EinopsError, match="^Invalid axis identifier"):
  271. einsum(
  272. create_tensor(5, 1),
  273. "i 2j -> i",
  274. )
  275. # raise ValueError("Einsum pattern must contain '->'.")
  276. with pytest.raises(ValueError, match="^Einsum pattern"):
  277. einsum(
  278. create_tensor(5, 3, 2),
  279. "i j k",
  280. )
  281. # raise RuntimeError("Too many axes in einsum.")
  282. with pytest.raises(RuntimeError, match="^Too many axes"):
  283. einsum(
  284. create_tensor(1),
  285. " ".join(string.ascii_letters) + " extra ->",
  286. )
  287. # raise RuntimeError("Unknown axis on right side of einsum.")
  288. with pytest.raises(RuntimeError, match="^Unknown axis"):
  289. einsum(
  290. create_tensor(5, 1),
  291. "i j -> k",
  292. )
  293. # raise ValueError(
  294. # "The last argument passed to `einops.einsum` must be a string,"
  295. # " representing the einsum pattern."
  296. # )
  297. with pytest.raises(ValueError, match="^The last argument"):
  298. einsum(
  299. "i j k -> i",
  300. create_tensor(5, 4, 3),
  301. )
  302. # raise ValueError(
  303. # "`einops.einsum` takes at minimum two arguments: the tensors,"
  304. # " followed by the pattern."
  305. # )
  306. with pytest.raises(ValueError, match="^`einops.einsum` takes"):
  307. einsum(
  308. "i j k -> i",
  309. )
  310. with pytest.raises(ValueError, match="^`einops.einsum` takes"):
  311. einsum(
  312. create_tensor(5, 1),
  313. )
  314. # TODO: Include check for giving normal einsum pattern rather than einops.