array_api.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from typing import List, Sequence, Tuple
  2. from .einops import EinopsError, Reduction, Tensor, _apply_recipe_array_api, _prepare_transformation_recipe
  3. from .packing import analyze_pattern, prod
  4. def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
  5. if isinstance(tensor, list):
  6. if len(tensor) == 0:
  7. raise TypeError("Einops can't be applied to an empty list")
  8. xp = tensor[0].__array_namespace__()
  9. tensor = xp.stack(tensor)
  10. else:
  11. xp = tensor.__array_namespace__()
  12. try:
  13. hashable_axes_lengths = tuple(axes_lengths.items())
  14. recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=tensor.ndim)
  15. return _apply_recipe_array_api(
  16. xp,
  17. recipe=recipe,
  18. tensor=tensor,
  19. reduction_type=reduction,
  20. axes_lengths=hashable_axes_lengths,
  21. )
  22. except EinopsError as e:
  23. message = f' Error while processing {reduction}-reduction pattern "{pattern}".'
  24. if not isinstance(tensor, list):
  25. message += f"\n Input tensor shape: {tensor.shape}. "
  26. else:
  27. message += "\n Input is list. "
  28. message += f"Additional info: {axes_lengths}."
  29. raise EinopsError(message + f"\n {e}") from None
  30. def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
  31. return reduce(tensor, pattern, reduction="repeat", **axes_lengths)
  32. def rearrange(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
  33. return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
  34. def asnumpy(tensor: Tensor):
  35. import numpy as np
  36. return np.from_dlpack(tensor)
  37. Shape = Tuple
  38. def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]:
  39. n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, "pack")
  40. xp = tensors[0].__array_namespace__()
  41. reshaped_tensors: List[Tensor] = []
  42. packed_shapes: List[Shape] = []
  43. for i, tensor in enumerate(tensors):
  44. shape = tensor.shape
  45. if len(shape) < min_axes:
  46. raise EinopsError(
  47. f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
  48. f"while pattern {pattern} assumes at least {min_axes} axes"
  49. )
  50. axis_after_packed_axes = len(shape) - n_axes_after
  51. packed_shapes.append(shape[n_axes_before:axis_after_packed_axes])
  52. reshaped_tensors.append(xp.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])))
  53. return xp.concat(reshaped_tensors, axis=n_axes_before), packed_shapes
  54. def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]:
  55. xp = tensor.__array_namespace__()
  56. n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname="unpack")
  57. # backend = get_backend(tensor)
  58. input_shape = tensor.shape
  59. if len(input_shape) != n_axes_before + 1 + n_axes_after:
  60. raise EinopsError(f"unpack(..., {pattern}) received input of wrong dim with shape {input_shape}")
  61. unpacked_axis: int = n_axes_before
  62. lengths_of_composed_axes: List[int] = [-1 if -1 in p_shape else prod(p_shape) for p_shape in packed_shapes]
  63. n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes)
  64. if n_unknown_composed_axes > 1:
  65. raise EinopsError(
  66. f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions"
  67. )
  68. # following manipulations allow to skip some shape verifications
  69. # and leave it to backends
  70. # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis
  71. # split positions when computed should be
  72. # [0, 1, 7, 11, N-6 , N ], where N = length of axis
  73. split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]]
  74. if n_unknown_composed_axes == 0:
  75. for i, x in enumerate(lengths_of_composed_axes[:-1]):
  76. split_positions[i + 1] = split_positions[i] + x
  77. else:
  78. unknown_composed_axis: int = lengths_of_composed_axes.index(-1)
  79. for i in range(unknown_composed_axis):
  80. split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i]
  81. for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]:
  82. split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j]
  83. shape_start = input_shape[:unpacked_axis]
  84. shape_end = input_shape[unpacked_axis + 1 :]
  85. slice_filler = (slice(None, None),) * unpacked_axis
  86. try:
  87. return [
  88. xp.reshape(
  89. # shortest way slice arbitrary axis
  90. tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]), ...)],
  91. (*shape_start, *element_shape, *shape_end),
  92. )
  93. for i, element_shape in enumerate(packed_shapes)
  94. ]
  95. except Exception as e:
  96. # this hits if there is an error during reshapes, which means passed shapes were incorrect
  97. raise RuntimeError(
  98. f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
  99. f" into requested {packed_shapes}"
  100. ) from e