| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- """
- Specialization of einops for torch.
- Unfortunately, torch's jit scripting mechanism isn't strong enough,
- and to have scripting supported at least for layers,
- a number of additional moves is needed.
- Design of main operations (dynamic resolution by lookup) is unlikely
- to be implemented by torch.jit.script,
- but torch.compile seems to work with operations just fine.
- """
- import warnings
- from typing import Dict, List, Tuple
- import torch
- from einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
- class TorchJitBackend:
- """
- Completely static backend that mimics part of normal backend functionality
- but restricted to be within torchscript.
- """
- @staticmethod
- def reduce(x: torch.Tensor, operation: str, reduced_axes: List[int]):
- if operation == "min":
- return x.amin(dim=reduced_axes)
- elif operation == "max":
- return x.amax(dim=reduced_axes)
- elif operation == "sum":
- return x.sum(dim=reduced_axes)
- elif operation == "mean":
- return x.mean(dim=reduced_axes)
- elif operation == "prod":
- for i in sorted(reduced_axes)[::-1]:
- x = x.prod(dim=i)
- return x
- else:
- raise NotImplementedError("Unknown reduction ", operation)
- @staticmethod
- def transpose(x, axes: List[int]):
- return x.permute(axes)
- @staticmethod
- def stack_on_zeroth_dimension(tensors: List[torch.Tensor]):
- return torch.stack(tensors)
- @staticmethod
- def tile(x, repeats: List[int]):
- return x.repeat(repeats)
- @staticmethod
- def add_axes(x, n_axes: int, pos2len: Dict[int, int]):
- repeats = [-1] * n_axes
- for axis_position, axis_length in pos2len.items():
- x = torch.unsqueeze(x, axis_position)
- repeats[axis_position] = axis_length
- return x.expand(repeats)
- @staticmethod
- def is_float_type(x):
- return x.dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]
- @staticmethod
- def shape(x):
- return x.shape
- @staticmethod
- def reshape(x, shape: List[int]):
- return x.reshape(shape)
- # mirrors einops.einops._apply_recipe
- def apply_for_scriptable_torch(
- recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str, axes_dims: List[Tuple[str, int]]
- ) -> torch.Tensor:
- backend = TorchJitBackend
- (
- init_shapes,
- axes_reordering,
- reduced_axes,
- added_axes,
- final_shapes,
- n_axes_w_added,
- ) = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_dims=axes_dims)
- if init_shapes is not None:
- tensor = backend.reshape(tensor, init_shapes)
- if axes_reordering is not None:
- tensor = backend.transpose(tensor, axes_reordering)
- if len(reduced_axes) > 0:
- tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes)
- if len(added_axes) > 0:
- tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes)
- if final_shapes is not None:
- tensor = backend.reshape(tensor, final_shapes)
- return tensor
- def allow_ops_in_compiled_graph():
- if hasattr(torch, "__version__") and torch.__version__[0] < "2":
- # torch._dynamo and torch.compile appear in pytorch 2.0
- return
- if hasattr(torch, "__version__") and torch.__version__ >= "2.8":
- # einops don't need to use allow_in graph for torch 2.8 and above
- return
- try:
- from torch._dynamo import allow_in_graph
- except ImportError:
- warnings.warn(
- "allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0", ImportWarning, stacklevel=1
- )
- return
- from .einops import einsum, rearrange, reduce, repeat
- from .packing import pack, unpack
- allow_in_graph(rearrange)
- allow_in_graph(reduce)
- allow_in_graph(repeat)
- allow_in_graph(einsum)
- allow_in_graph(pack)
- allow_in_graph(unpack)
- # CF: https://github.com/pytorch/pytorch/blob/2df939aacac68e9621fbd5d876c78d86e72b41e2/torch/_dynamo/__init__.py#L222
- global _ops_were_registered_in_torchdynamo
- _ops_were_registered_in_torchdynamo = True
- # module import automatically registers ops in torchdynamo
- allow_ops_in_compiled_graph()
|