_order.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from __future__ import annotations
  2. from typing import Any, TYPE_CHECKING, Union
  3. if TYPE_CHECKING:
  4. from collections.abc import Sequence
  5. import torch # noqa: TC002
  6. from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
  7. def _wrap_dim(arg: Any, orig_ndim: int, allow_none: bool = True) -> DimEntry:
  8. """
  9. Convert various dimension representations to DimEntry.
  10. Args:
  11. arg: The argument to convert (Dim, int, or other)
  12. orig_ndim: Original number of dimensions
  13. allow_none: Whether to allow None values
  14. Returns:
  15. DimEntry representation of the dimension
  16. """
  17. from . import Dim
  18. if arg is None and allow_none:
  19. return DimEntry() # None entry
  20. elif isinstance(arg, Dim):
  21. return DimEntry(arg)
  22. elif isinstance(arg, int):
  23. if arg < 0:
  24. pos = arg
  25. else:
  26. pos = arg - orig_ndim
  27. return DimEntry(pos)
  28. else:
  29. return DimEntry()
  30. def order(
  31. tensor_or_dim: Union[torch.Tensor, Any], *dims: Union[Any, Sequence[Any]]
  32. ) -> torch.Tensor:
  33. """
  34. Reorder the dimensions of a tensor or create a tensor from a dimension.
  35. It allows reordering tensor dimensions using first-class dimensions and
  36. positional indices.
  37. Args:
  38. tensor_or_dim: Input tensor with first-class dimensions, or a Dim object
  39. *dims: Dimensions or sequences of dimensions specifying the new order
  40. Returns:
  41. Tensor with reordered dimensions
  42. Examples:
  43. >>> import torch
  44. >>> from functorch.dim import dims
  45. >>> batch, channel, height, width = dims(4)
  46. >>> x = torch.randn(2, 3, 4, 5)[batch, channel, height, width]
  47. >>> # Reorder to [height, width, batch, channel]
  48. >>> y = order(x, height, width, batch, channel)
  49. """
  50. from . import Dim, DimList, Tensor
  51. # Handle first argument - tensor or dimension
  52. if isinstance(tensor_or_dim, Tensor):
  53. # First-class tensor
  54. orig_levels = tensor_or_dim._levels[:]
  55. data = tensor_or_dim._tensor
  56. has_device = tensor_or_dim._has_device
  57. elif isinstance(tensor_or_dim, Dim):
  58. # Single dimension - create range tensor
  59. orig_levels = [DimEntry(tensor_or_dim)]
  60. data = tensor_or_dim._get_range()
  61. has_device = False
  62. else:
  63. raise ValueError("First argument must be a Tensor or Dim object")
  64. flat_positional_dims = []
  65. to_flatten = [] # List of (start_index, length) pairs for flattening
  66. levels = orig_levels[:]
  67. orig_ndim = ndim_of_levels(levels)
  68. def append_dim(d: DimEntry) -> None:
  69. """Add a dimension to the reordering, removing it from available levels."""
  70. try:
  71. idx = levels.index(d)
  72. except ValueError:
  73. idx = None
  74. if idx is None:
  75. if d.is_positional():
  76. raise ValueError(
  77. f"tensor has {orig_ndim} positional dimensions, but {d.position() + orig_ndim} specified, "
  78. f"or it was specified twice"
  79. )
  80. else:
  81. raise ValueError(
  82. f"tensor does not contain dim {d.dim()} or it was specified twice"
  83. )
  84. levels[idx] = DimEntry()
  85. flat_positional_dims.append(d)
  86. n_new_positional = 0
  87. # Process each dimension argument
  88. for arg in dims:
  89. entry = _wrap_dim(arg, orig_ndim, False)
  90. if not entry.is_none():
  91. append_dim(entry)
  92. n_new_positional += 1
  93. elif isinstance(arg, DimList):
  94. # Handle DimList
  95. for dim in arg._dims:
  96. append_dim(DimEntry(dim))
  97. n_new_positional += 1
  98. else:
  99. # Handle sequences of dimensions for flattening
  100. n_new_positional += 1
  101. if not hasattr(arg, "__iter__"):
  102. raise ValueError("expected a Dim, List[Dim], or Sequence[Dim]")
  103. # Convert to list to get length
  104. seq = list(arg)
  105. to_flatten.append((len(flat_positional_dims), len(seq)))
  106. for item in seq:
  107. entry = _wrap_dim(item, orig_ndim, False)
  108. if entry.is_none():
  109. raise ValueError("expected a Dim or int")
  110. append_dim(entry)
  111. # Build new level ordering
  112. insert_point = -1
  113. new_levels: list[DimEntry] = []
  114. # Add remaining (non-reordered) levels, finding insertion point for new dimensions
  115. for level in levels:
  116. if level.is_none():
  117. continue
  118. if level.is_positional():
  119. if insert_point == -1:
  120. insert_point = len(new_levels)
  121. new_levels.extend(flat_positional_dims)
  122. new_levels.append(level)
  123. # If no positional dimensions found, append new dims at the end
  124. if insert_point == -1:
  125. insert_point = len(new_levels)
  126. new_levels.extend(flat_positional_dims)
  127. # Match tensor to new level structure
  128. if data is None:
  129. raise AssertionError("Cannot reorder None tensor")
  130. ndata = _match_levels(data, orig_levels, new_levels)
  131. # Handle dimension flattening if requested
  132. if to_flatten:
  133. # Now build the reshape target
  134. view_shape = []
  135. sizes = ndata.size()
  136. # Add dimensions before the reordered ones
  137. for i in range(insert_point):
  138. view_shape.append(sizes[i])
  139. # Process flattening groups
  140. i = 0
  141. for start_idx, length in to_flatten:
  142. # Add individual dims before this flattening group
  143. while i < start_idx:
  144. view_shape.append(sizes[insert_point + i])
  145. i += 1
  146. # Flatten the group
  147. new_size = 1
  148. for j in range(length):
  149. new_size *= sizes[insert_point + i + j]
  150. view_shape.append(new_size)
  151. i += length
  152. # Add remaining individual dims
  153. while i < len(flat_positional_dims):
  154. view_shape.append(sizes[insert_point + i])
  155. i += 1
  156. # Add dimensions after the reordered ones
  157. for i in range(insert_point + len(flat_positional_dims), len(levels)):
  158. view_shape.append(sizes[i])
  159. # Update levels by removing flattened dimensions
  160. n_to_remove = len(flat_positional_dims) - n_new_positional
  161. if n_to_remove > 0:
  162. # Remove flattened levels
  163. new_levels = (
  164. new_levels[:insert_point] + new_levels[insert_point + n_to_remove :]
  165. )
  166. ndata = ndata.reshape(view_shape)
  167. # Renumber positional dimensions (negative indexing from the right)
  168. seen = 0
  169. for i in range(len(new_levels) - 1, -1, -1):
  170. if new_levels[i].is_positional() or (
  171. i >= insert_point and i < insert_point + n_new_positional
  172. ):
  173. seen -= 1
  174. new_levels[i] = DimEntry(seen)
  175. result = Tensor.from_positional(ndata, new_levels, has_device)
  176. return result # type: ignore[return-value]