| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939 |
- import functools
- import itertools
- import string
- import typing
- from collections import OrderedDict
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast, overload
- if typing.TYPE_CHECKING:
- # for docstrings in pycharm
- import numpy as np # noqa E401
- from . import EinopsError
- from ._backends import get_backend
- from .parsing import AnonymousAxis, ParsedExpression, _ellipsis
- Tensor = TypeVar("Tensor")
- ReductionCallable = Callable[[Tensor, Tuple[int, ...]], Tensor]
- Reduction = Union[str, ReductionCallable]
- Size = typing.Any
- _reductions = ("min", "max", "sum", "mean", "prod", "any", "all")
- # magic integers are required to stay within
- # traceable subset of language
- _unknown_axis_length = -999999
- _expected_axis_length = -99999
- def _product(sequence: List[int]) -> int:
- """minimalistic product that works both with numbers and symbols. Supports empty lists"""
- result = 1
- for element in sequence:
- result *= element
- return result
- def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
- if callable(reduction_type):
- # custom callable
- return reduction_type(tensor, tuple(reduced_axes))
- else:
- # one of built-in operations
- assert reduction_type in _reductions
- if reduction_type == "mean":
- if not backend.is_float_type(tensor):
- raise NotImplementedError("reduce_mean is not available for non-floating tensors")
- return backend.reduce(tensor, reduction_type, tuple(reduced_axes))
- def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
- # 'collapses' neighboring axes if those participate in the result pattern in the same order
- # TODO add support for added_axes
- assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
- # joining consecutive axes that will be reduced
- # possibly we can skip this if all backends can optimize this (not sure)
- reduced_axes = tuple(sorted(reduced_axes))
- for i in range(len(reduced_axes) - 1)[::-1]:
- if reduced_axes[i] + 1 == reduced_axes[i + 1]:
- removed_axis = reduced_axes[i + 1]
- removed_length = init_shapes[removed_axis]
- init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1 :]
- init_shapes[removed_axis - 1] *= removed_length
- reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2 :])
- # removing axes that are moved together during reshape
- def build_mapping():
- init_to_final = {}
- for axis in range(len(init_shapes)):
- if axis in reduced_axes:
- init_to_final[axis] = None
- else:
- after_reduction = sum(x is not None for x in init_to_final.values())
- init_to_final[axis] = list(axes_reordering).index(after_reduction)
- return init_to_final
- init_axis_to_final_axis = build_mapping()
- for init_axis in range(len(init_shapes) - 1)[::-1]:
- if init_axis_to_final_axis[init_axis] is None:
- continue
- if init_axis_to_final_axis[init_axis + 1] is None:
- continue
- if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
- removed_axis = init_axis + 1
- removed_length = init_shapes[removed_axis]
- removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
- reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
- init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1 :]
- init_shapes[removed_axis - 1] *= removed_length
- old_reordering = axes_reordering
- axes_reordering = []
- for axis in old_reordering:
- if axis == removed_axis_after_reduction:
- pass
- elif axis < removed_axis_after_reduction:
- axes_reordering.append(axis)
- else:
- axes_reordering.append(axis - 1)
- init_axis_to_final_axis = build_mapping()
- return init_shapes, reduced_axes, axes_reordering, final_shapes
- CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int]
- # Actual type is tuple[tuple[str, int], ...]
- # However torch.jit.script does not "understand" the correct type,
- # and torch_specific will use list version.
- HashableAxesLengths = Tuple[Tuple[str, int], ...]
- FakeHashableAxesLengths = List[Tuple[str, int]]
- class TransformRecipe:
- """
- Recipe describes actual computation pathway.
- Recipe can be applied to a tensor or variable.
- """
- # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)
- # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided
- def __init__(
- self,
- # list of sizes (or just sizes) for elementary axes as they appear in left expression.
- # this is what (after computing unknown parts) will be a shape after first transposition.
- # This does not include any ellipsis dimensions.
- elementary_axes_lengths: List[int],
- # if additional axes are provided, they should be set in prev array
- # This shows mapping from name to position
- axis_name2elementary_axis: Dict[str, int],
- # each dimension in input can help to reconstruct length of one elementary axis
- # or verify one of dimensions. Each element points to element of elementary_axes_lengths.
- input_composition_known_unknown: List[Tuple[List[int], List[int]]],
- # permutation applied to elementary axes, if ellipsis is absent
- axes_permutation: List[int],
- # permutation puts reduced axes in the end, we only need to know the first position.
- first_reduced_axis: int,
- # at which positions which of elementary axes should appear. Axis position -> axis index.
- added_axes: Dict[int, int],
- # ids of axes as they appear in result, again pointers to elementary_axes_lengths,
- # only used to infer result dimensions
- output_composite_axes: List[List[int]],
- ):
- self.elementary_axes_lengths: List[int] = elementary_axes_lengths
- self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis
- self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown
- self.axes_permutation: List[int] = axes_permutation
- self.first_reduced_axis: int = first_reduced_axis
- self.added_axes: Dict[int, int] = added_axes
- self.output_composite_axes: List[List[int]] = output_composite_axes
- def _reconstruct_from_shape_uncached(
- self: TransformRecipe, shape: List[int], axes_dims: FakeHashableAxesLengths
- ) -> CookedRecipe:
- """
- Reconstruct all actual parameters using shape.
- Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet)
- known axes can be integers or symbols, but not Nones.
- """
- # magic number
- need_init_reshape = False
- # last axis is allocated for collapsed ellipsis
- axes_lengths: List[int] = list(self.elementary_axes_lengths)
- for axis, dim in axes_dims:
- axes_lengths[self.axis_name2elementary_axis[axis]] = dim
- for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown):
- length = shape[input_axis]
- if len(known_axes) == 0 and len(unknown_axes) == 1:
- # shortcut for the most common case
- axes_lengths[unknown_axes[0]] = length
- continue
- known_product = 1
- for axis in known_axes:
- known_product *= axes_lengths[axis]
- if len(unknown_axes) == 0:
- if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
- raise EinopsError(f"Shape mismatch, {length} != {known_product}")
- else:
- # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out'
- if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
- raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}")
- unknown_axis = unknown_axes[0]
- inferred_length: int = length // known_product
- axes_lengths[unknown_axis] = inferred_length
- if len(known_axes) + len(unknown_axes) != 1:
- need_init_reshape = True
- # at this point all axes_lengths are computed (either have values or variables, but not Nones)
- # elementary axes are ordered as they appear in input, then all added axes
- init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None
- need_final_reshape = False
- final_shapes: List[int] = []
- for grouping in self.output_composite_axes:
- lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
- final_shapes.append(_product(lengths))
- if len(lengths) != 1:
- need_final_reshape = True
- added_axes: Dict[int, int] = {
- pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()
- }
- # this list can be empty
- reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation)))
- n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation)
- axes_reordering: Optional[List[int]] = self.axes_permutation
- if self.axes_permutation == list(range(len(self.axes_permutation))):
- axes_reordering = None
- _final_shapes = final_shapes if need_final_reshape else None
- return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes
- _reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)
- def _apply_recipe(
- backend, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths
- ) -> Tensor:
- # this method implements actual work for all backends for 3 operations
- try:
- init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
- recipe, backend.shape(tensor), axes_lengths
- )
- except TypeError:
- # shape or one of passed axes lengths is not hashable (i.e. they are symbols)
- _result = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_lengths)
- (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result
- if init_shapes is not None:
- tensor = backend.reshape(tensor, init_shapes)
- if axes_reordering is not None:
- tensor = backend.transpose(tensor, axes_reordering)
- if len(reduced_axes) > 0:
- tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend)
- if len(added_axes) > 0:
- tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes)
- if final_shapes is not None:
- tensor = backend.reshape(tensor, final_shapes)
- return tensor
- def _apply_recipe_array_api(
- xp, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths
- ) -> Tensor:
- # completely-inline implementation
- init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
- recipe, tensor.shape, axes_lengths
- )
- if init_shapes is not None:
- tensor = xp.reshape(tensor, init_shapes)
- if axes_reordering is not None:
- tensor = xp.permute_dims(tensor, axes_reordering)
- if len(reduced_axes) > 0:
- if callable(reduction_type):
- # custom callable
- tensor = reduction_type(tensor, tuple(reduced_axes))
- else:
- # one of built-in operations
- assert reduction_type in _reductions
- tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes))
- if len(added_axes) > 0:
- # we use broadcasting
- for axis_position, _axis_length in added_axes.items():
- tensor = xp.expand_dims(tensor, axis=axis_position)
- final_shape = list(tensor.shape)
- for axis_position, axis_length in added_axes.items():
- final_shape[axis_position] = axis_length
- tensor = xp.broadcast_to(tensor, final_shape)
- if final_shapes is not None:
- tensor = xp.reshape(tensor, final_shapes)
- return tensor
- @functools.lru_cache(256)
- def _prepare_transformation_recipe(
- pattern: str,
- operation: Reduction,
- axes_names: Tuple[str, ...],
- ndim: int,
- ) -> TransformRecipe:
- """Perform initial parsing of pattern and provided supplementary info
- axes_lengths is a tuple of tuples (axis_name, axis_length)
- """
- left_str, rght_str = pattern.split("->")
- left = ParsedExpression(left_str)
- rght = ParsedExpression(rght_str)
- # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
- if not left.has_ellipsis and rght.has_ellipsis:
- raise EinopsError(f"Ellipsis found in right side, but not left side of a pattern {pattern}")
- if left.has_ellipsis and left.has_ellipsis_parenthesized:
- raise EinopsError(f"Ellipsis inside parenthesis in the left side is not allowed: {pattern}")
- if operation == "rearrange":
- if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
- raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)")
- difference = set.symmetric_difference(left.identifiers, rght.identifiers)
- if len(difference) > 0:
- raise EinopsError(f"Identifiers only on one side of expression (should be on both): {difference}")
- elif operation == "repeat":
- difference = set.difference(left.identifiers, rght.identifiers)
- if len(difference) > 0:
- raise EinopsError(f"Unexpected identifiers on the left side of repeat: {difference}")
- axes_without_size = set.difference(
- {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
- {*left.identifiers, *axes_names},
- )
- if len(axes_without_size) > 0:
- raise EinopsError(f"Specify sizes for new axes in repeat: {axes_without_size}")
- elif operation in _reductions or callable(operation):
- difference = set.difference(rght.identifiers, left.identifiers)
- if len(difference) > 0:
- raise EinopsError(f"Unexpected identifiers on the right side of reduce {operation}: {difference}")
- else:
- raise EinopsError(f"Unknown reduction {operation}. Expect one of {_reductions}.")
- if left.has_ellipsis:
- n_other_dims = len(left.composition) - 1
- if ndim < n_other_dims:
- raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.")
- ellipsis_ndim = ndim - n_other_dims
- ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)]
- left_composition = []
- for composite_axis in left.composition:
- if composite_axis == _ellipsis:
- for axis in ell_axes:
- left_composition.append([axis])
- else:
- left_composition.append(composite_axis)
- rght_composition = []
- for composite_axis in rght.composition:
- if composite_axis == _ellipsis:
- for axis in ell_axes:
- rght_composition.append([axis])
- else:
- group = []
- for axis in composite_axis:
- if axis == _ellipsis:
- group.extend(ell_axes)
- else:
- group.append(axis)
- rght_composition.append(group)
- left.identifiers.update(ell_axes)
- left.identifiers.remove(_ellipsis)
- if rght.has_ellipsis:
- rght.identifiers.update(ell_axes)
- rght.identifiers.remove(_ellipsis)
- else:
- if ndim != len(left.composition):
- raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.")
- left_composition = left.composition
- rght_composition = rght.composition
- # parsing all dimensions to find out lengths
- axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict()
- for composite_axis in left_composition:
- for axis_name in composite_axis:
- if isinstance(axis_name, AnonymousAxis):
- axis_name2known_length[axis_name] = axis_name.value
- else:
- axis_name2known_length[axis_name] = _unknown_axis_length
- # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point
- repeat_axes_names = []
- for axis_name in rght.identifiers:
- if axis_name not in axis_name2known_length:
- if isinstance(axis_name, AnonymousAxis):
- axis_name2known_length[axis_name] = axis_name.value
- else:
- axis_name2known_length[axis_name] = _unknown_axis_length
- repeat_axes_names.append(axis_name)
- axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
- # axes provided as kwargs
- for elementary_axis in axes_names:
- if not ParsedExpression.check_axis_name(elementary_axis):
- raise EinopsError("Invalid name for an axis", elementary_axis)
- if elementary_axis not in axis_name2known_length:
- raise EinopsError(f"Axis {elementary_axis} is not used in transform")
- axis_name2known_length[elementary_axis] = _expected_axis_length
- input_axes_known_unknown = []
- # some shapes are inferred later - all information is prepared for faster inference
- for composite_axis in left_composition:
- known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
- unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
- if len(unknown) > 1:
- raise EinopsError(f"Could not infer sizes for {unknown}")
- assert len(unknown) + len(known) == len(composite_axis)
- input_axes_known_unknown.append(
- ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown])
- )
- axis_position_after_reduction: Dict[str, int] = {}
- for axis_name in itertools.chain(*left_composition):
- if axis_name in rght.identifiers:
- axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
- result_axes_grouping: List[List[int]] = [
- [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition)
- ]
- ordered_axis_left = list(itertools.chain(*left_composition))
- ordered_axis_rght = list(itertools.chain(*rght_composition))
- reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers]
- order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes
- axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition]
- added_axes = {
- i: axis_name2position[axis_name]
- for i, axis_name in enumerate(ordered_axis_rght)
- if axis_name not in left.identifiers
- }
- first_reduced_axis = len(order_after_transposition) - len(reduced_axes)
- return TransformRecipe(
- elementary_axes_lengths=list(axis_name2known_length.values()),
- axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names},
- input_composition_known_unknown=input_axes_known_unknown,
- axes_permutation=axes_permutation,
- first_reduced_axis=first_reduced_axis,
- added_axes=added_axes,
- output_composite_axes=result_axes_grouping,
- )
- def _prepare_recipes_for_all_dims(
- pattern: str, operation: Reduction, axes_names: Tuple[str, ...]
- ) -> Dict[int, TransformRecipe]:
- """
- Internal function, used in layers.
- Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims
- """
- left_str, rght_str = pattern.split("->")
- left = ParsedExpression(left_str)
- dims = [len(left.composition)]
- if left.has_ellipsis:
- dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)]
- return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims}
- @overload
- def reduce(tensor: List[Tensor], pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor: ...
- @overload
- def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor: ...
- def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor:
- """
- einops.reduce combines rearrangement and reduction using reader-friendly notation.
- Some examples:
- ```python
- >>> x = np.random.randn(100, 32, 64)
- # perform max-reduction on the first axis
- # Axis t does not appear on RHS - thus we reduced over t
- >>> y = reduce(x, 't b c -> b c', 'max')
- # same as previous, but using verbose names for axes
- >>> y = reduce(x, 'time batch channel -> batch channel', 'max')
- # let's pretend now that x is a batch of images
- # with 4 dims: batch=10, height=20, width=30, channel=40
- >>> x = np.random.randn(10, 20, 30, 40)
- # 2d max-pooling with kernel size = 2 * 2 for image processing
- >>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
- # same as previous, using anonymous axes,
- # note: only reduced axes can be anonymous
- >>> y1 = reduce(x, 'b c (h1 2) (w1 2) -> b c h1 w1', 'max')
- # adaptive 2d max-pooling to 3 * 4 grid,
- # each element is max of 10x10 tile in the original tensor.
- >>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
- (10, 20, 3, 4)
- # Global average pooling
- >>> reduce(x, 'b c h w -> b c', 'mean').shape
- (10, 20)
- # subtracting mean over batch for each channel;
- # similar to x - np.mean(x, axis=(0, 2, 3), keepdims=True)
- >>> y = x - reduce(x, 'b c h w -> 1 c 1 1', 'mean')
- # Subtracting per-image mean for each channel
- >>> y = x - reduce(x, 'b c h w -> b c 1 1', 'mean')
- # same as previous, but using empty compositions
- >>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
- ```
- Parameters:
- tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
- list of tensors is also accepted, those should be of the same type and shape
- pattern: string, reduction pattern
- reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod', 'any', 'all').
- Alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
- This allows using various reductions like: np.max, np.nanmean, tf.reduce_logsumexp, torch.var, etc.
- axes_lengths: any additional specifications for dimensions
- Returns:
- tensor of the same type as input
- """
- try:
- if isinstance(tensor, list):
- if len(tensor) == 0:
- raise TypeError("Rearrange/Reduce/Repeat can't be applied to an empty list")
- backend = get_backend(tensor[0])
- tensor = backend.stack_on_zeroth_dimension(tensor)
- else:
- backend = get_backend(tensor)
- hashable_axes_lengths = tuple(axes_lengths.items())
- shape = backend.shape(tensor)
- recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape))
- return _apply_recipe(
- backend, recipe, cast(Tensor, tensor), reduction_type=reduction, axes_lengths=hashable_axes_lengths
- )
- except EinopsError as e:
- message = f' Error while processing {reduction}-reduction pattern "{pattern}".'
- if not isinstance(tensor, list):
- message += f"\n Input tensor shape: {shape}. "
- else:
- message += "\n Input is list. "
- message += f"Additional info: {axes_lengths}."
- raise EinopsError(message + f"\n {e}") from None
- @overload
- def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: Size) -> Tensor: ...
- @overload
- def rearrange(tensor: Tensor, pattern: str, **axes_lengths: Size) -> Tensor: ...
- def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths: Size) -> Tensor:
- """
- einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
- This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
- stack, concatenate and other operations.
- Examples:
- ```python
- # suppose we have a set of 32 images in "h w c" format (height-width-channel)
- >>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
- # stack along first (batch) axis, output is a single array
- >>> rearrange(images, 'b h w c -> b h w c').shape
- (32, 30, 40, 3)
- # stacked and reordered axes to "b c h w" format
- >>> rearrange(images, 'b h w c -> b c h w').shape
- (32, 3, 30, 40)
- # concatenate images along height (vertical axis), 960 = 32 * 30
- >>> rearrange(images, 'b h w c -> (b h) w c').shape
- (960, 40, 3)
- # concatenated images along horizontal axis, 1280 = 32 * 40
- >>> rearrange(images, 'b h w c -> h (b w) c').shape
- (30, 1280, 3)
- # flattened each image into a vector, 3600 = 30 * 40 * 3
- >>> rearrange(images, 'b h w c -> b (c h w)').shape
- (32, 3600)
- # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
- >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
- (128, 15, 20, 3)
- # space-to-depth operation
- >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
- (32, 15, 20, 12)
- ```
- When composing axes, C-order enumeration used (consecutive elements have different last axis).
- Find more examples in einops tutorial.
- Parameters:
- tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
- list of tensors is also accepted, those should be of the same type and shape
- pattern: string, rearrangement pattern
- axes_lengths: any additional specifications for dimensions
- Returns:
- tensor of the same type as input. If possible, a view to the original tensor is returned.
- """
- return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
- @overload
- def repeat(tensor: List[Tensor], pattern: str, **axes_lengths: Size) -> Tensor: ...
- @overload
- def repeat(tensor: Tensor, pattern: str, **axes_lengths: Size) -> Tensor: ...
- def repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths: Size) -> Tensor:
- """
- einops.repeat allows reordering elements and repeating them in arbitrary combinations.
- This operation includes functionality of repeat, tile, and broadcast functions.
- Examples for repeat operation:
- ```python
- # a grayscale image (of shape height x width)
- >>> image = np.random.randn(30, 40)
- # change it to RGB format by repeating in each channel
- >>> repeat(image, 'h w -> h w c', c=3).shape
- (30, 40, 3)
- # repeat image 2 times along height (vertical axis)
- >>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape
- (60, 40)
- # repeat image 2 time along height and 3 times along width
- >>> repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape
- (60, 120)
- # convert each pixel to a small square 2x2, i.e. upsample an image by 2x
- >>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
- (60, 80)
- # 'pixelate' an image first by downsampling by 2x, then upsampling
- >>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
- >>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
- (30, 40)
- ```
- When composing axes, C-order enumeration used (consecutive elements have different last axis).
- Find more examples in einops tutorial.
- Parameters:
- tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
- list of tensors is also accepted, those should be of the same type and shape
- pattern: string, rearrangement pattern
- axes_lengths: any additional specifications for dimensions
- Returns:
- Tensor of the same type as input. If possible, a view to the original tensor is returned.
- """
- return reduce(tensor, pattern, reduction="repeat", **axes_lengths)
- def parse_shape(x: Tensor, pattern: str) -> dict:
- """
- Parse a tensor shape to dictionary mapping axes names to their lengths.
- ```python
- # Use underscore to skip the dimension in parsing.
- >>> x = np.zeros([2, 3, 5, 7])
- >>> parse_shape(x, 'batch _ h w')
- {'batch': 2, 'h': 5, 'w': 7}
- # `parse_shape` output can be used to specify axes_lengths for other operations:
- >>> y = np.zeros([700])
- >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
- (2, 10, 5, 7)
- ```
- For symbolic frameworks may return symbols, not integers.
- Parameters:
- x: tensor of any supported framework
- pattern: str, space separated names for axes, underscore means skip axis
- Returns:
- dict, maps axes names to their lengths
- """
- exp = ParsedExpression(pattern, allow_underscore=True)
- shape = get_backend(x).shape(x)
- if exp.has_composed_axes():
- raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
- if len(shape) != len(exp.composition):
- if exp.has_ellipsis:
- if len(shape) < len(exp.composition) - 1:
- raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
- else:
- raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
- if exp.has_ellipsis:
- ellipsis_idx = exp.composition.index(_ellipsis)
- composition = (
- exp.composition[:ellipsis_idx]
- + ["_"] * (len(shape) - len(exp.composition) + 1)
- + exp.composition[ellipsis_idx + 1 :]
- )
- else:
- composition = exp.composition
- result = {}
- for axes, axis_length in zip(composition, shape): # type: ignore
- # axes either [], or [AnonymousAxis] or ['axis_name']
- if len(axes) == 0:
- if axis_length != 1:
- raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
- else:
- [axis] = axes
- if isinstance(axis, str):
- if axis != "_":
- result[axis] = axis_length
- else:
- if axis.value != axis_length:
- raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
- return result
- # _enumerate_directions is not exposed in the public API
- def _enumerate_directions(x):
- """
- For an n-dimensional tensor, returns tensors to enumerate each axis.
- ```python
- x = np.zeros([2, 3, 4]) # or any other tensor
- i, j, k = _enumerate_directions(x)
- result = i + 2*j + 3*k
- ```
- `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result
- Works very similarly to numpy.ogrid (open indexing grid)
- """
- backend = get_backend(x)
- shape = backend.shape(x)
- result = []
- for axis_id, axis_length in enumerate(shape):
- shape = [1] * len(shape)
- shape[axis_id] = axis_length
- result.append(backend.reshape(backend.arange(0, axis_length), shape))
- return result
- # to avoid importing numpy
- np_ndarray = Any
- def asnumpy(tensor: Tensor) -> np_ndarray:
- """
- Convert a tensor of an imperative framework (i.e. numpy/cupy/torch/jax/etc.) to `numpy.ndarray`
- Parameters:
- tensor: tensor of any known imperative framework
- Returns:
- `numpy.ndarray`, converted to numpy
- """
- return get_backend(tensor).to_numpy(tensor)
- def _validate_einsum_axis_name(axis_name):
- if len(axis_name) == 0:
- raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
- if len(axis_name) > 1:
- raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
- axis_name = axis_name[0]
- if isinstance(axis_name, AnonymousAxis):
- raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
- if len(axis_name) == 0:
- raise RuntimeError("Encountered empty axis name in einsum.")
- if not isinstance(axis_name, str):
- raise RuntimeError("Axis name in einsum must be a string.")
- @functools.lru_cache(256)
- def _compactify_pattern_for_einsum(pattern: str) -> str:
- if "->" not in pattern:
- # numpy allows this, so make sure users
- # don't accidentally do something like this.
- raise ValueError("Einsum pattern must contain '->'.")
- lefts_str, right_str = pattern.split("->")
- lefts = [ParsedExpression(left, allow_underscore=True, allow_duplicates=True) for left in lefts_str.split(",")]
- right = ParsedExpression(right_str, allow_underscore=True)
- # Start from 'a' and go up to 'Z'
- output_axis_names = string.ascii_letters
- i = 0
- axis_name_mapping = {}
- left_patterns = []
- for left in lefts:
- left_pattern = ""
- for raw_axis_name in left.composition:
- if raw_axis_name == _ellipsis:
- left_pattern += "..."
- continue
- _validate_einsum_axis_name(raw_axis_name)
- axis_name = raw_axis_name[0]
- if axis_name not in axis_name_mapping:
- if i >= len(output_axis_names):
- raise RuntimeError("Too many axes in einsum.")
- axis_name_mapping[axis_name] = output_axis_names[i]
- i += 1
- left_pattern += axis_name_mapping[axis_name]
- left_patterns.append(left_pattern)
- compact_pattern = ",".join(left_patterns) + "->"
- for raw_axis_name in right.composition:
- if raw_axis_name == _ellipsis:
- compact_pattern += "..."
- continue
- _validate_einsum_axis_name(raw_axis_name)
- axis_name = raw_axis_name[0]
- if axis_name not in axis_name_mapping:
- raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.")
- compact_pattern += axis_name_mapping[axis_name]
- return compact_pattern
- @typing.overload
- def einsum(tensor: Tensor, pattern: str, /) -> Tensor: ...
- @typing.overload
- def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str, /) -> Tensor: ...
- @typing.overload
- def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str, /) -> Tensor: ...
- @typing.overload
- def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str, /) -> Tensor: ...
- def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor:
- r"""
- einops.einsum calls einsum operations with einops-style named
- axes indexing, computing tensor products with an arbitrary
- number of tensors. Unlike typical einsum syntax, here you must
- pass tensors first, and then the pattern.
- Also, note that rearrange operations such as `"(batch chan) out"`,
- or singleton axes `()`, are not currently supported.
- Examples:
- For a given pattern such as:
- ```python
- >>> x, y, z = np.random.randn(3, 20, 20, 20)
- >>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k")
- ```
- the following formula is computed:
- ```tex
- output[a, b, k] = \sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
- ```
- where the summation over `c`, `d`, and `g` is performed
- because those axes names do not appear on the right-hand side.
- Let's see some additional examples:
- ```python
- # Filter a set of images:
- >>> batched_images = np.random.randn(128, 16, 16)
- >>> filters = np.random.randn(16, 16, 30)
- >>> result = einsum(batched_images, filters,
- ... "batch h w, h w channel -> batch channel")
- >>> result.shape
- (128, 30)
- # Matrix multiplication, with an unknown input shape:
- >>> batch_shape = (50, 30)
- >>> data = np.random.randn(*batch_shape, 20)
- >>> weights = np.random.randn(10, 20)
- >>> result = einsum(weights, data,
- ... "out_dim in_dim, ... in_dim -> ... out_dim")
- >>> result.shape
- (50, 30, 10)
- # Matrix trace on a single tensor:
- >>> matrix = np.random.randn(10, 10)
- >>> result = einsum(matrix, "i i ->")
- >>> result.shape
- ()
- ```
- Parameters:
- tensors_and_pattern:
- tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax).
- pattern: string, einsum pattern, with commas
- separating specifications for each tensor.
- pattern should be provided after all tensors.
- Returns:
- Tensor of the same type as input, after processing with einsum.
- """
- if len(tensors_and_pattern) <= 1:
- raise ValueError(
- "`einops.einsum` takes at minimum two arguments: the tensors (at least one), followed by the pattern."
- )
- pattern = tensors_and_pattern[-1]
- if not isinstance(pattern, str):
- raise ValueError(
- "The last argument passed to `einops.einsum` must be a string, representing the einsum pattern."
- )
- tensors = tensors_and_pattern[:-1]
- pattern = _compactify_pattern_for_einsum(pattern)
- return get_backend(tensors[0]).einsum(pattern, *tensors)
|