_einmix.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import string
  2. import warnings
  3. from typing import Any, Dict, List, Optional
  4. from einops import EinopsError
  5. from einops.einops import _product
  6. from einops.parsing import ParsedExpression, _ellipsis
  7. def _report_axes(axes: set, report_message: str):
  8. if len(axes) > 0:
  9. raise EinopsError(report_message.format(axes))
  10. class _EinmixMixin:
  11. def __init__(self, pattern: str, weight_shape: str, bias_shape: Optional[str] = None, **axes_lengths: Any):
  12. """
  13. EinMix - Einstein summation with automated tensor management and axis packing/unpacking.
  14. EinMix is a combination of einops and MLP, see tutorial:
  15. https://github.com/arogozhnikov/einops/blob/main/docs/3-einmix-layer.ipynb
  16. Imagine taking einsum with two arguments, one of each input, and one - tensor with weights
  17. >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight)
  18. This layer manages weights for you, syntax highlights a special role of weight matrix
  19. >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out')
  20. But otherwise it is the same einsum under the hood. Plus einops-rearrange.
  21. Simple linear layer with a bias term (you have one like that in your framework)
  22. >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20)
  23. There is no restriction to mix the last axis. Let's mix along height
  24. >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32)
  25. Example of channel-wise multiplication (like one used in normalizations)
  26. >>> EinMix('t b c -> t b c', weight_shape='c', c=128)
  27. Multi-head linear layer (each head is own linear layer):
  28. >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...)
  29. ... and yes, you need to specify all dimensions of weight shape/bias shape in parameters.
  30. Use cases:
  31. - when channel dimension is not last, use EinMix, not transposition
  32. - patch/segment embeddings
  33. - when need only within-group connections to reduce number of weights and computations
  34. - next-gen MLPs (follow tutorial link above to learn more!)
  35. - in general, any time you want to combine linear layer and einops.rearrange
  36. Uniform He initialization is applied to weight tensor.
  37. This accounts for the number of elements mixed and produced.
  38. Parameters
  39. :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
  40. :param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer
  41. If bias_shape is not specified, bias is not created.
  42. :param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added.
  43. :param axes_lengths: dimensions of weight tensor
  44. """
  45. super().__init__()
  46. self.pattern = pattern
  47. self.weight_shape = weight_shape
  48. self.bias_shape = bias_shape
  49. self.axes_lengths = axes_lengths
  50. self.initialize_einmix(
  51. pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths
  52. )
  53. def initialize_einmix(self, pattern: str, weight_shape: str, bias_shape: Optional[str], axes_lengths: dict):
  54. left_pattern, right_pattern = pattern.split("->")
  55. left = ParsedExpression(left_pattern)
  56. right = ParsedExpression(right_pattern)
  57. weight = ParsedExpression(weight_shape)
  58. _report_axes(
  59. set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}),
  60. "Unrecognized identifiers on the right side of EinMix {}",
  61. )
  62. if weight.has_ellipsis:
  63. raise EinopsError("Ellipsis is not supported in weight, as its shape should be fully specified")
  64. if left.has_ellipsis or right.has_ellipsis:
  65. if not (left.has_ellipsis and right.has_ellipsis):
  66. raise EinopsError(f"Ellipsis in EinMix should be on both sides, {pattern}")
  67. if left.has_ellipsis_parenthesized:
  68. raise EinopsError(f"Ellipsis on left side can't be in parenthesis, got {pattern}")
  69. if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]):
  70. raise EinopsError("Anonymous axes (numbers) are not allowed in EinMix")
  71. if "(" in weight_shape or ")" in weight_shape:
  72. raise EinopsError(f"Parenthesis is not allowed in weight shape: {weight_shape}")
  73. pre_reshape_pattern = None
  74. pre_reshape_lengths = None
  75. post_reshape_pattern = None
  76. if any(len(group) != 1 for group in left.composition):
  77. names: List[str] = []
  78. for group in left.composition:
  79. names += group
  80. names = [name if name != _ellipsis else "..." for name in names]
  81. composition = " ".join(names)
  82. pre_reshape_pattern = f"{left_pattern}-> {composition}"
  83. pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names}
  84. if any(len(group) != 1 for group in right.composition) or right.has_ellipsis_parenthesized:
  85. names = []
  86. for group in right.composition:
  87. names += group
  88. names = [name if name != _ellipsis else "..." for name in names]
  89. composition = " ".join(names)
  90. post_reshape_pattern = f"{composition} ->{right_pattern}"
  91. self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {})
  92. for axis in weight.identifiers:
  93. if axis not in axes_lengths:
  94. raise EinopsError(f"Dimension {axis} of weight should be specified")
  95. _report_axes(
  96. set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}),
  97. "Axes {} are not used in pattern",
  98. )
  99. _report_axes(
  100. set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), "Weight axes {} are redundant"
  101. )
  102. if len(weight.identifiers) == 0:
  103. warnings.warn("EinMix: weight has no dimensions (means multiplication by a number)", stacklevel=2)
  104. _weight_shape = [axes_lengths[axis] for (axis,) in weight.composition]
  105. # single output element is a combination of fan_in input elements
  106. _fan_in = _product([axes_lengths[axis] for (axis,) in weight.composition if axis not in right.identifiers])
  107. if bias_shape is not None:
  108. # maybe I should put ellipsis in the beginning for simplicity?
  109. if not isinstance(bias_shape, str):
  110. raise EinopsError("bias shape should be string specifying which axes bias depends on")
  111. bias = ParsedExpression(bias_shape)
  112. _report_axes(
  113. set.difference(bias.identifiers, right.identifiers),
  114. "Bias axes {} not present in output",
  115. )
  116. _report_axes(
  117. set.difference(bias.identifiers, set(axes_lengths)),
  118. "Sizes not provided for bias axes {}",
  119. )
  120. _bias_shape = []
  121. used_non_trivial_size = False
  122. for axes in right.composition:
  123. if axes == _ellipsis:
  124. if used_non_trivial_size:
  125. raise EinopsError("all bias dimensions should go after ellipsis in the output")
  126. else:
  127. # handles ellipsis correctly
  128. for axis in axes:
  129. if axis == _ellipsis:
  130. if used_non_trivial_size:
  131. raise EinopsError("all bias dimensions should go after ellipsis in the output")
  132. elif axis in bias.identifiers:
  133. _bias_shape.append(axes_lengths[axis])
  134. used_non_trivial_size = True
  135. else:
  136. _bias_shape.append(1)
  137. else:
  138. _bias_shape = None
  139. weight_bound = (3 / _fan_in) ** 0.5
  140. bias_bound = (1 / _fan_in) ** 0.5
  141. self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound)
  142. # rewrite einsum expression with single-letter latin identifiers so that
  143. # expression will be understood by any framework
  144. mapped_identifiers = {*left.identifiers, *right.identifiers, *weight.identifiers}
  145. if _ellipsis in mapped_identifiers:
  146. mapped_identifiers.remove(_ellipsis)
  147. mapped_identifiers = sorted(mapped_identifiers)
  148. mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapped_identifiers)}
  149. mapping2letters[_ellipsis] = "..." # preserve ellipsis
  150. def write_flat_remapped(axes: ParsedExpression):
  151. result = []
  152. for composed_axis in axes.composition:
  153. if isinstance(composed_axis, list):
  154. result.extend([mapping2letters[axis] for axis in composed_axis])
  155. else:
  156. assert composed_axis == _ellipsis
  157. result.append("...")
  158. return "".join(result)
  159. self.einsum_pattern: str = (
  160. f"{write_flat_remapped(left)},{write_flat_remapped(weight)}->{write_flat_remapped(right)}"
  161. )
  162. def _create_rearrange_layers(
  163. self,
  164. pre_reshape_pattern: Optional[str],
  165. pre_reshape_lengths: Optional[Dict],
  166. post_reshape_pattern: Optional[str],
  167. post_reshape_lengths: Optional[Dict],
  168. ):
  169. raise NotImplementedError("Should be defined in framework implementations")
  170. def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
  171. """Shape and implementations"""
  172. raise NotImplementedError("Should be defined in framework implementations")
  173. def __repr__(self):
  174. params = repr(self.pattern)
  175. params += f", '{self.weight_shape}'"
  176. if self.bias_shape is not None:
  177. params += f", '{self.bias_shape}'"
  178. for axis, length in self.axes_lengths.items():
  179. params += f", {axis}={length}"
  180. return f"{self.__class__.__name__}({params})"
  181. class _EinmixDebugger(_EinmixMixin):
  182. """Used only to test mixin"""
  183. def _create_rearrange_layers(
  184. self,
  185. pre_reshape_pattern: Optional[str],
  186. pre_reshape_lengths: Optional[Dict],
  187. post_reshape_pattern: Optional[str],
  188. post_reshape_lengths: Optional[Dict],
  189. ):
  190. self.pre_reshape_pattern = pre_reshape_pattern
  191. self.pre_reshape_lengths = pre_reshape_lengths
  192. self.post_reshape_pattern = post_reshape_pattern
  193. self.post_reshape_lengths = post_reshape_lengths
  194. def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
  195. self.saved_weight_shape = weight_shape
  196. self.saved_bias_shape = bias_shape