einops.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939
  1. import functools
  2. import itertools
  3. import string
  4. import typing
  5. from collections import OrderedDict
  6. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast, overload
  7. if typing.TYPE_CHECKING:
  8. # for docstrings in pycharm
  9. import numpy as np # noqa E401
  10. from . import EinopsError
  11. from ._backends import get_backend
  12. from .parsing import AnonymousAxis, ParsedExpression, _ellipsis
  13. Tensor = TypeVar("Tensor")
  14. ReductionCallable = Callable[[Tensor, Tuple[int, ...]], Tensor]
  15. Reduction = Union[str, ReductionCallable]
  16. Size = typing.Any
  17. _reductions = ("min", "max", "sum", "mean", "prod", "any", "all")
  18. # magic integers are required to stay within
  19. # traceable subset of language
  20. _unknown_axis_length = -999999
  21. _expected_axis_length = -99999
  22. def _product(sequence: List[int]) -> int:
  23. """minimalistic product that works both with numbers and symbols. Supports empty lists"""
  24. result = 1
  25. for element in sequence:
  26. result *= element
  27. return result
  28. def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
  29. if callable(reduction_type):
  30. # custom callable
  31. return reduction_type(tensor, tuple(reduced_axes))
  32. else:
  33. # one of built-in operations
  34. assert reduction_type in _reductions
  35. if reduction_type == "mean":
  36. if not backend.is_float_type(tensor):
  37. raise NotImplementedError("reduce_mean is not available for non-floating tensors")
  38. return backend.reduce(tensor, reduction_type, tuple(reduced_axes))
  39. def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
  40. # 'collapses' neighboring axes if those participate in the result pattern in the same order
  41. # TODO add support for added_axes
  42. assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
  43. # joining consecutive axes that will be reduced
  44. # possibly we can skip this if all backends can optimize this (not sure)
  45. reduced_axes = tuple(sorted(reduced_axes))
  46. for i in range(len(reduced_axes) - 1)[::-1]:
  47. if reduced_axes[i] + 1 == reduced_axes[i + 1]:
  48. removed_axis = reduced_axes[i + 1]
  49. removed_length = init_shapes[removed_axis]
  50. init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1 :]
  51. init_shapes[removed_axis - 1] *= removed_length
  52. reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2 :])
  53. # removing axes that are moved together during reshape
  54. def build_mapping():
  55. init_to_final = {}
  56. for axis in range(len(init_shapes)):
  57. if axis in reduced_axes:
  58. init_to_final[axis] = None
  59. else:
  60. after_reduction = sum(x is not None for x in init_to_final.values())
  61. init_to_final[axis] = list(axes_reordering).index(after_reduction)
  62. return init_to_final
  63. init_axis_to_final_axis = build_mapping()
  64. for init_axis in range(len(init_shapes) - 1)[::-1]:
  65. if init_axis_to_final_axis[init_axis] is None:
  66. continue
  67. if init_axis_to_final_axis[init_axis + 1] is None:
  68. continue
  69. if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
  70. removed_axis = init_axis + 1
  71. removed_length = init_shapes[removed_axis]
  72. removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
  73. reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
  74. init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1 :]
  75. init_shapes[removed_axis - 1] *= removed_length
  76. old_reordering = axes_reordering
  77. axes_reordering = []
  78. for axis in old_reordering:
  79. if axis == removed_axis_after_reduction:
  80. pass
  81. elif axis < removed_axis_after_reduction:
  82. axes_reordering.append(axis)
  83. else:
  84. axes_reordering.append(axis - 1)
  85. init_axis_to_final_axis = build_mapping()
  86. return init_shapes, reduced_axes, axes_reordering, final_shapes
  87. CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int]
  88. # Actual type is tuple[tuple[str, int], ...]
  89. # However torch.jit.script does not "understand" the correct type,
  90. # and torch_specific will use list version.
  91. HashableAxesLengths = Tuple[Tuple[str, int], ...]
  92. FakeHashableAxesLengths = List[Tuple[str, int]]
  93. class TransformRecipe:
  94. """
  95. Recipe describes actual computation pathway.
  96. Recipe can be applied to a tensor or variable.
  97. """
  98. # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)
  99. # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided
  100. def __init__(
  101. self,
  102. # list of sizes (or just sizes) for elementary axes as they appear in left expression.
  103. # this is what (after computing unknown parts) will be a shape after first transposition.
  104. # This does not include any ellipsis dimensions.
  105. elementary_axes_lengths: List[int],
  106. # if additional axes are provided, they should be set in prev array
  107. # This shows mapping from name to position
  108. axis_name2elementary_axis: Dict[str, int],
  109. # each dimension in input can help to reconstruct length of one elementary axis
  110. # or verify one of dimensions. Each element points to element of elementary_axes_lengths.
  111. input_composition_known_unknown: List[Tuple[List[int], List[int]]],
  112. # permutation applied to elementary axes, if ellipsis is absent
  113. axes_permutation: List[int],
  114. # permutation puts reduced axes in the end, we only need to know the first position.
  115. first_reduced_axis: int,
  116. # at which positions which of elementary axes should appear. Axis position -> axis index.
  117. added_axes: Dict[int, int],
  118. # ids of axes as they appear in result, again pointers to elementary_axes_lengths,
  119. # only used to infer result dimensions
  120. output_composite_axes: List[List[int]],
  121. ):
  122. self.elementary_axes_lengths: List[int] = elementary_axes_lengths
  123. self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis
  124. self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown
  125. self.axes_permutation: List[int] = axes_permutation
  126. self.first_reduced_axis: int = first_reduced_axis
  127. self.added_axes: Dict[int, int] = added_axes
  128. self.output_composite_axes: List[List[int]] = output_composite_axes
  129. def _reconstruct_from_shape_uncached(
  130. self: TransformRecipe, shape: List[int], axes_dims: FakeHashableAxesLengths
  131. ) -> CookedRecipe:
  132. """
  133. Reconstruct all actual parameters using shape.
  134. Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet)
  135. known axes can be integers or symbols, but not Nones.
  136. """
  137. # magic number
  138. need_init_reshape = False
  139. # last axis is allocated for collapsed ellipsis
  140. axes_lengths: List[int] = list(self.elementary_axes_lengths)
  141. for axis, dim in axes_dims:
  142. axes_lengths[self.axis_name2elementary_axis[axis]] = dim
  143. for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown):
  144. length = shape[input_axis]
  145. if len(known_axes) == 0 and len(unknown_axes) == 1:
  146. # shortcut for the most common case
  147. axes_lengths[unknown_axes[0]] = length
  148. continue
  149. known_product = 1
  150. for axis in known_axes:
  151. known_product *= axes_lengths[axis]
  152. if len(unknown_axes) == 0:
  153. if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
  154. raise EinopsError(f"Shape mismatch, {length} != {known_product}")
  155. else:
  156. # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out'
  157. if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
  158. raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}")
  159. unknown_axis = unknown_axes[0]
  160. inferred_length: int = length // known_product
  161. axes_lengths[unknown_axis] = inferred_length
  162. if len(known_axes) + len(unknown_axes) != 1:
  163. need_init_reshape = True
  164. # at this point all axes_lengths are computed (either have values or variables, but not Nones)
  165. # elementary axes are ordered as they appear in input, then all added axes
  166. init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None
  167. need_final_reshape = False
  168. final_shapes: List[int] = []
  169. for grouping in self.output_composite_axes:
  170. lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
  171. final_shapes.append(_product(lengths))
  172. if len(lengths) != 1:
  173. need_final_reshape = True
  174. added_axes: Dict[int, int] = {
  175. pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()
  176. }
  177. # this list can be empty
  178. reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation)))
  179. n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation)
  180. axes_reordering: Optional[List[int]] = self.axes_permutation
  181. if self.axes_permutation == list(range(len(self.axes_permutation))):
  182. axes_reordering = None
  183. _final_shapes = final_shapes if need_final_reshape else None
  184. return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes
  185. _reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)
  186. def _apply_recipe(
  187. backend, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths
  188. ) -> Tensor:
  189. # this method implements actual work for all backends for 3 operations
  190. try:
  191. init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
  192. recipe, backend.shape(tensor), axes_lengths
  193. )
  194. except TypeError:
  195. # shape or one of passed axes lengths is not hashable (i.e. they are symbols)
  196. _result = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_lengths)
  197. (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result
  198. if init_shapes is not None:
  199. tensor = backend.reshape(tensor, init_shapes)
  200. if axes_reordering is not None:
  201. tensor = backend.transpose(tensor, axes_reordering)
  202. if len(reduced_axes) > 0:
  203. tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend)
  204. if len(added_axes) > 0:
  205. tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes)
  206. if final_shapes is not None:
  207. tensor = backend.reshape(tensor, final_shapes)
  208. return tensor
  209. def _apply_recipe_array_api(
  210. xp, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths
  211. ) -> Tensor:
  212. # completely-inline implementation
  213. init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
  214. recipe, tensor.shape, axes_lengths
  215. )
  216. if init_shapes is not None:
  217. tensor = xp.reshape(tensor, init_shapes)
  218. if axes_reordering is not None:
  219. tensor = xp.permute_dims(tensor, axes_reordering)
  220. if len(reduced_axes) > 0:
  221. if callable(reduction_type):
  222. # custom callable
  223. tensor = reduction_type(tensor, tuple(reduced_axes))
  224. else:
  225. # one of built-in operations
  226. assert reduction_type in _reductions
  227. tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes))
  228. if len(added_axes) > 0:
  229. # we use broadcasting
  230. for axis_position, _axis_length in added_axes.items():
  231. tensor = xp.expand_dims(tensor, axis=axis_position)
  232. final_shape = list(tensor.shape)
  233. for axis_position, axis_length in added_axes.items():
  234. final_shape[axis_position] = axis_length
  235. tensor = xp.broadcast_to(tensor, final_shape)
  236. if final_shapes is not None:
  237. tensor = xp.reshape(tensor, final_shapes)
  238. return tensor
  239. @functools.lru_cache(256)
  240. def _prepare_transformation_recipe(
  241. pattern: str,
  242. operation: Reduction,
  243. axes_names: Tuple[str, ...],
  244. ndim: int,
  245. ) -> TransformRecipe:
  246. """Perform initial parsing of pattern and provided supplementary info
  247. axes_lengths is a tuple of tuples (axis_name, axis_length)
  248. """
  249. left_str, rght_str = pattern.split("->")
  250. left = ParsedExpression(left_str)
  251. rght = ParsedExpression(rght_str)
  252. # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
  253. if not left.has_ellipsis and rght.has_ellipsis:
  254. raise EinopsError(f"Ellipsis found in right side, but not left side of a pattern {pattern}")
  255. if left.has_ellipsis and left.has_ellipsis_parenthesized:
  256. raise EinopsError(f"Ellipsis inside parenthesis in the left side is not allowed: {pattern}")
  257. if operation == "rearrange":
  258. if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
  259. raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)")
  260. difference = set.symmetric_difference(left.identifiers, rght.identifiers)
  261. if len(difference) > 0:
  262. raise EinopsError(f"Identifiers only on one side of expression (should be on both): {difference}")
  263. elif operation == "repeat":
  264. difference = set.difference(left.identifiers, rght.identifiers)
  265. if len(difference) > 0:
  266. raise EinopsError(f"Unexpected identifiers on the left side of repeat: {difference}")
  267. axes_without_size = set.difference(
  268. {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
  269. {*left.identifiers, *axes_names},
  270. )
  271. if len(axes_without_size) > 0:
  272. raise EinopsError(f"Specify sizes for new axes in repeat: {axes_without_size}")
  273. elif operation in _reductions or callable(operation):
  274. difference = set.difference(rght.identifiers, left.identifiers)
  275. if len(difference) > 0:
  276. raise EinopsError(f"Unexpected identifiers on the right side of reduce {operation}: {difference}")
  277. else:
  278. raise EinopsError(f"Unknown reduction {operation}. Expect one of {_reductions}.")
  279. if left.has_ellipsis:
  280. n_other_dims = len(left.composition) - 1
  281. if ndim < n_other_dims:
  282. raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.")
  283. ellipsis_ndim = ndim - n_other_dims
  284. ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)]
  285. left_composition = []
  286. for composite_axis in left.composition:
  287. if composite_axis == _ellipsis:
  288. for axis in ell_axes:
  289. left_composition.append([axis])
  290. else:
  291. left_composition.append(composite_axis)
  292. rght_composition = []
  293. for composite_axis in rght.composition:
  294. if composite_axis == _ellipsis:
  295. for axis in ell_axes:
  296. rght_composition.append([axis])
  297. else:
  298. group = []
  299. for axis in composite_axis:
  300. if axis == _ellipsis:
  301. group.extend(ell_axes)
  302. else:
  303. group.append(axis)
  304. rght_composition.append(group)
  305. left.identifiers.update(ell_axes)
  306. left.identifiers.remove(_ellipsis)
  307. if rght.has_ellipsis:
  308. rght.identifiers.update(ell_axes)
  309. rght.identifiers.remove(_ellipsis)
  310. else:
  311. if ndim != len(left.composition):
  312. raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.")
  313. left_composition = left.composition
  314. rght_composition = rght.composition
  315. # parsing all dimensions to find out lengths
  316. axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict()
  317. for composite_axis in left_composition:
  318. for axis_name in composite_axis:
  319. if isinstance(axis_name, AnonymousAxis):
  320. axis_name2known_length[axis_name] = axis_name.value
  321. else:
  322. axis_name2known_length[axis_name] = _unknown_axis_length
  323. # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point
  324. repeat_axes_names = []
  325. for axis_name in rght.identifiers:
  326. if axis_name not in axis_name2known_length:
  327. if isinstance(axis_name, AnonymousAxis):
  328. axis_name2known_length[axis_name] = axis_name.value
  329. else:
  330. axis_name2known_length[axis_name] = _unknown_axis_length
  331. repeat_axes_names.append(axis_name)
  332. axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
  333. # axes provided as kwargs
  334. for elementary_axis in axes_names:
  335. if not ParsedExpression.check_axis_name(elementary_axis):
  336. raise EinopsError("Invalid name for an axis", elementary_axis)
  337. if elementary_axis not in axis_name2known_length:
  338. raise EinopsError(f"Axis {elementary_axis} is not used in transform")
  339. axis_name2known_length[elementary_axis] = _expected_axis_length
  340. input_axes_known_unknown = []
  341. # some shapes are inferred later - all information is prepared for faster inference
  342. for composite_axis in left_composition:
  343. known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
  344. unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
  345. if len(unknown) > 1:
  346. raise EinopsError(f"Could not infer sizes for {unknown}")
  347. assert len(unknown) + len(known) == len(composite_axis)
  348. input_axes_known_unknown.append(
  349. ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown])
  350. )
  351. axis_position_after_reduction: Dict[str, int] = {}
  352. for axis_name in itertools.chain(*left_composition):
  353. if axis_name in rght.identifiers:
  354. axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
  355. result_axes_grouping: List[List[int]] = [
  356. [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition)
  357. ]
  358. ordered_axis_left = list(itertools.chain(*left_composition))
  359. ordered_axis_rght = list(itertools.chain(*rght_composition))
  360. reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers]
  361. order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes
  362. axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition]
  363. added_axes = {
  364. i: axis_name2position[axis_name]
  365. for i, axis_name in enumerate(ordered_axis_rght)
  366. if axis_name not in left.identifiers
  367. }
  368. first_reduced_axis = len(order_after_transposition) - len(reduced_axes)
  369. return TransformRecipe(
  370. elementary_axes_lengths=list(axis_name2known_length.values()),
  371. axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names},
  372. input_composition_known_unknown=input_axes_known_unknown,
  373. axes_permutation=axes_permutation,
  374. first_reduced_axis=first_reduced_axis,
  375. added_axes=added_axes,
  376. output_composite_axes=result_axes_grouping,
  377. )
  378. def _prepare_recipes_for_all_dims(
  379. pattern: str, operation: Reduction, axes_names: Tuple[str, ...]
  380. ) -> Dict[int, TransformRecipe]:
  381. """
  382. Internal function, used in layers.
  383. Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims
  384. """
  385. left_str, rght_str = pattern.split("->")
  386. left = ParsedExpression(left_str)
  387. dims = [len(left.composition)]
  388. if left.has_ellipsis:
  389. dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)]
  390. return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims}
  391. @overload
  392. def reduce(tensor: List[Tensor], pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor: ...
  393. @overload
  394. def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor: ...
  395. def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor:
  396. """
  397. einops.reduce combines rearrangement and reduction using reader-friendly notation.
  398. Some examples:
  399. ```python
  400. >>> x = np.random.randn(100, 32, 64)
  401. # perform max-reduction on the first axis
  402. # Axis t does not appear on RHS - thus we reduced over t
  403. >>> y = reduce(x, 't b c -> b c', 'max')
  404. # same as previous, but using verbose names for axes
  405. >>> y = reduce(x, 'time batch channel -> batch channel', 'max')
  406. # let's pretend now that x is a batch of images
  407. # with 4 dims: batch=10, height=20, width=30, channel=40
  408. >>> x = np.random.randn(10, 20, 30, 40)
  409. # 2d max-pooling with kernel size = 2 * 2 for image processing
  410. >>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
  411. # same as previous, using anonymous axes,
  412. # note: only reduced axes can be anonymous
  413. >>> y1 = reduce(x, 'b c (h1 2) (w1 2) -> b c h1 w1', 'max')
  414. # adaptive 2d max-pooling to 3 * 4 grid,
  415. # each element is max of 10x10 tile in the original tensor.
  416. >>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
  417. (10, 20, 3, 4)
  418. # Global average pooling
  419. >>> reduce(x, 'b c h w -> b c', 'mean').shape
  420. (10, 20)
  421. # subtracting mean over batch for each channel;
  422. # similar to x - np.mean(x, axis=(0, 2, 3), keepdims=True)
  423. >>> y = x - reduce(x, 'b c h w -> 1 c 1 1', 'mean')
  424. # Subtracting per-image mean for each channel
  425. >>> y = x - reduce(x, 'b c h w -> b c 1 1', 'mean')
  426. # same as previous, but using empty compositions
  427. >>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
  428. ```
  429. Parameters:
  430. tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
  431. list of tensors is also accepted, those should be of the same type and shape
  432. pattern: string, reduction pattern
  433. reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod', 'any', 'all').
  434. Alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
  435. This allows using various reductions like: np.max, np.nanmean, tf.reduce_logsumexp, torch.var, etc.
  436. axes_lengths: any additional specifications for dimensions
  437. Returns:
  438. tensor of the same type as input
  439. """
  440. try:
  441. if isinstance(tensor, list):
  442. if len(tensor) == 0:
  443. raise TypeError("Rearrange/Reduce/Repeat can't be applied to an empty list")
  444. backend = get_backend(tensor[0])
  445. tensor = backend.stack_on_zeroth_dimension(tensor)
  446. else:
  447. backend = get_backend(tensor)
  448. hashable_axes_lengths = tuple(axes_lengths.items())
  449. shape = backend.shape(tensor)
  450. recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape))
  451. return _apply_recipe(
  452. backend, recipe, cast(Tensor, tensor), reduction_type=reduction, axes_lengths=hashable_axes_lengths
  453. )
  454. except EinopsError as e:
  455. message = f' Error while processing {reduction}-reduction pattern "{pattern}".'
  456. if not isinstance(tensor, list):
  457. message += f"\n Input tensor shape: {shape}. "
  458. else:
  459. message += "\n Input is list. "
  460. message += f"Additional info: {axes_lengths}."
  461. raise EinopsError(message + f"\n {e}") from None
  462. @overload
  463. def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: Size) -> Tensor: ...
  464. @overload
  465. def rearrange(tensor: Tensor, pattern: str, **axes_lengths: Size) -> Tensor: ...
  466. def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths: Size) -> Tensor:
  467. """
  468. einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
  469. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
  470. stack, concatenate and other operations.
  471. Examples:
  472. ```python
  473. # suppose we have a set of 32 images in "h w c" format (height-width-channel)
  474. >>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
  475. # stack along first (batch) axis, output is a single array
  476. >>> rearrange(images, 'b h w c -> b h w c').shape
  477. (32, 30, 40, 3)
  478. # stacked and reordered axes to "b c h w" format
  479. >>> rearrange(images, 'b h w c -> b c h w').shape
  480. (32, 3, 30, 40)
  481. # concatenate images along height (vertical axis), 960 = 32 * 30
  482. >>> rearrange(images, 'b h w c -> (b h) w c').shape
  483. (960, 40, 3)
  484. # concatenated images along horizontal axis, 1280 = 32 * 40
  485. >>> rearrange(images, 'b h w c -> h (b w) c').shape
  486. (30, 1280, 3)
  487. # flattened each image into a vector, 3600 = 30 * 40 * 3
  488. >>> rearrange(images, 'b h w c -> b (c h w)').shape
  489. (32, 3600)
  490. # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
  491. >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
  492. (128, 15, 20, 3)
  493. # space-to-depth operation
  494. >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
  495. (32, 15, 20, 12)
  496. ```
  497. When composing axes, C-order enumeration used (consecutive elements have different last axis).
  498. Find more examples in einops tutorial.
  499. Parameters:
  500. tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
  501. list of tensors is also accepted, those should be of the same type and shape
  502. pattern: string, rearrangement pattern
  503. axes_lengths: any additional specifications for dimensions
  504. Returns:
  505. tensor of the same type as input. If possible, a view to the original tensor is returned.
  506. """
  507. return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
  508. @overload
  509. def repeat(tensor: List[Tensor], pattern: str, **axes_lengths: Size) -> Tensor: ...
  510. @overload
  511. def repeat(tensor: Tensor, pattern: str, **axes_lengths: Size) -> Tensor: ...
  512. def repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths: Size) -> Tensor:
  513. """
  514. einops.repeat allows reordering elements and repeating them in arbitrary combinations.
  515. This operation includes functionality of repeat, tile, and broadcast functions.
  516. Examples for repeat operation:
  517. ```python
  518. # a grayscale image (of shape height x width)
  519. >>> image = np.random.randn(30, 40)
  520. # change it to RGB format by repeating in each channel
  521. >>> repeat(image, 'h w -> h w c', c=3).shape
  522. (30, 40, 3)
  523. # repeat image 2 times along height (vertical axis)
  524. >>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape
  525. (60, 40)
  526. # repeat image 2 time along height and 3 times along width
  527. >>> repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape
  528. (60, 120)
  529. # convert each pixel to a small square 2x2, i.e. upsample an image by 2x
  530. >>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
  531. (60, 80)
  532. # 'pixelate' an image first by downsampling by 2x, then upsampling
  533. >>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
  534. >>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
  535. (30, 40)
  536. ```
  537. When composing axes, C-order enumeration used (consecutive elements have different last axis).
  538. Find more examples in einops tutorial.
  539. Parameters:
  540. tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
  541. list of tensors is also accepted, those should be of the same type and shape
  542. pattern: string, rearrangement pattern
  543. axes_lengths: any additional specifications for dimensions
  544. Returns:
  545. Tensor of the same type as input. If possible, a view to the original tensor is returned.
  546. """
  547. return reduce(tensor, pattern, reduction="repeat", **axes_lengths)
  548. def parse_shape(x: Tensor, pattern: str) -> dict:
  549. """
  550. Parse a tensor shape to dictionary mapping axes names to their lengths.
  551. ```python
  552. # Use underscore to skip the dimension in parsing.
  553. >>> x = np.zeros([2, 3, 5, 7])
  554. >>> parse_shape(x, 'batch _ h w')
  555. {'batch': 2, 'h': 5, 'w': 7}
  556. # `parse_shape` output can be used to specify axes_lengths for other operations:
  557. >>> y = np.zeros([700])
  558. >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
  559. (2, 10, 5, 7)
  560. ```
  561. For symbolic frameworks may return symbols, not integers.
  562. Parameters:
  563. x: tensor of any supported framework
  564. pattern: str, space separated names for axes, underscore means skip axis
  565. Returns:
  566. dict, maps axes names to their lengths
  567. """
  568. exp = ParsedExpression(pattern, allow_underscore=True)
  569. shape = get_backend(x).shape(x)
  570. if exp.has_composed_axes():
  571. raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
  572. if len(shape) != len(exp.composition):
  573. if exp.has_ellipsis:
  574. if len(shape) < len(exp.composition) - 1:
  575. raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
  576. else:
  577. raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
  578. if exp.has_ellipsis:
  579. ellipsis_idx = exp.composition.index(_ellipsis)
  580. composition = (
  581. exp.composition[:ellipsis_idx]
  582. + ["_"] * (len(shape) - len(exp.composition) + 1)
  583. + exp.composition[ellipsis_idx + 1 :]
  584. )
  585. else:
  586. composition = exp.composition
  587. result = {}
  588. for axes, axis_length in zip(composition, shape): # type: ignore
  589. # axes either [], or [AnonymousAxis] or ['axis_name']
  590. if len(axes) == 0:
  591. if axis_length != 1:
  592. raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
  593. else:
  594. [axis] = axes
  595. if isinstance(axis, str):
  596. if axis != "_":
  597. result[axis] = axis_length
  598. else:
  599. if axis.value != axis_length:
  600. raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
  601. return result
  602. # _enumerate_directions is not exposed in the public API
  603. def _enumerate_directions(x):
  604. """
  605. For an n-dimensional tensor, returns tensors to enumerate each axis.
  606. ```python
  607. x = np.zeros([2, 3, 4]) # or any other tensor
  608. i, j, k = _enumerate_directions(x)
  609. result = i + 2*j + 3*k
  610. ```
  611. `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result
  612. Works very similarly to numpy.ogrid (open indexing grid)
  613. """
  614. backend = get_backend(x)
  615. shape = backend.shape(x)
  616. result = []
  617. for axis_id, axis_length in enumerate(shape):
  618. shape = [1] * len(shape)
  619. shape[axis_id] = axis_length
  620. result.append(backend.reshape(backend.arange(0, axis_length), shape))
  621. return result
  622. # to avoid importing numpy
  623. np_ndarray = Any
  624. def asnumpy(tensor: Tensor) -> np_ndarray:
  625. """
  626. Convert a tensor of an imperative framework (i.e. numpy/cupy/torch/jax/etc.) to `numpy.ndarray`
  627. Parameters:
  628. tensor: tensor of any known imperative framework
  629. Returns:
  630. `numpy.ndarray`, converted to numpy
  631. """
  632. return get_backend(tensor).to_numpy(tensor)
  633. def _validate_einsum_axis_name(axis_name):
  634. if len(axis_name) == 0:
  635. raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
  636. if len(axis_name) > 1:
  637. raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
  638. axis_name = axis_name[0]
  639. if isinstance(axis_name, AnonymousAxis):
  640. raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
  641. if len(axis_name) == 0:
  642. raise RuntimeError("Encountered empty axis name in einsum.")
  643. if not isinstance(axis_name, str):
  644. raise RuntimeError("Axis name in einsum must be a string.")
  645. @functools.lru_cache(256)
  646. def _compactify_pattern_for_einsum(pattern: str) -> str:
  647. if "->" not in pattern:
  648. # numpy allows this, so make sure users
  649. # don't accidentally do something like this.
  650. raise ValueError("Einsum pattern must contain '->'.")
  651. lefts_str, right_str = pattern.split("->")
  652. lefts = [ParsedExpression(left, allow_underscore=True, allow_duplicates=True) for left in lefts_str.split(",")]
  653. right = ParsedExpression(right_str, allow_underscore=True)
  654. # Start from 'a' and go up to 'Z'
  655. output_axis_names = string.ascii_letters
  656. i = 0
  657. axis_name_mapping = {}
  658. left_patterns = []
  659. for left in lefts:
  660. left_pattern = ""
  661. for raw_axis_name in left.composition:
  662. if raw_axis_name == _ellipsis:
  663. left_pattern += "..."
  664. continue
  665. _validate_einsum_axis_name(raw_axis_name)
  666. axis_name = raw_axis_name[0]
  667. if axis_name not in axis_name_mapping:
  668. if i >= len(output_axis_names):
  669. raise RuntimeError("Too many axes in einsum.")
  670. axis_name_mapping[axis_name] = output_axis_names[i]
  671. i += 1
  672. left_pattern += axis_name_mapping[axis_name]
  673. left_patterns.append(left_pattern)
  674. compact_pattern = ",".join(left_patterns) + "->"
  675. for raw_axis_name in right.composition:
  676. if raw_axis_name == _ellipsis:
  677. compact_pattern += "..."
  678. continue
  679. _validate_einsum_axis_name(raw_axis_name)
  680. axis_name = raw_axis_name[0]
  681. if axis_name not in axis_name_mapping:
  682. raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.")
  683. compact_pattern += axis_name_mapping[axis_name]
  684. return compact_pattern
  685. @typing.overload
  686. def einsum(tensor: Tensor, pattern: str, /) -> Tensor: ...
  687. @typing.overload
  688. def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str, /) -> Tensor: ...
  689. @typing.overload
  690. def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str, /) -> Tensor: ...
  691. @typing.overload
  692. def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str, /) -> Tensor: ...
  693. def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor:
  694. r"""
  695. einops.einsum calls einsum operations with einops-style named
  696. axes indexing, computing tensor products with an arbitrary
  697. number of tensors. Unlike typical einsum syntax, here you must
  698. pass tensors first, and then the pattern.
  699. Also, note that rearrange operations such as `"(batch chan) out"`,
  700. or singleton axes `()`, are not currently supported.
  701. Examples:
  702. For a given pattern such as:
  703. ```python
  704. >>> x, y, z = np.random.randn(3, 20, 20, 20)
  705. >>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k")
  706. ```
  707. the following formula is computed:
  708. ```tex
  709. output[a, b, k] = \sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
  710. ```
  711. where the summation over `c`, `d`, and `g` is performed
  712. because those axes names do not appear on the right-hand side.
  713. Let's see some additional examples:
  714. ```python
  715. # Filter a set of images:
  716. >>> batched_images = np.random.randn(128, 16, 16)
  717. >>> filters = np.random.randn(16, 16, 30)
  718. >>> result = einsum(batched_images, filters,
  719. ... "batch h w, h w channel -> batch channel")
  720. >>> result.shape
  721. (128, 30)
  722. # Matrix multiplication, with an unknown input shape:
  723. >>> batch_shape = (50, 30)
  724. >>> data = np.random.randn(*batch_shape, 20)
  725. >>> weights = np.random.randn(10, 20)
  726. >>> result = einsum(weights, data,
  727. ... "out_dim in_dim, ... in_dim -> ... out_dim")
  728. >>> result.shape
  729. (50, 30, 10)
  730. # Matrix trace on a single tensor:
  731. >>> matrix = np.random.randn(10, 10)
  732. >>> result = einsum(matrix, "i i ->")
  733. >>> result.shape
  734. ()
  735. ```
  736. Parameters:
  737. tensors_and_pattern:
  738. tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax).
  739. pattern: string, einsum pattern, with commas
  740. separating specifications for each tensor.
  741. pattern should be provided after all tensors.
  742. Returns:
  743. Tensor of the same type as input, after processing with einsum.
  744. """
  745. if len(tensors_and_pattern) <= 1:
  746. raise ValueError(
  747. "`einops.einsum` takes at minimum two arguments: the tensors (at least one), followed by the pattern."
  748. )
  749. pattern = tensors_and_pattern[-1]
  750. if not isinstance(pattern, str):
  751. raise ValueError(
  752. "The last argument passed to `einops.einsum` must be a string, representing the einsum pattern."
  753. )
  754. tensors = tensors_and_pattern[:-1]
  755. pattern = _compactify_pattern_for_einsum(pattern)
  756. return get_backend(tensors[0]).einsum(pattern, *tensors)