| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764 |
- """
- Backends in `einops` are organized to meet the following requirements
- - backends are not imported unless those are actually needed, because
- - backends may not be installed
- - importing all available backends will drive to significant memory footprint
- - backends may be present but installed with errors (but never used),
- importing may drive to crashes
- - backend should be either symbolic or imperative
- - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
- - if backend can't provide symbols for shape dimensions, UnknownSize objects are used
- """
- import sys
- __author__ = "Alex Rogozhnikov"
- _loaded_backends: dict = {}
- _type2backend: dict = {}
- _debug_importing = False
- def get_backend(tensor) -> "AbstractBackend":
- """
- Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
- If needed, imports package and creates backend
- """
- _type = type(tensor)
- _result = _type2backend.get(_type, None)
- if _result is not None:
- return _result
- previously_loaded_backends = list(_loaded_backends.items())
- for _framework_name, backend in previously_loaded_backends:
- if backend.is_appropriate_type(tensor):
- _type2backend[_type] = backend
- return backend
- # Find backend subclasses recursively
- backend_subclasses = []
- backends = AbstractBackend.__subclasses__()
- while backends:
- backend = backends.pop()
- backends += backend.__subclasses__()
- backend_subclasses.append(backend)
- # handles modification of _loaded_backends from other thread, see #391
- prev_backend_names = [x for x, _ in previously_loaded_backends]
- for BackendSubclass in backend_subclasses:
- if _debug_importing:
- print("Testing for subclass of ", BackendSubclass)
- if BackendSubclass.framework_name not in prev_backend_names:
- # check that module was already imported. Otherwise it can't be imported
- if BackendSubclass.framework_name in sys.modules:
- if _debug_importing:
- print("Imported backend for ", BackendSubclass.framework_name)
- backend = BackendSubclass()
- _loaded_backends[backend.framework_name] = backend
- if backend.is_appropriate_type(tensor):
- _type2backend[_type] = backend
- return backend
- raise RuntimeError(f"Tensor type unknown to einops {type(tensor)}")
- class AbstractBackend:
- """Base backend class, major part of methods are only for debugging purposes."""
- framework_name: str
- def is_appropriate_type(self, tensor):
- """helper method should recognize tensors it can handle"""
- raise NotImplementedError()
- def from_numpy(self, x):
- raise NotImplementedError("framework doesn't support imperative execution")
- def to_numpy(self, x):
- raise NotImplementedError("framework doesn't support imperative execution")
- def create_symbol(self, shape):
- raise NotImplementedError("framework doesn't support symbolic computations")
- def eval_symbol(self, symbol, symbol_value_pairs):
- # symbol-value pairs is list[tuple[symbol, value-tensor]]
- raise NotImplementedError("framework doesn't support symbolic computations")
- def arange(self, start, stop):
- # supplementary method used only in testing, so should implement CPU version
- raise NotImplementedError("framework doesn't implement arange")
- def shape(self, x):
- """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
- return x.shape
- def reshape(self, x, shape):
- return x.reshape(shape)
- def transpose(self, x, axes):
- return x.transpose(axes)
- def reduce(self, x, operation, axes):
- return getattr(x, operation)(axis=axes)
- def stack_on_zeroth_dimension(self, tensors: list):
- raise NotImplementedError()
- def add_axis(self, x, new_position):
- raise NotImplementedError()
- def add_axes(self, x, n_axes, pos2len):
- repeats = [1] * n_axes
- for axis_position, axis_length in pos2len.items():
- x = self.add_axis(x, axis_position)
- repeats[axis_position] = axis_length
- return self.tile(x, tuple(repeats))
- def tile(self, x, repeats):
- """repeats - same lengths as x.shape"""
- raise NotImplementedError()
- def concat(self, tensors, axis: int):
- """concatenates tensors along axis.
- Assume identical across tensors: devices, dtypes and shapes except selected axis."""
- raise NotImplementedError()
- def is_float_type(self, x):
- # some backends (torch) can't compute average for non-floating types.
- # Decided to drop average for all backends if type is not floating
- raise NotImplementedError()
- def layers(self):
- raise NotImplementedError("backend does not provide layers")
- def __repr__(self):
- return f"<einops backend for {self.framework_name}>"
- def einsum(self, pattern, *x):
- raise NotImplementedError("backend does not support einsum")
- class UnknownSize:
- """pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements"""
- def __floordiv__(self, other):
- return self
- def __eq__(self, other):
- return True # we don't know actual size
- def __mul__(self, other):
- return self
- def __rmul__(self, other):
- return self
- def __hash__(self):
- return hash(None)
- class NumpyBackend(AbstractBackend):
- framework_name = "numpy"
- def __init__(self):
- import numpy
- self.np = numpy
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, self.np.ndarray)
- def from_numpy(self, x):
- return x
- def to_numpy(self, x):
- return x
- def arange(self, start, stop):
- return self.np.arange(start, stop)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.np.stack(tensors)
- def tile(self, x, repeats):
- return self.np.tile(x, repeats)
- def concat(self, tensors, axis: int):
- return self.np.concatenate(tensors, axis=axis)
- def is_float_type(self, x):
- return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
- def add_axis(self, x, new_position):
- return self.np.expand_dims(x, new_position)
- def einsum(self, pattern, *x):
- return self.np.einsum(pattern, *x)
- class JaxBackend(NumpyBackend):
- framework_name = "jax"
- def __init__(self):
- super().__init__()
- self.onp = self.np
- import jax.numpy
- self.np = jax.numpy
- def from_numpy(self, x):
- return self.np.asarray(x)
- def to_numpy(self, x):
- return self.onp.asarray(x)
- class TorchBackend(AbstractBackend):
- framework_name = "torch"
- def __init__(self):
- import torch
- self.torch = torch
- # importing would register operations in torch._dynamo for torch.compile
- from . import _torch_specific # noqa
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, self.torch.Tensor)
- def from_numpy(self, x):
- variable = self.torch.from_numpy(x)
- if self.is_float_type(variable):
- # attach grad only to floating types
- variable.requires_grad = True
- return variable
- def to_numpy(self, x):
- return x.detach().cpu().numpy()
- def arange(self, start, stop):
- return self.torch.arange(start, stop, dtype=self.torch.int64)
- def reduce(self, x, operation, reduced_axes):
- if operation == "min":
- return x.amin(dim=reduced_axes)
- elif operation == "max":
- return x.amax(dim=reduced_axes)
- elif operation == "sum":
- return x.sum(dim=reduced_axes)
- elif operation == "mean":
- return x.mean(dim=reduced_axes)
- elif operation in ("any", "all", "prod"):
- # pytorch supports reducing only one operation at a time
- for i in sorted(reduced_axes)[::-1]:
- x = getattr(x, operation)(dim=i)
- return x
- else:
- raise NotImplementedError("Unknown reduction ", operation)
- def transpose(self, x, axes):
- return x.permute(axes)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.torch.stack(tensors)
- def add_axes(self, x, n_axes, pos2len):
- repeats = [-1] * n_axes
- for axis_position, axis_length in pos2len.items():
- x = self.add_axis(x, axis_position)
- repeats[axis_position] = axis_length
- return x.expand(repeats)
- def tile(self, x, repeats):
- return x.repeat(repeats)
- def concat(self, tensors, axis: int):
- return self.torch.cat(tensors, dim=axis)
- def add_axis(self, x, new_position):
- return self.torch.unsqueeze(x, new_position)
- def is_float_type(self, x):
- return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]
- def layers(self):
- from .layers import torch
- return torch
- def einsum(self, pattern, *x):
- return self.torch.einsum(pattern, *x)
- class CupyBackend(AbstractBackend):
- framework_name = "cupy"
- def __init__(self):
- import cupy
- self.cupy = cupy
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, self.cupy.ndarray)
- def from_numpy(self, x):
- return self.cupy.asarray(x)
- def to_numpy(self, x):
- return self.cupy.asnumpy(x)
- def arange(self, start, stop):
- return self.cupy.arange(start, stop)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.cupy.stack(tensors)
- def tile(self, x, repeats):
- return self.cupy.tile(x, repeats)
- def concat(self, tensors, axis: int):
- return self.cupy.concatenate(tensors, axis=axis)
- def add_axis(self, x, new_position):
- return self.cupy.expand_dims(x, new_position)
- def is_float_type(self, x):
- return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
- def einsum(self, pattern, *x):
- return self.cupy.einsum(pattern, *x)
- class HashableTuple:
- """Overcomes non-hashability of symbolic elements"""
- def __init__(self, elements: tuple):
- self.elements = elements
- def __iter__(self):
- yield from self.elements
- def __len__(self):
- return len(self.elements)
- def __getitem__(self, item):
- return self.elements[item]
- # default equality and hash is used (True only with itself, hash taken of id)
- class TensorflowBackend(AbstractBackend):
- framework_name = "tensorflow"
- def __init__(self):
- import tensorflow
- self.tf = tensorflow
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, (self.tf.Tensor, self.tf.Variable))
- def from_numpy(self, x):
- assert self.tf.executing_eagerly()
- return self.tf.convert_to_tensor(x)
- def to_numpy(self, x):
- assert self.tf.executing_eagerly()
- return x.numpy()
- def arange(self, start, stop):
- return self.tf.range(start, stop)
- def shape(self, x):
- if self.tf.executing_eagerly():
- return tuple(UnknownSize() if d is None else int(d) for d in x.shape)
- else:
- static_shape = x.shape.as_list()
- tf_shape = self.tf.shape(x)
- # use the static shape where known, otherwise use the TF shape components
- shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)])
- try:
- hash(shape)
- return shape
- except BaseException:
- # unhashable symbols in shape. Wrap tuple to be hashable.
- return HashableTuple(shape)
- def reduce(self, x, operation, axes):
- return getattr(self.tf, "reduce_" + operation)(x, axis=axes)
- def reshape(self, x, shape):
- return self.tf.reshape(x, shape)
- def transpose(self, x, axes):
- return self.tf.transpose(x, axes)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.tf.stack(tensors)
- def tile(self, x, repeats):
- return self.tf.tile(x, repeats)
- def concat(self, tensors, axis: int):
- return self.tf.concat(tensors, axis=axis)
- def add_axis(self, x, new_position):
- return self.tf.expand_dims(x, new_position)
- def is_float_type(self, x):
- return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
- def layers(self):
- from .layers import tensorflow
- return tensorflow
- def einsum(self, pattern, *x):
- return self.tf.einsum(pattern, *x)
- class TFKerasBackend(AbstractBackend):
- framework_name = "tensorflow.keras"
- def __init__(self):
- import tensorflow as tf
- self.tf = tf
- self.keras = tf.keras
- self.K = tf.keras.backend
- def is_appropriate_type(self, tensor):
- return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor)
- def create_symbol(self, shape):
- return self.keras.Input(batch_shape=shape)
- def eval_symbol(self, symbol, symbol_value_pairs):
- model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol)
- return model.predict_on_batch([val for (_, val) in symbol_value_pairs])
- def arange(self, start, stop):
- return self.K.arange(start, stop)
- def shape(self, x):
- shape = self.K.shape(x) # tf tensor
- return HashableTuple(tuple(shape))
- def reduce(self, x, operation, axes):
- return getattr(self.K, operation)(x, axis=axes)
- def reshape(self, x, shape):
- return self.K.reshape(x, shape)
- def transpose(self, x, axes):
- return self.K.permute_dimensions(x, axes)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.K.stack(tensors)
- def tile(self, x, repeats):
- return self.K.tile(x, repeats)
- def concat(self, tensors, axis: int):
- return self.K.concatenate(tensors, axis=axis)
- def add_axis(self, x, new_position):
- return self.K.expand_dims(x, new_position)
- def is_float_type(self, x):
- return "float" in self.K.dtype(x)
- def layers(self):
- from .layers import keras
- return keras
- class OneFlowBackend(AbstractBackend):
- framework_name = "oneflow"
- def __init__(self):
- import oneflow as flow
- self.flow = flow
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, self.flow.Tensor)
- def from_numpy(self, x):
- variable = self.flow.from_numpy(x)
- if self.is_float_type(variable):
- # attach grad only to floating types
- variable.requires_grad = True
- return variable
- def to_numpy(self, x):
- return x.detach().cpu().numpy()
- def arange(self, start, stop):
- return self.flow.arange(start, stop, dtype=self.flow.int64)
- def reduce(self, x, operation, reduced_axes):
- for axis in sorted(reduced_axes, reverse=True):
- if operation == "min":
- x, _ = x.min(dim=axis)
- elif operation == "max":
- x, _ = x.max(dim=axis)
- elif operation in ["sum", "mean", "prod", "any", "all"]:
- x = getattr(x, operation)(dim=axis)
- else:
- raise NotImplementedError("Unknown reduction ", operation)
- return x
- def transpose(self, x, axes):
- return x.permute(axes)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.flow.stack(tensors)
- def add_axes(self, x, n_axes, pos2len):
- repeats = [-1] * n_axes
- for axis_position, axis_length in pos2len.items():
- x = self.add_axis(x, axis_position)
- repeats[axis_position] = axis_length
- return x.expand(*repeats)
- def tile(self, x, repeats):
- return x.repeat(repeats)
- def concat(self, tensors, axis: int):
- return self.flow.concat(tensors, dim=axis)
- def add_axis(self, x, new_position):
- return self.flow.unsqueeze(x, new_position)
- def is_float_type(self, x):
- return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64]
- def layers(self):
- from .layers import oneflow
- return oneflow
- def einsum(self, pattern, *x):
- return self.flow.einsum(pattern, *x)
- class PaddleBackend(AbstractBackend):
- framework_name = "paddle"
- def __init__(self):
- import paddle
- self.paddle = paddle
- def is_appropriate_type(self, tensor):
- return self.paddle.is_tensor(tensor)
- def from_numpy(self, x):
- tensor = self.paddle.to_tensor(x)
- tensor.stop_gradient = False
- return tensor
- def to_numpy(self, x):
- return x.detach().numpy()
- def arange(self, start, stop):
- return self.paddle.arange(start, stop, dtype=self.paddle.int64)
- def reduce(self, x, operation, axes):
- if len(axes) == x.ndim:
- # currently paddle returns 1d tensor instead of 0d
- return super().reduce(x, operation, axes).squeeze(0)
- else:
- return super().reduce(x, operation, axes)
- def transpose(self, x, axes):
- return x.transpose(axes)
- def add_axes(self, x, n_axes, pos2len):
- repeats = [-1] * n_axes
- for axis_position, axis_length in pos2len.items():
- x = self.add_axis(x, axis_position)
- repeats[axis_position] = axis_length
- return x.expand(repeats)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.paddle.stack(tensors)
- def reshape(self, x, shape):
- return x.reshape(shape)
- def tile(self, x, repeats):
- return x.tile(repeats)
- def concat(self, tensors, axis: int):
- return self.paddle.concat(tensors, axis=axis)
- def add_axis(self, x, new_position):
- return x.unsqueeze(new_position)
- def is_float_type(self, x):
- return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64]
- def layers(self):
- from .layers import paddle
- return paddle
- def einsum(self, pattern, *x):
- return self.paddle.einsum(pattern, *x)
- def shape(self, x):
- return tuple(x.shape)
- class TinygradBackend(AbstractBackend):
- framework_name = "tinygrad"
- def __init__(self):
- import tinygrad
- self.tinygrad = tinygrad
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, self.tinygrad.Tensor)
- def from_numpy(self, x):
- return self.tinygrad.Tensor(x)
- def to_numpy(self, x):
- return x.numpy()
- def arange(self, start, stop):
- return self.tinygrad.Tensor.arange(start, stop)
- def shape(self, x):
- return x.shape
- def reshape(self, x, shape):
- return x.reshape(shape)
- def transpose(self, x, axes):
- return x.permute(axes)
- def reduce(self, x, operation, axes):
- for axis in sorted(axes, reverse=True):
- x = getattr(x, operation)(axis=axis)
- return x
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.tinygrad.Tensor.stack(tensors)
- def add_axis(self, x, new_position):
- return x.unsqueeze(new_position)
- def tile(self, x, repeats):
- return x.repeat(repeats)
- def concat(self, tensors, axis: int):
- return tensors[0].cat(*tensors[1:], dim=axis) if len(tensors) > 1 else tensors[0]
- def is_float_type(self, x):
- return self.tinygrad.dtypes.is_float(x.dtype)
- def einsum(self, pattern, *x):
- return self.tinygrad.Tensor.einsum(pattern, *x)
- class PyTensorBackend(AbstractBackend):
- framework_name = "pytensor"
- def __init__(self):
- from pytensor import tensor
- self.pt = tensor
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, self.pt.TensorVariable)
- def is_float_type(self, x):
- return x.dtype in self.pt.type.float_dtypes
- def from_numpy(self, x):
- return self.pt.as_tensor(x)
- def to_numpy(self, x):
- return x.eval() # Will only work if there are no symbolic inputs
- def create_symbol(self, shape):
- if not isinstance(shape, tuple | list):
- shape = (shape,)
- return self.pt.tensor(shape=shape)
- def eval_symbol(self, symbol, symbol_value_pairs):
- return symbol.eval(dict(symbol_value_pairs))
- def arange(self, start, stop):
- return self.pt.arange(start, stop)
- def shape(self, x):
- # use the static shape dimensions where known
- return tuple(
- static_dim if static_dim is not None else symbolic_dim
- for static_dim, symbolic_dim in zip(x.type.shape, x.shape)
- )
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.pt.stack(tensors)
- def tile(self, x, repeats):
- return self.pt.tile(x, repeats)
- def concat(self, tensors, axis: int):
- return self.pt.concatenate(tensors, axis=axis)
- def add_axis(self, x, new_position):
- return self.pt.expand_dims(x, new_position)
- def einsum(self, pattern, *x):
- return self.pt.einsum(pattern, *x)
- class MLXBackend(AbstractBackend):
- framework_name = "mlx"
- def __init__(self):
- import mlx.core as mx
- import numpy as np
- self.mx = mx
- self.np = np
- def is_appropriate_type(self, tensor):
- return isinstance(tensor, self.mx.array)
- def from_numpy(self, x):
- return self.mx.array(x)
- def to_numpy(self, x):
- if x.dtype == self.mx.bfloat16:
- x = x.astype(self.mx.float32)
- return self.np.array(x)
- def arange(self, start, stop):
- return self.mx.arange(start, stop)
- def stack_on_zeroth_dimension(self, tensors: list):
- return self.mx.stack(tensors)
- def add_axes(self, x, new_position):
- return self.mx.expand_dims(x, new_position)
- def tile(self, x, repeats):
- return self.mx.tile(x, repeats)
- def concat(self, tensors, axis: int):
- return self.mx.concatenate(tensors, axis=axis)
- def is_float_type(self, x):
- return self.mx.issubdtype(x.dtype, self.mx.floating)
- def einsum(self, pattern, *x):
- return self.mx.einsum(pattern, *x)
|