_wrap.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """
  2. Python implementation of function wrapping functionality for functorch.dim.
  3. """
  4. from __future__ import annotations
  5. import functools
  6. from typing import Any, Optional, TYPE_CHECKING
  7. import torch
  8. from torch.utils._pytree import tree_map
  9. from ._dim_entry import DimEntry
  10. from ._enable_all_layers import EnableAllLayers
  11. from ._tensor_info import TensorInfo
  12. if TYPE_CHECKING:
  13. from collections.abc import Callable
  14. def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
  15. """Handle tensor conversion for torch function integration."""
  16. return tensor
  17. class WrappedOperator:
  18. """
  19. This class wraps PyTorch operations to support first-class dimensions.
  20. """
  21. def __init__(
  22. self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
  23. ):
  24. self.orig = orig
  25. self.wrapper_implementation = wrapper_implementation
  26. self.name = getattr(orig, "__name__", "")
  27. self.doc = getattr(orig, "__doc__", None)
  28. self.dim_name = dim_name
  29. self.is_pointwise = False
  30. self.dim_offset = 0
  31. self.keepdim_offset = 1
  32. self.single_dim = False
  33. self.reduce = True
  34. # Update docstring if we have a dim_name
  35. if self.doc and self.dim_name:
  36. self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
  37. def function(self) -> Callable:
  38. """Create a wrapped function that calls our wrapper implementation."""
  39. def wrapped_func(*args: Any, **kwargs: Any) -> Any:
  40. return self.wrapper_implementation(self, *args, **kwargs)
  41. # Copy metadata using functools.update_wrapper for just __name__ and __doc__
  42. functools.update_wrapper(
  43. wrapped_func, self.orig, assigned=("__name__",), updated=()
  44. )
  45. wrapped_func.__doc__ = self.doc
  46. return wrapped_func
  47. def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
  48. """Convert single dimension specification to DimEntry object."""
  49. from . import Dim
  50. if isinstance(dim, Dim):
  51. if keepdim:
  52. raise ValueError("cannot preserve first-class dimensions with keepdim=True")
  53. return DimEntry(dim)
  54. elif isinstance(dim, int):
  55. i = dim
  56. while i >= 0:
  57. i -= ndim
  58. return DimEntry(i)
  59. else:
  60. return DimEntry()
  61. def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
  62. """Convert dimension specification to list of DimEntry objects."""
  63. de = _wrap_dim(dim, ndim, keepdim)
  64. result = []
  65. if not de.is_none():
  66. result.append(de)
  67. else:
  68. for d in dim:
  69. result.append(_wrap_dim(d, ndim, keepdim))
  70. return result
  71. def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
  72. """
  73. This is the core method that handles dimension-aware operations.
  74. """
  75. if not args:
  76. raise ValueError("Expected at least one argument (self)")
  77. # Get dimension argument
  78. dim_arg = kwargs.get(wrapper.dim_name)
  79. if dim_arg is None and wrapper.dim_offset < len(args):
  80. # Try to get dim from positional args (accounting for self at index 0)
  81. dim_idx = wrapper.dim_offset + 1
  82. if dim_idx < len(args):
  83. dim_arg = args[dim_idx]
  84. # If no dimension argument provided, fall back to standard functorch handling
  85. if dim_arg is None:
  86. info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
  87. if not info:
  88. return wrapper.orig(*args, **kwargs)
  89. with EnableAllLayers(info.levels) as guard:
  90. if info.batchedtensor is None:
  91. raise AssertionError("Expected batchedtensor to be non-None")
  92. guard.inplace_update_layers(info.batchedtensor, info.levels)
  93. new_args = list(args)
  94. new_args[0] = handle_from_tensor(info.batchedtensor)
  95. result = wrapper.orig(*new_args, **kwargs)
  96. return guard.from_batched(result, info.has_device)
  97. # Handle dimension-aware operation
  98. info = TensorInfo.create(args[0])
  99. if not info:
  100. return wrapper.orig(*args, **kwargs)
  101. # Check for keepdim parameter
  102. keepdim = False
  103. if wrapper.reduce:
  104. keepdim_arg = kwargs.get("keepdim")
  105. if keepdim_arg is None and wrapper.keepdim_offset < len(args):
  106. keepdim_idx = wrapper.keepdim_offset + 1
  107. if keepdim_idx < len(args):
  108. keepdim_arg = args[keepdim_idx]
  109. if keepdim_arg is not None:
  110. keepdim = bool(keepdim_arg)
  111. # Wrap dimensions
  112. ndim = info.ndim()
  113. dims = _wrap_dims(dim_arg, ndim, keepdim)
  114. # Convert dimensions to indices and validate
  115. dim_indices: list[int] = []
  116. seen = [False] * len(info.levels)
  117. for d in dims:
  118. midx = None
  119. for i, level in enumerate(info.levels):
  120. if level == d:
  121. midx = i
  122. break
  123. if midx is None:
  124. # Try to match by position/name more flexibly
  125. for i, level in enumerate(info.levels):
  126. if hasattr(level, "matches") and level.matches(d):
  127. midx = i
  128. break
  129. if midx is None:
  130. level_strs = [str(level) for level in info.levels]
  131. raise ValueError(
  132. f"Tensor with dimensions {level_strs} does not contain {d}"
  133. )
  134. seen[midx] = True
  135. dim_indices.append(midx)
  136. # Determine new levels after reduction
  137. new_levels = []
  138. if wrapper.reduce and not keepdim:
  139. for i, level in enumerate(info.levels):
  140. if not seen[i]:
  141. new_levels.append(level)
  142. else:
  143. new_levels = info.levels[:]
  144. # Create dimension indices for the original function
  145. if len(dim_indices) == 1:
  146. py_indices: Any = dim_indices[0]
  147. else:
  148. py_indices = tuple(dim_indices)
  149. # Update arguments
  150. new_args = list(args)
  151. new_kwargs = kwargs.copy()
  152. if info.tensor is None:
  153. raise AssertionError("Expected tensor to be non-None")
  154. new_args[0] = handle_from_tensor(info.tensor)
  155. # Update dimension argument
  156. if wrapper.dim_name in new_kwargs:
  157. new_kwargs[wrapper.dim_name] = py_indices
  158. else:
  159. dim_idx = wrapper.dim_offset + 1
  160. if dim_idx < len(new_args):
  161. new_args = list(new_args)
  162. new_args[dim_idx] = py_indices
  163. # Call original function
  164. result = wrapper.orig(*new_args, **new_kwargs)
  165. # Wrap results
  166. def wrap_result(obj: Any) -> Any:
  167. if isinstance(obj, torch.Tensor):
  168. from . import Tensor
  169. return Tensor.from_positional(obj, new_levels, info.has_device)
  170. return obj
  171. return tree_map(wrap_result, result)
  172. def _wrap(
  173. orig: Callable,
  174. dim_offset: Optional[int] = None,
  175. keepdim_offset: Optional[int] = None,
  176. dim_name: Optional[str] = None,
  177. single_dim: Optional[bool] = None,
  178. reduce: Optional[bool] = None,
  179. ) -> Callable:
  180. """
  181. Wrap a PyTorch function to support first-class dimensions.
  182. Args:
  183. orig: Original function to wrap
  184. dim_offset: Offset for dimension argument (default: 0)
  185. keepdim_offset: Offset for keepdim argument (default: 1)
  186. dim_name: Name of dimension parameter (default: "dim")
  187. single_dim: Whether function takes single dimension (default: False)
  188. reduce: Whether function reduces dimensions (default: True)
  189. """
  190. dim_name = dim_name or "dim"
  191. wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
  192. if dim_offset is not None:
  193. wrapper.dim_offset = dim_offset
  194. if keepdim_offset is not None:
  195. wrapper.keepdim_offset = keepdim_offset
  196. if single_dim is not None:
  197. wrapper.single_dim = single_dim
  198. if reduce is not None:
  199. wrapper.reduce = reduce
  200. return wrapper.function()
  201. def call_torch_function(
  202. wrapper: WrappedOperator,
  203. func: Callable,
  204. types: tuple,
  205. args: tuple = (),
  206. kwargs: Optional[dict] = None,
  207. ) -> Any:
  208. """
  209. Handle __torch_function__ calls for wrapped operators.
  210. """
  211. if kwargs is None:
  212. kwargs = {}
  213. # Import here to avoid circular imports
  214. from . import _Tensor
  215. # Use the torch function mechanism from _Tensor
  216. return _Tensor.__torch_function__(func, types, args, kwargs)