test_layers.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. import pickle
  2. from collections import namedtuple
  3. import numpy as np
  4. import pytest
  5. from einops import EinopsError, rearrange, reduce
  6. from einops.tests import FLOAT_REDUCTIONS as REDUCTIONS
  7. from einops.tests import collect_test_backends, is_backend_tested
  8. __author__ = "Alex Rogozhnikov"
  9. testcase = namedtuple("testcase", ["pattern", "axes_lengths", "input_shape", "wrong_shapes"])
  10. rearrangement_patterns = [
  11. testcase(
  12. "b c h w -> b (c h w)",
  13. dict(c=20),
  14. (10, 20, 30, 40),
  15. [(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]],
  16. ),
  17. testcase(
  18. "b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1",
  19. dict(h2=2, w2=2),
  20. (10, 20, 30, 40),
  21. [(), (1, 1, 1, 1), (1, 10, 3), ()],
  22. ),
  23. testcase(
  24. "b ... c -> c b ...",
  25. dict(b=10),
  26. (10, 20, 30),
  27. [(), (10,), (5, 10)],
  28. ),
  29. ]
  30. def test_rearrange_imperative():
  31. for backend in collect_test_backends(symbolic=False, layers=True):
  32. print("Test layer for ", backend.framework_name)
  33. for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
  34. x = np.arange(np.prod(input_shape), dtype="float32").reshape(input_shape)
  35. result_numpy = rearrange(x, pattern, **axes_lengths)
  36. layer = backend.layers().Rearrange(pattern, **axes_lengths)
  37. for shape in wrong_shapes:
  38. try:
  39. layer(backend.from_numpy(np.zeros(shape, dtype="float32")))
  40. except BaseException:
  41. pass
  42. else:
  43. raise AssertionError("Failure expected")
  44. # simple pickling / unpickling
  45. layer2 = pickle.loads(pickle.dumps(layer))
  46. result1 = backend.to_numpy(layer(backend.from_numpy(x)))
  47. result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
  48. assert np.allclose(result_numpy, result1)
  49. assert np.allclose(result1, result2)
  50. just_sum = backend.layers().Reduce("...->", reduction="sum")
  51. variable = backend.from_numpy(x)
  52. result = just_sum(layer(variable))
  53. result.backward()
  54. assert np.allclose(backend.to_numpy(variable.grad), 1)
  55. def test_rearrange_symbolic():
  56. for backend in collect_test_backends(symbolic=True, layers=True):
  57. print("Test layer for ", backend.framework_name)
  58. for pattern, axes_lengths, input_shape, _wrong_shapes in rearrangement_patterns:
  59. x = np.arange(np.prod(input_shape), dtype="float32").reshape(input_shape)
  60. result_numpy = rearrange(x, pattern, **axes_lengths)
  61. layer = backend.layers().Rearrange(pattern, **axes_lengths)
  62. input_shape_of_nones = [None] * len(input_shape)
  63. shapes = [input_shape, input_shape_of_nones]
  64. for shape in shapes:
  65. symbol = backend.create_symbol(shape)
  66. eval_inputs = [(symbol, x)]
  67. result_symbol1 = layer(symbol)
  68. result1 = backend.eval_symbol(result_symbol1, eval_inputs)
  69. assert np.allclose(result_numpy, result1)
  70. layer2 = pickle.loads(pickle.dumps(layer))
  71. result_symbol2 = layer2(symbol)
  72. result2 = backend.eval_symbol(result_symbol2, eval_inputs)
  73. assert np.allclose(result1, result2)
  74. # now testing back-propagation
  75. just_sum = backend.layers().Reduce("...->", reduction="sum")
  76. result_sum1 = backend.eval_symbol(just_sum(result_symbol1), eval_inputs)
  77. result_sum2 = np.sum(x)
  78. assert np.allclose(result_sum1, result_sum2)
  79. reduction_patterns = [
  80. *rearrangement_patterns,
  81. testcase("b c h w -> b ()", dict(b=10), (10, 20, 30, 40), [(10,), (10, 20, 30)]),
  82. testcase("b c (h1 h2) (w1 w2) -> b c h1 w1", dict(h1=15, h2=2, w2=2), (10, 20, 30, 40), [(10, 20, 31, 40)]),
  83. testcase("b ... c -> b", dict(b=10), (10, 20, 30, 40), [(10,), (11, 10)]),
  84. ]
  85. def test_reduce_imperative():
  86. for backend in collect_test_backends(symbolic=False, layers=True):
  87. print("Test layer for ", backend.framework_name)
  88. for reduction in REDUCTIONS:
  89. for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
  90. print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
  91. x = np.arange(1, 1 + np.prod(input_shape), dtype="float32").reshape(input_shape)
  92. x /= x.mean()
  93. result_numpy = reduce(x, pattern, reduction, **axes_lengths)
  94. layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
  95. for shape in wrong_shapes:
  96. try:
  97. layer(backend.from_numpy(np.zeros(shape, dtype="float32")))
  98. except BaseException:
  99. pass
  100. else:
  101. raise AssertionError("Failure expected")
  102. # simple pickling / unpickling
  103. layer2 = pickle.loads(pickle.dumps(layer))
  104. result1 = backend.to_numpy(layer(backend.from_numpy(x)))
  105. result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
  106. assert np.allclose(result_numpy, result1)
  107. assert np.allclose(result1, result2)
  108. just_sum = backend.layers().Reduce("...->", reduction="sum")
  109. variable = backend.from_numpy(x)
  110. result = just_sum(layer(variable))
  111. result.backward()
  112. grad = backend.to_numpy(variable.grad)
  113. if reduction == "sum":
  114. assert np.allclose(grad, 1)
  115. if reduction == "mean":
  116. assert np.allclose(grad, grad.min())
  117. if reduction in ["max", "min"]:
  118. assert np.all(np.isin(grad, [0, 1]))
  119. assert np.sum(grad) > 0.5
  120. def test_reduce_symbolic():
  121. for backend in collect_test_backends(symbolic=True, layers=True):
  122. print("Test layer for ", backend.framework_name)
  123. for reduction in REDUCTIONS:
  124. for pattern, axes_lengths, input_shape, _wrong_shapes in reduction_patterns:
  125. x = np.arange(1, 1 + np.prod(input_shape), dtype="float32").reshape(input_shape)
  126. x /= x.mean()
  127. result_numpy = reduce(x, pattern, reduction, **axes_lengths)
  128. layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
  129. input_shape_of_nones = [None] * len(input_shape)
  130. shapes = [input_shape, input_shape_of_nones]
  131. for shape in shapes:
  132. symbol = backend.create_symbol(shape)
  133. eval_inputs = [(symbol, x)]
  134. result_symbol1 = layer(symbol)
  135. result1 = backend.eval_symbol(result_symbol1, eval_inputs)
  136. assert np.allclose(result_numpy, result1)
  137. layer2 = pickle.loads(pickle.dumps(layer))
  138. result_symbol2 = layer2(symbol)
  139. result2 = backend.eval_symbol(result_symbol2, eval_inputs)
  140. assert np.allclose(result1, result2)
  141. def create_torch_model(use_reduce=False, add_scripted_layer=False):
  142. if not is_backend_tested("torch"):
  143. pytest.skip()
  144. else:
  145. import torch.jit
  146. from torch.nn import Conv2d, Linear, MaxPool2d, ReLU, Sequential
  147. from einops.layers.torch import EinMix, Rearrange, Reduce
  148. return Sequential(
  149. Conv2d(3, 6, kernel_size=(5, 5)),
  150. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
  151. Conv2d(6, 16, kernel_size=(5, 5)),
  152. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
  153. torch.jit.script(Rearrange("b c h w -> b (c h w)"))
  154. if add_scripted_layer
  155. else Rearrange("b c h w -> b (c h w)"),
  156. Linear(16 * 5 * 5, 120),
  157. ReLU(),
  158. Linear(120, 84),
  159. ReLU(),
  160. EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84),
  161. EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84),
  162. Linear(84, 10),
  163. )
  164. def test_torch_layer():
  165. if not is_backend_tested("torch"):
  166. pytest.skip()
  167. else:
  168. # checked that torch present
  169. import torch
  170. import torch.jit
  171. model1 = create_torch_model(use_reduce=True)
  172. model2 = create_torch_model(use_reduce=False)
  173. input = torch.randn([10, 3, 32, 32])
  174. # random models have different predictions
  175. assert not torch.allclose(model1(input), model2(input))
  176. model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
  177. assert torch.allclose(model1(input), model2(input))
  178. # tracing (freezing)
  179. model3 = torch.jit.trace(model2, example_inputs=input)
  180. torch.testing.assert_close(model1(input), model3(input), atol=1e-3, rtol=1e-3)
  181. torch.testing.assert_close(model1(input + 1), model3(input + 1), atol=1e-3, rtol=1e-3)
  182. model4 = torch.jit.trace(model2, example_inputs=input)
  183. torch.testing.assert_close(model1(input), model4(input), atol=1e-3, rtol=1e-3)
  184. torch.testing.assert_close(model1(input + 1), model4(input + 1), atol=1e-3, rtol=1e-3)
  185. def test_torch_layers_scripting():
  186. if not is_backend_tested("torch"):
  187. pytest.skip()
  188. else:
  189. import torch
  190. for script_layer in [False, True]:
  191. model1 = create_torch_model(use_reduce=True, add_scripted_layer=script_layer)
  192. model2 = torch.jit.script(model1)
  193. input = torch.randn([10, 3, 32, 32])
  194. torch.testing.assert_close(model1(input), model2(input), atol=1e-3, rtol=1e-3)
  195. def test_keras_layer():
  196. rng = np.random.default_rng()
  197. if not is_backend_tested("tensorflow"):
  198. pytest.skip()
  199. else:
  200. import tensorflow as tf
  201. if tf.__version__ < "2.16.":
  202. # current implementation of layers follows new TF interface
  203. pytest.skip()
  204. from tensorflow.keras.layers import Conv2D as Conv2d
  205. from tensorflow.keras.layers import Dense as Linear
  206. from tensorflow.keras.layers import ReLU
  207. from tensorflow.keras.models import Sequential
  208. from einops.layers.keras import EinMix, Rearrange, Reduce, keras_custom_objects
  209. def create_keras_model():
  210. return Sequential(
  211. [
  212. Conv2d(6, kernel_size=5, input_shape=[32, 32, 3]),
  213. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
  214. Conv2d(16, kernel_size=5),
  215. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
  216. Rearrange("b c h w -> b (c h w)"),
  217. Linear(120),
  218. ReLU(),
  219. Linear(84),
  220. ReLU(),
  221. EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84),
  222. EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84),
  223. Linear(10),
  224. ]
  225. )
  226. model1 = create_keras_model()
  227. model2 = create_keras_model()
  228. input = rng.normal(size=[10, 32, 32, 3]).astype("float32")
  229. # two randomly init models should provide different outputs
  230. assert not np.allclose(model1.predict_on_batch(input), model2.predict_on_batch(input))
  231. # get some temp filename
  232. tmp_model_filename = "/tmp/einops_tf_model.h5"
  233. # save arch + weights
  234. print("temp_path_keras1", tmp_model_filename)
  235. tf.keras.models.save_model(model1, tmp_model_filename)
  236. model3 = tf.keras.models.load_model(tmp_model_filename, custom_objects=keras_custom_objects)
  237. np.testing.assert_allclose(model1.predict_on_batch(input), model3.predict_on_batch(input))
  238. weight_filename = "/tmp/einops_tf_model.weights.h5"
  239. # save arch as json
  240. model4 = tf.keras.models.model_from_json(model1.to_json(), custom_objects=keras_custom_objects)
  241. model1.save_weights(weight_filename)
  242. model4.load_weights(weight_filename)
  243. model2.load_weights(weight_filename)
  244. # check that differently-inialized model receives same weights
  245. np.testing.assert_allclose(model1.predict_on_batch(input), model2.predict_on_batch(input))
  246. # ulimate test
  247. # save-load architecture, and then load weights - should return same result
  248. np.testing.assert_allclose(model1.predict_on_batch(input), model4.predict_on_batch(input))
  249. def test_flax_layers():
  250. """
  251. One-off simple tests for Flax layers.
  252. Unfortunately, Flax layers have a different interface from other layers.
  253. """
  254. if not is_backend_tested("jax"):
  255. pytest.skip()
  256. else:
  257. import flax
  258. import jax
  259. import jax.numpy as jnp
  260. from flax import linen as nn
  261. from einops.layers.flax import EinMix, Rearrange, Reduce
  262. class NN(nn.Module):
  263. @nn.compact
  264. def __call__(self, x):
  265. x = EinMix(
  266. "b (h h2) (w w2) c -> b h w c_out", "h2 w2 c c_out", "c_out", sizes=dict(h2=2, w2=3, c=4, c_out=5)
  267. )(x)
  268. x = Rearrange("b h w c -> b (w h c)", sizes=dict(c=5))(x)
  269. x = Reduce("b hwc -> b", "mean", dict(hwc=2 * 3 * 5))(x)
  270. return x
  271. model = NN()
  272. fixed_input = jnp.ones([10, 2 * 2, 3 * 3, 4])
  273. params = model.init(jax.random.PRNGKey(0), fixed_input)
  274. def eval_at_point(params):
  275. return jnp.linalg.norm(model.apply(params, fixed_input))
  276. vandg = jax.value_and_grad(eval_at_point)
  277. value0 = eval_at_point(params)
  278. value1, grad1 = vandg(params)
  279. assert jnp.allclose(value0, value1)
  280. if jax.__version__ < "0.6.0":
  281. tree_map = jax.tree_map
  282. else:
  283. tree_map = jax.tree.map
  284. params2 = tree_map(lambda x1, x2: x1 - x2 * 0.001, params, grad1)
  285. value2 = eval_at_point(params2)
  286. assert value0 >= value2, (value0, value2)
  287. # check serialization
  288. fbytes = flax.serialization.to_bytes(params)
  289. _loaded = flax.serialization.from_bytes(params, fbytes)
  290. def test_einmix_decomposition():
  291. """
  292. Testing that einmix correctly decomposes into smaller transformations.
  293. """
  294. from einops.layers._einmix import _EinmixDebugger
  295. mixin1 = _EinmixDebugger(
  296. "a b c d e -> e d c b a",
  297. weight_shape="d a b",
  298. d=2, a=3, b=5,
  299. ) # fmt: off
  300. assert mixin1.pre_reshape_pattern is None
  301. assert mixin1.post_reshape_pattern is None
  302. assert mixin1.einsum_pattern == "abcde,dab->edcba"
  303. assert mixin1.saved_weight_shape == [2, 3, 5]
  304. assert mixin1.saved_bias_shape is None
  305. mixin2 = _EinmixDebugger(
  306. "a b c d e -> e d c b a",
  307. weight_shape="d a b",
  308. bias_shape="a b c d e",
  309. a=1, b=2, c=3, d=4, e=5,
  310. ) # fmt: off
  311. assert mixin2.pre_reshape_pattern is None
  312. assert mixin2.post_reshape_pattern is None
  313. assert mixin2.einsum_pattern == "abcde,dab->edcba"
  314. assert mixin2.saved_weight_shape == [4, 1, 2]
  315. assert mixin2.saved_bias_shape == [5, 4, 3, 2, 1]
  316. mixin3 = _EinmixDebugger(
  317. "... -> ...",
  318. weight_shape="",
  319. bias_shape="",
  320. ) # fmt: off
  321. assert mixin3.pre_reshape_pattern is None
  322. assert mixin3.post_reshape_pattern is None
  323. assert mixin3.einsum_pattern == "...,->..."
  324. assert mixin3.saved_weight_shape == []
  325. assert mixin3.saved_bias_shape == []
  326. mixin4 = _EinmixDebugger(
  327. "b a ... -> b c ...",
  328. weight_shape="b a c",
  329. a=1, b=2, c=3,
  330. ) # fmt: off
  331. assert mixin4.pre_reshape_pattern is None
  332. assert mixin4.post_reshape_pattern is None
  333. assert mixin4.einsum_pattern == "ba...,bac->bc..."
  334. assert mixin4.saved_weight_shape == [2, 1, 3]
  335. assert mixin4.saved_bias_shape is None
  336. mixin5 = _EinmixDebugger(
  337. "(b a) ... -> b c (...)",
  338. weight_shape="b a c",
  339. a=1, b=2, c=3,
  340. ) # fmt: off
  341. assert mixin5.pre_reshape_pattern == "(b a) ... -> b a ..."
  342. assert mixin5.pre_reshape_lengths == dict(a=1, b=2)
  343. assert mixin5.post_reshape_pattern == "b c ... -> b c (...)"
  344. assert mixin5.einsum_pattern == "ba...,bac->bc..."
  345. assert mixin5.saved_weight_shape == [2, 1, 3]
  346. assert mixin5.saved_bias_shape is None
  347. mixin6 = _EinmixDebugger(
  348. "b ... (a c) -> b ... (a d)",
  349. weight_shape="c d",
  350. bias_shape="a d",
  351. a=1, c=3, d=4,
  352. ) # fmt: off
  353. assert mixin6.pre_reshape_pattern == "b ... (a c) -> b ... a c"
  354. assert mixin6.pre_reshape_lengths == dict(a=1, c=3)
  355. assert mixin6.post_reshape_pattern == "b ... a d -> b ... (a d)"
  356. assert mixin6.einsum_pattern == "b...ac,cd->b...ad"
  357. assert mixin6.saved_weight_shape == [3, 4]
  358. assert mixin6.saved_bias_shape == [1, 1, 4] # (b) a d, ellipsis does not participate
  359. mixin7 = _EinmixDebugger(
  360. "a ... (b c) -> a (... d b)",
  361. weight_shape="c d b",
  362. bias_shape="d b",
  363. b=2, c=3, d=4,
  364. ) # fmt: off
  365. assert mixin7.pre_reshape_pattern == "a ... (b c) -> a ... b c"
  366. assert mixin7.pre_reshape_lengths == dict(b=2, c=3)
  367. assert mixin7.post_reshape_pattern == "a ... d b -> a (... d b)"
  368. assert mixin7.einsum_pattern == "a...bc,cdb->a...db"
  369. assert mixin7.saved_weight_shape == [3, 4, 2]
  370. assert mixin7.saved_bias_shape == [1, 4, 2] # (a) d b, ellipsis does not participate
  371. def test_einmix_restrictions():
  372. """
  373. Testing different cases
  374. """
  375. from einops.layers._einmix import _EinmixDebugger
  376. with pytest.raises(EinopsError):
  377. _EinmixDebugger(
  378. "a b c d e -> e d c b a",
  379. weight_shape="d a b",
  380. d=2, a=3, # missing b
  381. ) # fmt: off
  382. with pytest.raises(EinopsError):
  383. _EinmixDebugger(
  384. "a b c d e -> e d c b a",
  385. weight_shape="w a b",
  386. d=2, a=3, b=1 # missing d
  387. ) # fmt: off
  388. with pytest.raises(EinopsError):
  389. _EinmixDebugger(
  390. "(...) a -> ... a",
  391. weight_shape="a", a=1, # ellipsis on the left
  392. ) # fmt: off
  393. with pytest.raises(EinopsError):
  394. _EinmixDebugger(
  395. "(...) a -> a ...",
  396. weight_shape="a", a=1, # ellipsis on the right side after bias axis
  397. bias_shape="a",
  398. ) # fmt: off