| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- """
- Python implementation of function wrapping functionality for functorch.dim.
- """
- from __future__ import annotations
- import functools
- from typing import Any, Optional, TYPE_CHECKING
- import torch
- from torch.utils._pytree import tree_map
- from ._dim_entry import DimEntry
- from ._enable_all_layers import EnableAllLayers
- from ._tensor_info import TensorInfo
- if TYPE_CHECKING:
- from collections.abc import Callable
- def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
- """Handle tensor conversion for torch function integration."""
- return tensor
- class WrappedOperator:
- """
- This class wraps PyTorch operations to support first-class dimensions.
- """
- def __init__(
- self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
- ):
- self.orig = orig
- self.wrapper_implementation = wrapper_implementation
- self.name = getattr(orig, "__name__", "")
- self.doc = getattr(orig, "__doc__", None)
- self.dim_name = dim_name
- self.is_pointwise = False
- self.dim_offset = 0
- self.keepdim_offset = 1
- self.single_dim = False
- self.reduce = True
- # Update docstring if we have a dim_name
- if self.doc and self.dim_name:
- self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
- def function(self) -> Callable:
- """Create a wrapped function that calls our wrapper implementation."""
- def wrapped_func(*args: Any, **kwargs: Any) -> Any:
- return self.wrapper_implementation(self, *args, **kwargs)
- # Copy metadata using functools.update_wrapper for just __name__ and __doc__
- functools.update_wrapper(
- wrapped_func, self.orig, assigned=("__name__",), updated=()
- )
- wrapped_func.__doc__ = self.doc
- return wrapped_func
- def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
- """Convert single dimension specification to DimEntry object."""
- from . import Dim
- if isinstance(dim, Dim):
- if keepdim:
- raise ValueError("cannot preserve first-class dimensions with keepdim=True")
- return DimEntry(dim)
- elif isinstance(dim, int):
- i = dim
- while i >= 0:
- i -= ndim
- return DimEntry(i)
- else:
- return DimEntry()
- def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
- """Convert dimension specification to list of DimEntry objects."""
- de = _wrap_dim(dim, ndim, keepdim)
- result = []
- if not de.is_none():
- result.append(de)
- else:
- for d in dim:
- result.append(_wrap_dim(d, ndim, keepdim))
- return result
- def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
- """
- This is the core method that handles dimension-aware operations.
- """
- if not args:
- raise ValueError("Expected at least one argument (self)")
- # Get dimension argument
- dim_arg = kwargs.get(wrapper.dim_name)
- if dim_arg is None and wrapper.dim_offset < len(args):
- # Try to get dim from positional args (accounting for self at index 0)
- dim_idx = wrapper.dim_offset + 1
- if dim_idx < len(args):
- dim_arg = args[dim_idx]
- # If no dimension argument provided, fall back to standard functorch handling
- if dim_arg is None:
- info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
- if not info:
- return wrapper.orig(*args, **kwargs)
- with EnableAllLayers(info.levels) as guard:
- if info.batchedtensor is None:
- raise AssertionError("Expected batchedtensor to be non-None")
- guard.inplace_update_layers(info.batchedtensor, info.levels)
- new_args = list(args)
- new_args[0] = handle_from_tensor(info.batchedtensor)
- result = wrapper.orig(*new_args, **kwargs)
- return guard.from_batched(result, info.has_device)
- # Handle dimension-aware operation
- info = TensorInfo.create(args[0])
- if not info:
- return wrapper.orig(*args, **kwargs)
- # Check for keepdim parameter
- keepdim = False
- if wrapper.reduce:
- keepdim_arg = kwargs.get("keepdim")
- if keepdim_arg is None and wrapper.keepdim_offset < len(args):
- keepdim_idx = wrapper.keepdim_offset + 1
- if keepdim_idx < len(args):
- keepdim_arg = args[keepdim_idx]
- if keepdim_arg is not None:
- keepdim = bool(keepdim_arg)
- # Wrap dimensions
- ndim = info.ndim()
- dims = _wrap_dims(dim_arg, ndim, keepdim)
- # Convert dimensions to indices and validate
- dim_indices: list[int] = []
- seen = [False] * len(info.levels)
- for d in dims:
- midx = None
- for i, level in enumerate(info.levels):
- if level == d:
- midx = i
- break
- if midx is None:
- # Try to match by position/name more flexibly
- for i, level in enumerate(info.levels):
- if hasattr(level, "matches") and level.matches(d):
- midx = i
- break
- if midx is None:
- level_strs = [str(level) for level in info.levels]
- raise ValueError(
- f"Tensor with dimensions {level_strs} does not contain {d}"
- )
- seen[midx] = True
- dim_indices.append(midx)
- # Determine new levels after reduction
- new_levels = []
- if wrapper.reduce and not keepdim:
- for i, level in enumerate(info.levels):
- if not seen[i]:
- new_levels.append(level)
- else:
- new_levels = info.levels[:]
- # Create dimension indices for the original function
- if len(dim_indices) == 1:
- py_indices: Any = dim_indices[0]
- else:
- py_indices = tuple(dim_indices)
- # Update arguments
- new_args = list(args)
- new_kwargs = kwargs.copy()
- if info.tensor is None:
- raise AssertionError("Expected tensor to be non-None")
- new_args[0] = handle_from_tensor(info.tensor)
- # Update dimension argument
- if wrapper.dim_name in new_kwargs:
- new_kwargs[wrapper.dim_name] = py_indices
- else:
- dim_idx = wrapper.dim_offset + 1
- if dim_idx < len(new_args):
- new_args = list(new_args)
- new_args[dim_idx] = py_indices
- # Call original function
- result = wrapper.orig(*new_args, **new_kwargs)
- # Wrap results
- def wrap_result(obj: Any) -> Any:
- if isinstance(obj, torch.Tensor):
- from . import Tensor
- return Tensor.from_positional(obj, new_levels, info.has_device)
- return obj
- return tree_map(wrap_result, result)
- def _wrap(
- orig: Callable,
- dim_offset: Optional[int] = None,
- keepdim_offset: Optional[int] = None,
- dim_name: Optional[str] = None,
- single_dim: Optional[bool] = None,
- reduce: Optional[bool] = None,
- ) -> Callable:
- """
- Wrap a PyTorch function to support first-class dimensions.
- Args:
- orig: Original function to wrap
- dim_offset: Offset for dimension argument (default: 0)
- keepdim_offset: Offset for keepdim argument (default: 1)
- dim_name: Name of dimension parameter (default: "dim")
- single_dim: Whether function takes single dimension (default: False)
- reduce: Whether function reduces dimensions (default: True)
- """
- dim_name = dim_name or "dim"
- wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
- if dim_offset is not None:
- wrapper.dim_offset = dim_offset
- if keepdim_offset is not None:
- wrapper.keepdim_offset = keepdim_offset
- if single_dim is not None:
- wrapper.single_dim = single_dim
- if reduce is not None:
- wrapper.reduce = reduce
- return wrapper.function()
- def call_torch_function(
- wrapper: WrappedOperator,
- func: Callable,
- types: tuple,
- args: tuple = (),
- kwargs: Optional[dict] = None,
- ) -> Any:
- """
- Handle __torch_function__ calls for wrapped operators.
- """
- if kwargs is None:
- kwargs = {}
- # Import here to avoid circular imports
- from . import _Tensor
- # Use the torch function mechanism from _Tensor
- return _Tensor.__torch_function__(func, types, args, kwargs)
|