_torch_specific.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. """
  2. Specialization of einops for torch.
  3. Unfortunately, torch's jit scripting mechanism isn't strong enough,
  4. and to have scripting supported at least for layers,
  5. a number of additional moves is needed.
  6. Design of main operations (dynamic resolution by lookup) is unlikely
  7. to be implemented by torch.jit.script,
  8. but torch.compile seems to work with operations just fine.
  9. """
  10. import warnings
  11. from typing import Dict, List, Tuple
  12. import torch
  13. from einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
  14. class TorchJitBackend:
  15. """
  16. Completely static backend that mimics part of normal backend functionality
  17. but restricted to be within torchscript.
  18. """
  19. @staticmethod
  20. def reduce(x: torch.Tensor, operation: str, reduced_axes: List[int]):
  21. if operation == "min":
  22. return x.amin(dim=reduced_axes)
  23. elif operation == "max":
  24. return x.amax(dim=reduced_axes)
  25. elif operation == "sum":
  26. return x.sum(dim=reduced_axes)
  27. elif operation == "mean":
  28. return x.mean(dim=reduced_axes)
  29. elif operation == "prod":
  30. for i in sorted(reduced_axes)[::-1]:
  31. x = x.prod(dim=i)
  32. return x
  33. else:
  34. raise NotImplementedError("Unknown reduction ", operation)
  35. @staticmethod
  36. def transpose(x, axes: List[int]):
  37. return x.permute(axes)
  38. @staticmethod
  39. def stack_on_zeroth_dimension(tensors: List[torch.Tensor]):
  40. return torch.stack(tensors)
  41. @staticmethod
  42. def tile(x, repeats: List[int]):
  43. return x.repeat(repeats)
  44. @staticmethod
  45. def add_axes(x, n_axes: int, pos2len: Dict[int, int]):
  46. repeats = [-1] * n_axes
  47. for axis_position, axis_length in pos2len.items():
  48. x = torch.unsqueeze(x, axis_position)
  49. repeats[axis_position] = axis_length
  50. return x.expand(repeats)
  51. @staticmethod
  52. def is_float_type(x):
  53. return x.dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]
  54. @staticmethod
  55. def shape(x):
  56. return x.shape
  57. @staticmethod
  58. def reshape(x, shape: List[int]):
  59. return x.reshape(shape)
  60. # mirrors einops.einops._apply_recipe
  61. def apply_for_scriptable_torch(
  62. recipe: TransformRecipe, tensor: torch.Tensor, reduction_type: str, axes_dims: List[Tuple[str, int]]
  63. ) -> torch.Tensor:
  64. backend = TorchJitBackend
  65. (
  66. init_shapes,
  67. axes_reordering,
  68. reduced_axes,
  69. added_axes,
  70. final_shapes,
  71. n_axes_w_added,
  72. ) = _reconstruct_from_shape_uncached(recipe, backend.shape(tensor), axes_dims=axes_dims)
  73. if init_shapes is not None:
  74. tensor = backend.reshape(tensor, init_shapes)
  75. if axes_reordering is not None:
  76. tensor = backend.transpose(tensor, axes_reordering)
  77. if len(reduced_axes) > 0:
  78. tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes)
  79. if len(added_axes) > 0:
  80. tensor = backend.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes)
  81. if final_shapes is not None:
  82. tensor = backend.reshape(tensor, final_shapes)
  83. return tensor
  84. def allow_ops_in_compiled_graph():
  85. if hasattr(torch, "__version__") and torch.__version__[0] < "2":
  86. # torch._dynamo and torch.compile appear in pytorch 2.0
  87. return
  88. if hasattr(torch, "__version__") and torch.__version__ >= "2.8":
  89. # einops don't need to use allow_in graph for torch 2.8 and above
  90. return
  91. try:
  92. from torch._dynamo import allow_in_graph
  93. except ImportError:
  94. warnings.warn(
  95. "allow_ops_in_compiled_graph failed to import torch: ensure pytorch >=2.0", ImportWarning, stacklevel=1
  96. )
  97. return
  98. from .einops import einsum, rearrange, reduce, repeat
  99. from .packing import pack, unpack
  100. allow_in_graph(rearrange)
  101. allow_in_graph(reduce)
  102. allow_in_graph(repeat)
  103. allow_in_graph(einsum)
  104. allow_in_graph(pack)
  105. allow_in_graph(unpack)
  106. # CF: https://github.com/pytorch/pytorch/blob/2df939aacac68e9621fbd5d876c78d86e72b41e2/torch/_dynamo/__init__.py#L222
  107. global _ops_were_registered_in_torchdynamo
  108. _ops_were_registered_in_torchdynamo = True
  109. # module import automatically registers ops in torchdynamo
  110. allow_ops_in_compiled_graph()