test_examples.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import numpy as np
  2. import pytest
  3. from einops import parse_shape, rearrange, reduce
  4. from einops.tests import is_backend_tested
  5. from einops.tests.test_ops import imp_op_backends
  6. def test_rearrange_examples():
  7. def test1(x):
  8. # transpose
  9. y = rearrange(x, "b c h w -> b h w c")
  10. assert tuple(y.shape) == (10, 30, 40, 20)
  11. return y
  12. def test2(x):
  13. # view / reshape
  14. y = rearrange(x, "b c h w -> b (c h w)")
  15. assert tuple(y.shape) == (10, 20 * 30 * 40)
  16. return y
  17. def test3(x):
  18. # depth-to-space
  19. y = rearrange(x, "b (c h1 w1) h w -> b c (h h1) (w w1)", h1=2, w1=2)
  20. assert tuple(y.shape) == (10, 5, 30 * 2, 40 * 2)
  21. return y
  22. def test4(x):
  23. # space-to-depth
  24. y = rearrange(x, "b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=2, w1=2)
  25. assert tuple(y.shape) == (10, 20 * 4, 30 // 2, 40 // 2)
  26. return y
  27. def test5(x):
  28. # simple transposition
  29. y = rearrange(x, "b1 sound b2 letter -> b1 b2 sound letter")
  30. assert tuple(y.shape) == (10, 30, 20, 40)
  31. return y
  32. def test6(x):
  33. # parsing parameters
  34. t = rearrange(x, "b c h w -> (b h w) c")
  35. t = t[:, ::2] # replacement for dot-product, just changes size of second axis
  36. assert tuple(t.shape) == (10 * 30 * 40, 10)
  37. y = rearrange(t, "(b h w) c2 -> b c2 h w", **parse_shape(x, "b _ h w"))
  38. assert tuple(y.shape) == (10, 10, 30, 40)
  39. return y
  40. def test7(x):
  41. # split of embedding into groups
  42. y1, y2 = rearrange(x, "b (c g) h w -> g b c h w", g=2)
  43. assert tuple(y1.shape) == (10, 10, 30, 40)
  44. assert tuple(y2.shape) == (10, 10, 30, 40)
  45. return y1 + y2 # only one tensor is expected in output
  46. def test8(x):
  47. # max-pooling
  48. y = reduce(x, "b c (h h1) (w w1) -> b c h w", reduction="max", h1=2, w1=2)
  49. assert tuple(y.shape) == (10, 20, 30 // 2, 40 // 2)
  50. return y
  51. def test9(x):
  52. # squeeze - unsqueeze
  53. y = reduce(x, "b c h w -> b c () ()", reduction="max")
  54. assert tuple(y.shape) == (10, 20, 1, 1)
  55. y = rearrange(y, "b c () () -> c b")
  56. assert tuple(y.shape) == (20, 10)
  57. return y
  58. def test10(x):
  59. # stack
  60. tensors = list(x + 0) # 0 is needed https://github.com/tensorflow/tensorflow/issues/23185
  61. tensors = rearrange(tensors, "b c h w -> b h w c")
  62. assert tuple(tensors.shape) == (10, 30, 40, 20)
  63. return tensors
  64. def test11(x):
  65. # concatenate
  66. tensors = list(x + 0) # 0 is needed https://github.com/tensorflow/tensorflow/issues/23185
  67. tensors = rearrange(tensors, "b c h w -> h (b w) c")
  68. assert tuple(tensors.shape) == (30, 10 * 40, 20)
  69. return tensors
  70. def shufflenet(x, convolve, c1, c2):
  71. # shufflenet reordering example
  72. x = convolve(x)
  73. x = rearrange(x, "b (c1 c2) h w-> b (c2 c1) h w", c1=c1, c2=c2)
  74. x = convolve(x)
  75. return x
  76. def convolve_strided_1d(x, stride, usual_convolution):
  77. x = rearrange(x, "b c t1 t2 -> b c (t1 t2)") # reduce dimensionality
  78. x = rearrange(x, "b c (t stride) -> (stride b) c t", stride=stride)
  79. x = usual_convolution(x)
  80. x = rearrange(x, "(stride b) c t -> b c (t stride)", stride=stride)
  81. return x
  82. def convolve_strided_2d(x, h_stride, w_stride, usual_convolution):
  83. x = rearrange(x, "b c (h hs) (w ws) -> (hs ws b) c h w", hs=h_stride, ws=w_stride)
  84. x = usual_convolution(x)
  85. x = rearrange(x, "(hs ws b) c h w -> b c (h hs) (w ws)", hs=h_stride, ws=w_stride)
  86. return x
  87. def unet_like_1d(x, usual_convolution):
  88. # u-net like steps for increasing / reducing dimensionality
  89. x = rearrange(x, "b c t1 t2 -> b c (t1 t2)") # reduce dimensionality
  90. y = rearrange(x, "b c (t dt) -> b (dt c) t", dt=2)
  91. y = usual_convolution(y)
  92. x = x + rearrange(y, "b (dt c) t -> b c (t dt)", dt=2)
  93. return x
  94. # mock for convolution (works for all backends)
  95. def convolve_mock(x):
  96. return x
  97. tests = [
  98. test1,
  99. test2,
  100. test3,
  101. test4,
  102. test5,
  103. test6,
  104. test7,
  105. test8,
  106. test9,
  107. test10,
  108. test11,
  109. lambda x: shufflenet(x, convolve=convolve_mock, c1=4, c2=5),
  110. lambda x: convolve_strided_1d(x, stride=2, usual_convolution=convolve_mock),
  111. lambda x: convolve_strided_2d(x, h_stride=2, w_stride=2, usual_convolution=convolve_mock),
  112. lambda x: unet_like_1d(x, usual_convolution=convolve_mock),
  113. ]
  114. for backend in imp_op_backends:
  115. print("testing source_examples for ", backend.framework_name)
  116. for test in tests:
  117. x = np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40])
  118. result1 = test(x)
  119. result2 = backend.to_numpy(test(backend.from_numpy(x)))
  120. assert np.array_equal(result1, result2)
  121. # now with strides
  122. x = np.arange(10 * 2 * 20 * 3 * 30 * 1 * 40).reshape([10 * 2, 20 * 3, 30 * 1, 40 * 1])
  123. # known torch bug - torch doesn't support negative steps
  124. last_step = -1 if (backend.framework_name != "torch" and backend.framework_name != "oneflow") else 1
  125. indexing_expression = np.index_exp[::2, ::3, ::1, ::last_step]
  126. result1 = test(x[indexing_expression])
  127. result2 = backend.to_numpy(test(backend.from_numpy(x)[indexing_expression]))
  128. assert np.array_equal(result1, result2)
  129. def tensor_train_example_numpy():
  130. # kept here just for a collection, only tested for numpy
  131. # https://arxiv.org/pdf/1509.06569.pdf, (5)
  132. x = np.ones([3, 4, 5, 6])
  133. rank = 4
  134. if np.__version__ < "1.15.0":
  135. # numpy.einsum fails here, skip test
  136. return
  137. # creating appropriate Gs
  138. Gs = [np.ones([d, d, rank, rank]) for d in x.shape]
  139. Gs[0] = Gs[0][:, :, :1, :]
  140. Gs[-1] = Gs[-1][:, :, :, :1]
  141. # einsum way
  142. y = x.reshape((1, *x.shape))
  143. for G in Gs:
  144. # taking partial results left-to-right
  145. # y = numpy.einsum('i j alpha beta, alpha i ... -> beta ... j', G, y)
  146. y = np.einsum("i j a b, a i ... -> b ... j", G, y)
  147. y1 = y.reshape(-1)
  148. # alternative way
  149. y = x.reshape(-1)
  150. for G in Gs:
  151. i, j, alpha, beta = G.shape
  152. y = rearrange(y, "(i rest alpha) -> rest (alpha i)", alpha=alpha, i=i)
  153. y = y @ rearrange(G, "i j alpha beta -> (alpha i) (j beta)")
  154. y = rearrange(y, "rest (beta j) -> (beta rest j)", beta=beta, j=j)
  155. y2 = y
  156. assert np.allclose(y1, y2)
  157. # yet another way
  158. y = x
  159. for G in Gs:
  160. i, j, alpha, beta = G.shape
  161. y = rearrange(y, "i ... (j alpha) -> ... j (alpha i)", alpha=alpha, i=i)
  162. y = y @ rearrange(G, "i j alpha beta -> (alpha i) (j beta)")
  163. y3 = y.reshape(-1)
  164. assert np.allclose(y1, y3)
  165. def test_pytorch_yolo_fragment():
  166. if not is_backend_tested("torch"):
  167. pytest.skip()
  168. import torch
  169. def old_way(tensor, num_classes, num_anchors, anchors, stride_h, stride_w):
  170. # https://github.com/BobLiu20/YOLOv3_PyTorch/blob/c6b483743598b5f64d520d81e7e5f47ba936d4c9/nets/yolo_loss.py#L28-L44
  171. bs = tensor.size(0)
  172. in_h = tensor.size(2)
  173. in_w = tensor.size(3)
  174. scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in anchors]
  175. prediction = tensor.view(bs, num_anchors, 5 + num_classes, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
  176. # Get outputs
  177. x = torch.sigmoid(prediction[..., 0]) # Center x
  178. y = torch.sigmoid(prediction[..., 1]) # Center y
  179. w = prediction[..., 2] # Width
  180. h = prediction[..., 3] # Height
  181. conf = torch.sigmoid(prediction[..., 4]) # Conf
  182. pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.
  183. # https://github.com/BobLiu20/YOLOv3_PyTorch/blob/c6b483743598b5f64d520d81e7e5f47ba936d4c9/nets/yolo_loss.py#L70-L92
  184. FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
  185. LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
  186. # Calculate offsets for each grid
  187. grid_x = (
  188. torch.linspace(0, in_w - 1, in_w)
  189. .repeat(in_w, 1)
  190. .repeat(bs * num_anchors, 1, 1)
  191. .view(x.shape)
  192. .type(FloatTensor)
  193. )
  194. grid_y = (
  195. torch.linspace(0, in_h - 1, in_h)
  196. .repeat(in_h, 1)
  197. .t()
  198. .repeat(bs * num_anchors, 1, 1)
  199. .view(y.shape)
  200. .type(FloatTensor)
  201. )
  202. # Calculate anchor w, h
  203. anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
  204. anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
  205. anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
  206. anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
  207. # Add offset and scale with anchors
  208. pred_boxes = FloatTensor(prediction[..., :4].shape)
  209. pred_boxes[..., 0] = x.data + grid_x
  210. pred_boxes[..., 1] = y.data + grid_y
  211. pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
  212. pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
  213. # Results
  214. _scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor)
  215. output = torch.cat(
  216. (pred_boxes.view(bs, -1, 4) * _scale, conf.view(bs, -1, 1), pred_cls.view(bs, -1, num_classes)), -1
  217. )
  218. return output
  219. def new_way(tensor, num_classes, num_anchors, anchors, stride_h, stride_w):
  220. raw_predictions = rearrange(tensor, " b (anchor prediction) h w -> prediction b anchor h w", anchor=num_anchors)
  221. anchors = torch.FloatTensor(anchors).to(tensor.device)
  222. anchor_sizes = rearrange(anchors, "anchor dim -> dim () anchor () ()")
  223. _, _, _, in_h, in_w = raw_predictions.shape
  224. grid_h = rearrange(torch.arange(in_h).float(), "h -> () () h ()").to(tensor.device)
  225. grid_w = rearrange(torch.arange(in_w).float(), "w -> () () () w").to(tensor.device)
  226. predicted_bboxes = torch.zeros_like(raw_predictions)
  227. predicted_bboxes[0] = (raw_predictions[0].sigmoid() + grid_h) * stride_h # center y
  228. predicted_bboxes[1] = (raw_predictions[1].sigmoid() + grid_w) * stride_w # center x
  229. predicted_bboxes[2:4] = (raw_predictions[2:4].exp()) * anchor_sizes # bbox width and height
  230. predicted_bboxes[4] = raw_predictions[4].sigmoid() # confidence
  231. predicted_bboxes[5:] = raw_predictions[5:].sigmoid() # class predictions
  232. # only to match results of original code, not needed
  233. return rearrange(predicted_bboxes, "prediction b anchor h w -> b anchor h w prediction")
  234. stride_h = 4
  235. stride_w = 4
  236. batch_size = 5
  237. num_classes = 12
  238. anchors = [[50, 100], [100, 50], [75, 75]]
  239. num_anchors = len(anchors)
  240. x = torch.randn([batch_size, num_anchors * (5 + num_classes), 1, 1])
  241. result1 = old_way(
  242. tensor=x,
  243. num_anchors=num_anchors,
  244. num_classes=num_classes,
  245. stride_h=stride_h,
  246. stride_w=stride_w,
  247. anchors=anchors,
  248. )
  249. result2 = new_way(
  250. tensor=x,
  251. num_anchors=num_anchors,
  252. num_classes=num_classes,
  253. stride_h=stride_h,
  254. stride_w=stride_w,
  255. anchors=anchors,
  256. )
  257. result1 = result1.reshape(result2.shape)
  258. assert torch.allclose(result1, result2)