packing.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from functools import lru_cache
  2. from typing import List, Sequence, Tuple, TypeVar, Union
  3. from einops import EinopsError
  4. from einops._backends import get_backend
  5. from einops.parsing import ParsedExpression
  6. Tensor = TypeVar("Tensor")
  7. Shape = Union[Tuple[int, ...], List[int]]
  8. @lru_cache(maxsize=128)
  9. def analyze_pattern(pattern: str, opname: str) -> Tuple[int, int, int]:
  10. # Maybe some validation of identifiers?
  11. axes = pattern.split()
  12. axes_set = set(axes)
  13. if len(axes) != len(axes_set):
  14. raise EinopsError(f'Duplicates in axes names in {opname}(..., "{pattern}")')
  15. if "*" not in axes_set:
  16. raise EinopsError(f'No *-axis in {opname}(..., "{pattern}")')
  17. for axis in axes:
  18. if axis != "*":
  19. is_valid, reason = ParsedExpression.check_axis_name_return_reason(axis)
  20. if not is_valid:
  21. raise EinopsError(f'Invalid axis name {axis} in {opname}(..., "{pattern}")')
  22. n_axes_before = axes.index("*")
  23. n_axes_after = len(axes) - n_axes_before - 1
  24. min_axes = n_axes_before + n_axes_after
  25. return n_axes_before, n_axes_after, min_axes
  26. def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]:
  27. """
  28. Packs several tensors into one.
  29. See einops tutorial for introduction into packing (and how it replaces stack and concatenation).
  30. Parameters:
  31. tensors: tensors to be packed, can be of different dimensionality
  32. pattern: pattern that is shared for all inputs and output, e.g. "i j * k" or "batch seq *"
  33. Returns:
  34. (packed_tensor, packed_shapes aka PS)
  35. Example:
  36. ```python
  37. >>> from numpy import zeros as Z
  38. >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
  39. >>> packed, ps = pack(inputs, 'i j * k')
  40. >>> packed.shape, ps
  41. ((2, 3, 71, 5), [(), (7,), (7, 9)])
  42. ```
  43. In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
  44. All other axes were 'packed' and concatenated.
  45. PS (packed shapes) contains information about axes that were matched to '*' in every input.
  46. Resulting tensor has as many elements as all inputs in total.
  47. Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.
  48. ```python
  49. >>> inputs_unpacked = unpack(packed, ps, 'i j * k')
  50. >>> [x.shape for x in inputs_unpacked]
  51. [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]
  52. ```
  53. Read the tutorial for introduction and application scenarios.
  54. """
  55. n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, "pack")
  56. # packing zero tensors is illegal
  57. backend = get_backend(tensors[0])
  58. reshaped_tensors: List[Tensor] = []
  59. packed_shapes: List[Shape] = []
  60. for i, tensor in enumerate(tensors):
  61. shape = backend.shape(tensor)
  62. if len(shape) < min_axes:
  63. raise EinopsError(
  64. f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
  65. f"while pattern {pattern} assumes at least {min_axes} axes"
  66. )
  67. axis_after_packed_axes = len(shape) - n_axes_after
  68. packed_shapes.append(shape[n_axes_before:axis_after_packed_axes])
  69. reshaped_tensors.append(backend.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])))
  70. return backend.concat(reshaped_tensors, axis=n_axes_before), packed_shapes
  71. def prod(x: Shape) -> int:
  72. result = 1
  73. for i in x:
  74. result *= i
  75. return result
  76. def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]:
  77. """
  78. Unpacks a single tensor into several by splitting over a selected axes.
  79. See einops tutorial for introduction into packing (and how it replaces stack and concatenation).
  80. Parameters:
  81. tensor: tensor to be unpacked
  82. packed_shapes: packed_shapes (aka PS) is a list of shapes that take place of '*' in each output.
  83. output will contain a single tensor for every provided shape
  84. pattern: pattern that is shared for input and all outputs, e.g. "i j * k" or "batch seq *",
  85. where * designates an axis to be unpacked
  86. Returns:
  87. list of tensors
  88. If framework supports views, results are views to the original tensor.
  89. Example:
  90. ```python
  91. >>> from numpy import zeros as Z
  92. >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
  93. >>> packed, ps = pack(inputs, 'i j * k')
  94. >>> packed.shape, ps
  95. ((2, 3, 71, 5), [(), (7,), (7, 9)])
  96. ```
  97. In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
  98. All other axes were 'packed' and concatenated.
  99. PS (packed shapes) contains information about axes that were matched to '*' in every input.
  100. Resulting tensor has as many elements as all inputs in total.
  101. Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.
  102. ```python
  103. >>> inputs_unpacked = unpack(packed, ps, 'i j * k')
  104. >>> [x.shape for x in inputs_unpacked]
  105. [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]
  106. ```
  107. Read the tutorial for introduction and application scenarios.
  108. """
  109. n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname="unpack")
  110. backend = get_backend(tensor)
  111. input_shape = backend.shape(tensor)
  112. if len(input_shape) != n_axes_before + 1 + n_axes_after:
  113. raise EinopsError(f"unpack(..., {pattern}) received input of wrong dim with shape {input_shape}")
  114. unpacked_axis: int = n_axes_before
  115. lengths_of_composed_axes: List[int] = [-1 if -1 in p_shape else prod(p_shape) for p_shape in packed_shapes]
  116. n_unknown_composed_axes = sum(int(x == -1) for x in lengths_of_composed_axes)
  117. if n_unknown_composed_axes > 1:
  118. raise EinopsError(
  119. f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions"
  120. )
  121. # following manipulations allow to skip some shape verifications
  122. # and leave it to backends
  123. # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis
  124. # split positions when computed should be
  125. # [0, 1, 7, 11, N-6 , N ], where N = length of axis
  126. split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]]
  127. if n_unknown_composed_axes == 0:
  128. for i, x in enumerate(lengths_of_composed_axes[:-1]):
  129. split_positions[i + 1] = split_positions[i] + x
  130. else:
  131. unknown_composed_axis: int = lengths_of_composed_axes.index(-1)
  132. for i in range(unknown_composed_axis):
  133. split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i]
  134. for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]:
  135. split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j]
  136. shape_start = input_shape[:unpacked_axis]
  137. shape_end = input_shape[unpacked_axis + 1 :]
  138. slice_filler = (slice(None, None),) * unpacked_axis
  139. try:
  140. return [
  141. backend.reshape(
  142. # shortest way slice arbitrary axis
  143. tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]))],
  144. (*shape_start, *element_shape, *shape_end),
  145. )
  146. for i, element_shape in enumerate(packed_shapes)
  147. ]
  148. except Exception as e:
  149. # this hits if there is an error during reshapes, which means passed shapes were incorrect
  150. raise EinopsError(
  151. f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
  152. f" into requested {packed_shapes}"
  153. ) from e