| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- from __future__ import annotations
- from typing import Any, TYPE_CHECKING, Union
- if TYPE_CHECKING:
- from collections.abc import Sequence
- import torch # noqa: TC002
- from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
- def _wrap_dim(arg: Any, orig_ndim: int, allow_none: bool = True) -> DimEntry:
- """
- Convert various dimension representations to DimEntry.
- Args:
- arg: The argument to convert (Dim, int, or other)
- orig_ndim: Original number of dimensions
- allow_none: Whether to allow None values
- Returns:
- DimEntry representation of the dimension
- """
- from . import Dim
- if arg is None and allow_none:
- return DimEntry() # None entry
- elif isinstance(arg, Dim):
- return DimEntry(arg)
- elif isinstance(arg, int):
- if arg < 0:
- pos = arg
- else:
- pos = arg - orig_ndim
- return DimEntry(pos)
- else:
- return DimEntry()
- def order(
- tensor_or_dim: Union[torch.Tensor, Any], *dims: Union[Any, Sequence[Any]]
- ) -> torch.Tensor:
- """
- Reorder the dimensions of a tensor or create a tensor from a dimension.
- It allows reordering tensor dimensions using first-class dimensions and
- positional indices.
- Args:
- tensor_or_dim: Input tensor with first-class dimensions, or a Dim object
- *dims: Dimensions or sequences of dimensions specifying the new order
- Returns:
- Tensor with reordered dimensions
- Examples:
- >>> import torch
- >>> from functorch.dim import dims
- >>> batch, channel, height, width = dims(4)
- >>> x = torch.randn(2, 3, 4, 5)[batch, channel, height, width]
- >>> # Reorder to [height, width, batch, channel]
- >>> y = order(x, height, width, batch, channel)
- """
- from . import Dim, DimList, Tensor
- # Handle first argument - tensor or dimension
- if isinstance(tensor_or_dim, Tensor):
- # First-class tensor
- orig_levels = tensor_or_dim._levels[:]
- data = tensor_or_dim._tensor
- has_device = tensor_or_dim._has_device
- elif isinstance(tensor_or_dim, Dim):
- # Single dimension - create range tensor
- orig_levels = [DimEntry(tensor_or_dim)]
- data = tensor_or_dim._get_range()
- has_device = False
- else:
- raise ValueError("First argument must be a Tensor or Dim object")
- flat_positional_dims = []
- to_flatten = [] # List of (start_index, length) pairs for flattening
- levels = orig_levels[:]
- orig_ndim = ndim_of_levels(levels)
- def append_dim(d: DimEntry) -> None:
- """Add a dimension to the reordering, removing it from available levels."""
- try:
- idx = levels.index(d)
- except ValueError:
- idx = None
- if idx is None:
- if d.is_positional():
- raise ValueError(
- f"tensor has {orig_ndim} positional dimensions, but {d.position() + orig_ndim} specified, "
- f"or it was specified twice"
- )
- else:
- raise ValueError(
- f"tensor does not contain dim {d.dim()} or it was specified twice"
- )
- levels[idx] = DimEntry()
- flat_positional_dims.append(d)
- n_new_positional = 0
- # Process each dimension argument
- for arg in dims:
- entry = _wrap_dim(arg, orig_ndim, False)
- if not entry.is_none():
- append_dim(entry)
- n_new_positional += 1
- elif isinstance(arg, DimList):
- # Handle DimList
- for dim in arg._dims:
- append_dim(DimEntry(dim))
- n_new_positional += 1
- else:
- # Handle sequences of dimensions for flattening
- n_new_positional += 1
- if not hasattr(arg, "__iter__"):
- raise ValueError("expected a Dim, List[Dim], or Sequence[Dim]")
- # Convert to list to get length
- seq = list(arg)
- to_flatten.append((len(flat_positional_dims), len(seq)))
- for item in seq:
- entry = _wrap_dim(item, orig_ndim, False)
- if entry.is_none():
- raise ValueError("expected a Dim or int")
- append_dim(entry)
- # Build new level ordering
- insert_point = -1
- new_levels: list[DimEntry] = []
- # Add remaining (non-reordered) levels, finding insertion point for new dimensions
- for level in levels:
- if level.is_none():
- continue
- if level.is_positional():
- if insert_point == -1:
- insert_point = len(new_levels)
- new_levels.extend(flat_positional_dims)
- new_levels.append(level)
- # If no positional dimensions found, append new dims at the end
- if insert_point == -1:
- insert_point = len(new_levels)
- new_levels.extend(flat_positional_dims)
- # Match tensor to new level structure
- if data is None:
- raise AssertionError("Cannot reorder None tensor")
- ndata = _match_levels(data, orig_levels, new_levels)
- # Handle dimension flattening if requested
- if to_flatten:
- # Now build the reshape target
- view_shape = []
- sizes = ndata.size()
- # Add dimensions before the reordered ones
- for i in range(insert_point):
- view_shape.append(sizes[i])
- # Process flattening groups
- i = 0
- for start_idx, length in to_flatten:
- # Add individual dims before this flattening group
- while i < start_idx:
- view_shape.append(sizes[insert_point + i])
- i += 1
- # Flatten the group
- new_size = 1
- for j in range(length):
- new_size *= sizes[insert_point + i + j]
- view_shape.append(new_size)
- i += length
- # Add remaining individual dims
- while i < len(flat_positional_dims):
- view_shape.append(sizes[insert_point + i])
- i += 1
- # Add dimensions after the reordered ones
- for i in range(insert_point + len(flat_positional_dims), len(levels)):
- view_shape.append(sizes[i])
- # Update levels by removing flattened dimensions
- n_to_remove = len(flat_positional_dims) - n_new_positional
- if n_to_remove > 0:
- # Remove flattened levels
- new_levels = (
- new_levels[:insert_point] + new_levels[insert_point + n_to_remove :]
- )
- ndata = ndata.reshape(view_shape)
- # Renumber positional dimensions (negative indexing from the right)
- seen = 0
- for i in range(len(new_levels) - 1, -1, -1):
- if new_levels[i].is_positional() or (
- i >= insert_point and i < insert_point + n_new_positional
- ):
- seen -= 1
- new_levels[i] = DimEntry(seen)
- result = Tensor.from_positional(ndata, new_levels, has_device)
- return result # type: ignore[return-value]
|