__init__.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. __author__ = "Alex Rogozhnikov"
  2. from typing import Any, Dict
  3. from einops import EinopsError
  4. from einops.einops import TransformRecipe, _apply_recipe, _prepare_recipes_for_all_dims, get_backend
  5. class RearrangeMixin:
  6. """
  7. Rearrange layer behaves identically to einops.rearrange operation.
  8. :param pattern: str, rearrangement pattern
  9. :param axes_lengths: any additional specification of dimensions
  10. See einops.rearrange for source_examples.
  11. """
  12. def __init__(self, pattern: str, **axes_lengths: Any) -> None:
  13. super().__init__()
  14. self.pattern = pattern
  15. self.axes_lengths = axes_lengths
  16. # self._recipe = self.recipe() # checking parameters
  17. self._multirecipe = self.multirecipe()
  18. self._axes_lengths = tuple(self.axes_lengths.items())
  19. def __repr__(self) -> str:
  20. params = repr(self.pattern)
  21. for axis, length in self.axes_lengths.items():
  22. params += f", {axis}={length}"
  23. return f"{self.__class__.__name__}({params})"
  24. def multirecipe(self) -> Dict[int, TransformRecipe]:
  25. try:
  26. return _prepare_recipes_for_all_dims(
  27. self.pattern, operation="rearrange", axes_names=tuple(self.axes_lengths)
  28. )
  29. except EinopsError as e:
  30. raise EinopsError(f" Error while preparing {self!r}\n {e}") from None
  31. def _apply_recipe(self, x):
  32. backend = get_backend(x)
  33. return _apply_recipe(
  34. backend=backend,
  35. recipe=self._multirecipe[len(x.shape)],
  36. tensor=x,
  37. reduction_type="rearrange",
  38. axes_lengths=self._axes_lengths,
  39. )
  40. def __getstate__(self):
  41. return {"pattern": self.pattern, "axes_lengths": self.axes_lengths}
  42. def __setstate__(self, state):
  43. self.__init__(pattern=state["pattern"], **state["axes_lengths"])
  44. class ReduceMixin:
  45. """
  46. Reduce layer behaves identically to einops.reduce operation.
  47. :param pattern: str, rearrangement pattern
  48. :param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
  49. :param axes_lengths: any additional specification of dimensions
  50. See einops.reduce for source_examples.
  51. """
  52. def __init__(self, pattern: str, reduction: str, **axes_lengths: Any):
  53. super().__init__()
  54. self.pattern = pattern
  55. self.reduction = reduction
  56. self.axes_lengths = axes_lengths
  57. self._multirecipe = self.multirecipe()
  58. self._axes_lengths = tuple(self.axes_lengths.items())
  59. def __repr__(self):
  60. params = f"{self.pattern!r}, {self.reduction!r}"
  61. for axis, length in self.axes_lengths.items():
  62. params += f", {axis}={length}"
  63. return f"{self.__class__.__name__}({params})"
  64. def multirecipe(self) -> Dict[int, TransformRecipe]:
  65. try:
  66. return _prepare_recipes_for_all_dims(
  67. self.pattern, operation=self.reduction, axes_names=tuple(self.axes_lengths)
  68. )
  69. except EinopsError as e:
  70. raise EinopsError(f" Error while preparing {self!r}\n {e}") from None
  71. def _apply_recipe(self, x):
  72. backend = get_backend(x)
  73. return _apply_recipe(
  74. backend=backend,
  75. recipe=self._multirecipe[len(x.shape)],
  76. tensor=x,
  77. reduction_type=self.reduction,
  78. axes_lengths=self._axes_lengths,
  79. )
  80. def __getstate__(self):
  81. return {"pattern": self.pattern, "reduction": self.reduction, "axes_lengths": self.axes_lengths}
  82. def __setstate__(self, state):
  83. self.__init__(pattern=state["pattern"], reduction=state["reduction"], **state["axes_lengths"])