__init__.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603
  1. from __future__ import annotations
  2. import dis
  3. import inspect
  4. import sys
  5. from typing import Any, Optional, TYPE_CHECKING, Union
  6. if TYPE_CHECKING:
  7. from collections.abc import Callable, Sequence
  8. import torch
  9. from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
  10. from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
  11. from ._enable_all_layers import EnableAllLayers
  12. from ._py_inst_decoder import _PyInstDecoder
  13. from ._tensor_info import TensorInfo
  14. POINTWISE_OPTIMIZE = True
  15. DOT_OPTIMIZED = True
  16. # Global dimension level counter
  17. _n_dims_created = 0
  18. def _relevant_op(opcode: Optional[str]) -> bool:
  19. """Check if opcode is relevant for variable assignment."""
  20. return bool(opcode and opcode.startswith("STORE_"))
  21. def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
  22. """Handle tensor conversion for torch function integration."""
  23. return tensor
  24. def _create_dim(name: str, size: Optional[int] = None) -> Dim:
  25. """Create a new Dim object."""
  26. return Dim(name, size if size is not None else -1)
  27. def dims(
  28. n: Optional[int] = None, sizes: Optional[list[Optional[int]]] = None
  29. ) -> Union[Dim, tuple[Dim, ...]]:
  30. """
  31. Create and return one or more Dim objects.
  32. Uses bytecode inspection to determine variable names when possible.
  33. Args:
  34. n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified.
  35. sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be
  36. created, specifying each dimensions size, or None to leave the size unset.
  37. Returns:
  38. Union[Dim, Tuple[Dim, ...]]: Single Dim if n=1, tuple of Dims otherwise.
  39. Examples:
  40. >>> batch, channel, width, height = dims(4)
  41. >>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224])
  42. >>> single_dim = dims(1)
  43. """
  44. specified_ndims = -1
  45. found_ndims = 0
  46. # Parse arguments
  47. if sizes is not None:
  48. specified_ndims = len(sizes)
  49. if n is not None:
  50. specified_ndims = n
  51. # Use bytecode inspection
  52. frame = inspect.currentframe()
  53. if frame is None:
  54. raise RuntimeError("Unable to get current frame")
  55. frame = frame.f_back
  56. try:
  57. if frame is None:
  58. raise RuntimeError("Unable to get caller frame")
  59. code = frame.f_code
  60. lasti = frame.f_lasti
  61. decoder = _PyInstDecoder(code, lasti)
  62. if sys.version_info >= (3, 11):
  63. if decoder.opcode() == "PRECALL":
  64. decoder.next()
  65. # Move to next instruction after the call
  66. decoder.next()
  67. # Determine number of dimensions from bytecode
  68. if _relevant_op(decoder.opcode()):
  69. found_ndims = 1
  70. elif decoder.opcode() == "UNPACK_SEQUENCE":
  71. found_ndims = decoder.oparg()
  72. decoder.next() # Move past UNPACK_SEQUENCE
  73. if specified_ndims == -1:
  74. if found_ndims == 0:
  75. raise SyntaxError(
  76. "dims() must be assigned to a sequence of variable names or have argument n specified"
  77. )
  78. specified_ndims = found_ndims
  79. if found_ndims != specified_ndims:
  80. found_ndims = 0
  81. def genobject(i: int) -> Dim:
  82. nonlocal found_ndims
  83. name = None
  84. if i < found_ndims:
  85. name = decoder.name()
  86. if not name:
  87. name = f"d{i}"
  88. found_ndims = 0
  89. else:
  90. decoder.next() # Move to next STORE instruction
  91. size = sizes[i] if sizes is not None else None
  92. return _create_dim(name, size)
  93. # Validate sizes parameter
  94. if sizes is not None and len(sizes) != specified_ndims:
  95. raise ValueError(f"expected {specified_ndims} sizes but found {len(sizes)}")
  96. if specified_ndims == 1:
  97. return genobject(0)
  98. result = []
  99. for i in range(specified_ndims):
  100. result.append(genobject(i))
  101. return tuple(result)
  102. finally:
  103. del frame
  104. class DimList:
  105. """
  106. A list of first-class dimensions that can be bound to tensor dimensions.
  107. A DimList can be in one of two states:
  108. 1. Unbound: Created with just a name, no specific dimensions yet
  109. 2. Bound: Either created with specific dimensions/sizes, or bound later via bind() or bind_len()
  110. """
  111. _name: Optional[str]
  112. _dims: list[Dim]
  113. _bound: bool
  114. def __init__(
  115. self,
  116. len_or_dims: Optional[Union[int, Sequence]] = None,
  117. name: Optional[str] = None,
  118. ):
  119. """
  120. Initialize a new DimList object.
  121. Args:
  122. len_or_dims: Optional length (int) or sequence of dimensions/sizes
  123. name: Optional name for the dimension list
  124. """
  125. # Initialize attributes
  126. self._name = name
  127. self._dims: list = []
  128. self._bound = False
  129. if isinstance(len_or_dims, int):
  130. self.bind_len(len_or_dims)
  131. elif len_or_dims is not None:
  132. dims = []
  133. for i, item in enumerate(len_or_dims):
  134. if isinstance(item, int):
  135. dim_name = f"{self._name}{i}" if self._name else f"dim{i}"
  136. dims.append(Dim(dim_name, item))
  137. else:
  138. dims.append(Dim(item))
  139. self._set_dims(dims)
  140. def _set_dims(self, dims: list) -> None:
  141. """Set the dimensions and mark as bound."""
  142. self._bound = True
  143. self._dims = dims
  144. def bind_len(self, size: int) -> None:
  145. """
  146. Bind this DimList to a specific length.
  147. Args:
  148. size: Number of dimensions to bind to
  149. Raises:
  150. DimensionBindError: If already bound to a different size
  151. """
  152. if self._bound:
  153. if len(self._dims) != size:
  154. raise DimensionBindError(
  155. f"Dimlist has size {len(self._dims)} but it is being bound to size {size}"
  156. )
  157. else:
  158. self._bound = True
  159. self._dims = []
  160. for i in range(size):
  161. dim_name = f"{self._name}{i}" if self._name else f"dim{i}"
  162. self._dims.append(Dim(dim_name))
  163. def bind(self, sizes: Sequence[int]) -> None:
  164. """
  165. Bind this DimList to specific sizes.
  166. Args:
  167. sizes: Sequence of sizes for each dimension
  168. Raises:
  169. ValueError: If sizes is not a sequence
  170. """
  171. if not hasattr(sizes, "__len__") or not hasattr(sizes, "__getitem__"):
  172. raise ValueError("expected a sequence")
  173. size = len(sizes)
  174. self.bind_len(size)
  175. for i, dim_size in enumerate(sizes):
  176. self._dims[i].size = int(dim_size)
  177. def _size(self) -> int:
  178. if not self._bound:
  179. raise DimensionBindError("DimList not bound")
  180. return len(self._dims)
  181. def size(self) -> int:
  182. """Return the size (number of dimensions) of this DimList."""
  183. return self._size()
  184. def _set_bound(self, b: bool) -> None:
  185. """Set the bound status (for internal use)."""
  186. self._bound = b
  187. @property
  188. def is_bound(self) -> bool:
  189. """Property to check if DimList is bound."""
  190. return self._bound
  191. def __len__(self) -> int:
  192. """Return the length of the DimList."""
  193. return self.size()
  194. def __getitem__(self, key: Union[int, slice]) -> Union[Dim, tuple[Dim, ...]]:
  195. if not self._bound:
  196. raise DimensionBindError("DimList not bound")
  197. if isinstance(key, int):
  198. if key < 0 or key >= len(self._dims):
  199. raise IndexError("index out of bounds")
  200. return self._dims[key]
  201. elif isinstance(key, slice):
  202. start, stop, step = key.indices(len(self._dims))
  203. result = []
  204. for i in range(start, stop, step):
  205. result.append(self._dims[i])
  206. return tuple(result)
  207. else:
  208. raise ValueError("expected an int or a slice")
  209. def __repr__(self) -> str:
  210. """Return string representation of the DimList."""
  211. if self._bound:
  212. # Show as tuple representation
  213. return f"({', '.join(repr(dim) for dim in self._dims)})"
  214. elif self._name is not None:
  215. # Show as *name for unbound with name
  216. return f"*{self._name}"
  217. else:
  218. # Show as <unbound_dimlist> for unbound without name
  219. return "<unbound_dimlist>"
  220. def __str__(self) -> str:
  221. """Return string representation of the DimList."""
  222. return self.__repr__()
  223. @classmethod
  224. def __torch_function__(
  225. cls,
  226. func: Callable,
  227. types: tuple,
  228. args: tuple = (),
  229. kwargs: Optional[dict] = None,
  230. ) -> Any:
  231. return _Tensor.__torch_function__(func, types, args, kwargs)
  232. def _create_dimlist(
  233. name: str, size: Optional[Union[int, list[Optional[int]]]] = None
  234. ) -> DimList:
  235. """Create a DimList object with the given name and optional size."""
  236. dimlist = DimList(name=name)
  237. if size is not None:
  238. if isinstance(size, int):
  239. dimlist.bind_len(size)
  240. else:
  241. # size is a list of optional ints
  242. dimlist.bind_len(len(size))
  243. for i, s in enumerate(size):
  244. if s is not None:
  245. dimlist._dims[i].size = s
  246. return dimlist
  247. def dimlists(
  248. n: Optional[int] = None, sizes: Optional[list[Optional[int]]] = None
  249. ) -> Union[DimList, tuple[DimList, ...]]:
  250. """
  251. Create and return one or more DimList objects.
  252. Similar to dims() but creates DimList objects instead.
  253. """
  254. specified_ndims = -1
  255. found_ndims = 0
  256. # Parse arguments
  257. if sizes is not None:
  258. specified_ndims = len(sizes)
  259. if n is not None:
  260. specified_ndims = n
  261. frame = inspect.currentframe()
  262. if frame is None:
  263. raise RuntimeError("Unable to get current frame")
  264. frame = frame.f_back
  265. try:
  266. if frame is None:
  267. raise RuntimeError("Unable to get caller frame")
  268. code = frame.f_code
  269. lasti = frame.f_lasti
  270. decoder = _PyInstDecoder(code, lasti)
  271. if sys.version_info >= (3, 11):
  272. if decoder.opcode() == "PRECALL":
  273. decoder.next()
  274. # Move to next instruction after the call
  275. decoder.next()
  276. # Determine number of dimensions from bytecode
  277. if _relevant_op(decoder.opcode()):
  278. found_ndims = 1
  279. elif decoder.opcode() == "UNPACK_SEQUENCE":
  280. found_ndims = decoder.oparg()
  281. decoder.next() # Move past UNPACK_SEQUENCE
  282. if specified_ndims == -1:
  283. if found_ndims == 0:
  284. raise SyntaxError(
  285. "dimlists() must be assigned to a sequence of variable names or have argument n specified"
  286. )
  287. specified_ndims = found_ndims
  288. if found_ndims != specified_ndims:
  289. found_ndims = 0
  290. # Generator function for dimlist names
  291. def genobject(i: int) -> str:
  292. nonlocal found_ndims
  293. name = None
  294. if i < found_ndims:
  295. name = decoder.name()
  296. if not name:
  297. name = f"d{i}"
  298. found_ndims = 0
  299. else:
  300. decoder.next() # Move to next STORE instruction
  301. return name
  302. # Validate sizes
  303. if sizes is not None and len(sizes) != specified_ndims:
  304. raise ValueError(f"expected {specified_ndims} sizes but found {len(sizes)}")
  305. # Create dimlists
  306. if specified_ndims == 1:
  307. name = genobject(0)
  308. return _create_dimlist(name, sizes[0] if sizes is not None else None)
  309. result = []
  310. for i in range(specified_ndims):
  311. name = genobject(i)
  312. size = sizes[i] if sizes is not None else None
  313. result.append(_create_dimlist(name, size))
  314. return tuple(result)
  315. finally:
  316. del frame
  317. class DimensionMismatchError(Exception):
  318. pass
  319. class DimensionBindError(Exception):
  320. pass
  321. from . import op_properties
  322. def _safe_print(*args: Any, **kwargs: Any) -> None:
  323. """Safe print that avoids recursive torch function dispatches."""
  324. import sys
  325. # Convert any torch objects to basic representations
  326. safe_args = []
  327. for arg in args:
  328. if hasattr(arg, "__class__") and "torch" in str(type(arg)):
  329. safe_args.append(f"<{type(arg).__name__}>")
  330. else:
  331. safe_args.append(str(arg))
  332. print(*safe_args, **kwargs, file=sys.stderr)
  333. class _Tensor:
  334. def _get_levels(self) -> list[Any]:
  335. raise NotImplementedError("_get_levels must be implemented by subclass")
  336. def _get_tensor(self) -> Optional[torch.Tensor]:
  337. raise NotImplementedError("_get_tensor must be implemented by subclass")
  338. @property
  339. def ndim(self) -> int:
  340. raise NotImplementedError("ndim must be implemented by subclass")
  341. @property
  342. def dims(self) -> tuple[Any, ...]:
  343. return tuple(l.dim() for l in self._get_levels() if not l.is_positional())
  344. def dim(self) -> int:
  345. return self.ndim
  346. @classmethod
  347. def __torch_function__(
  348. cls,
  349. func: Callable,
  350. types: tuple,
  351. args: tuple = (),
  352. kwargs: Optional[dict] = None,
  353. ) -> Any:
  354. if kwargs is None:
  355. kwargs = {}
  356. if DOT_OPTIMIZED and func is torch.Tensor.__mul__:
  357. # Check conditions: 2 args, both are tensor-like, both 0-dimensional
  358. if (
  359. len(args) == 2
  360. and not kwargs
  361. and isinstance(args[0], (_Tensor, torch.Tensor))
  362. and isinstance(args[1], (_Tensor, torch.Tensor))
  363. ):
  364. # Get tensor info for both operands
  365. lhs_info = TensorInfo.create(
  366. args[0], ensure_batched=False, ensure_present=False
  367. )
  368. rhs_info = TensorInfo.create(
  369. args[1], ensure_batched=False, ensure_present=False
  370. )
  371. if (
  372. lhs_info
  373. and rhs_info
  374. and lhs_info.tensor is not None
  375. and rhs_info.tensor is not None
  376. and lhs_info.tensor.dim() == 0
  377. and rhs_info.tensor.dim() == 0
  378. ):
  379. if (
  380. lhs_info.tensor.is_floating_point()
  381. and rhs_info.tensor.is_floating_point()
  382. ):
  383. # Collect all unique levels and has_device
  384. has_device = lhs_info.has_device or rhs_info.has_device
  385. levels = []
  386. for level in lhs_info.levels:
  387. if level not in levels:
  388. levels.append(level)
  389. for level in rhs_info.levels:
  390. if level not in levels:
  391. levels.append(level)
  392. # Debug print
  393. # print(f"DEBUG: Creating delayed mul, levels: {levels}, has_device: {has_device}")
  394. # Create delayed tensor
  395. return Tensor.create_delayed(func, args, levels, has_device)
  396. if func is torch.Tensor.__getitem__:
  397. from functorch.dim._getsetitem import getitem
  398. return getitem(cls, func, types, args, kwargs)
  399. if func is torch.Tensor.__setitem__:
  400. from functorch.dim._getsetitem import setitem
  401. # args should be (tensor, index, value)
  402. if len(args) == 3:
  403. setitem(args[0], args[1], args[2])
  404. return None
  405. else:
  406. raise ValueError(f"Expected 3 args for __setitem__, got {len(args)}")
  407. # Fast-path for len; mostly to avoid infinite loop in TestMinFunctorchOnly.test_softmax_split
  408. if func is torch.Tensor.__len__:
  409. return args[0].size(0)
  410. # Special handling for torch.softmax - use the pre-wrapped version
  411. if func is torch.softmax:
  412. return softmax(*args, **kwargs)
  413. # Special handling for torch.stack - use the custom stack function
  414. if func is torch.stack:
  415. return stack(*args, **kwargs)
  416. if (
  417. func is torch.Tensor.split
  418. or func is torch._VF.split # type: ignore[attr-defined]
  419. or func is torch._VF.split_with_sizes # type: ignore[attr-defined]
  420. or func is torch.split
  421. ):
  422. return split(*args, **kwargs)
  423. return _Tensor._torch_function_fallback(func, types, args, kwargs)
  424. @staticmethod
  425. def _torch_function_fallback(
  426. func: Callable, types: tuple, args: tuple, kwargs: dict
  427. ) -> Any:
  428. """Fallback torch function implementation for non-special-cased functions."""
  429. is_pointwise = POINTWISE_OPTIMIZE and func in op_properties.pointwise
  430. # TODO: optimize pytree here
  431. flat_args, spec = tree_flatten((args, kwargs))
  432. device_holding_tensor = None
  433. infos: list[TensorInfo] = []
  434. result_levels: list[DimEntry] = []
  435. for f in flat_args:
  436. info = TensorInfo.create(f, not is_pointwise, False)
  437. infos.append(info)
  438. if info:
  439. if not (is_pointwise or info.batchedtensor is not None):
  440. raise AssertionError(
  441. "Expected pointwise or batchedtensor to be set"
  442. )
  443. if device_holding_tensor is None and info.has_device:
  444. device_holding_tensor = info.tensor
  445. # Collect all unique levels
  446. for level in info.levels:
  447. if not isinstance(level, DimEntry):
  448. raise AssertionError(f"Expected DimEntry, got {type(level)}")
  449. if level not in result_levels:
  450. result_levels.append(level)
  451. if is_pointwise:
  452. # Pointwise operation: match all tensors to common levels
  453. for i, info in enumerate(infos):
  454. if info and info.tensor is not None:
  455. tensor = info.tensor
  456. if device_holding_tensor is not None and not info.has_device:
  457. tensor = tensor.to(device_holding_tensor.device)
  458. ml = _match_levels(tensor, info.levels, result_levels)
  459. flat_args[i] = handle_from_tensor(ml)
  460. unflat_args, unflat_kwargs = tree_unflatten(flat_args, spec)
  461. result = func(*unflat_args, **unflat_kwargs)
  462. # Wrap tensor results
  463. def wrap_tensor(obj: Any) -> Any:
  464. if isinstance(obj, torch.Tensor):
  465. return Tensor.from_positional(
  466. obj, result_levels, device_holding_tensor is not None
  467. )
  468. return obj
  469. # Small fastpath
  470. if isinstance(result, torch.Tensor):
  471. return wrap_tensor(result)
  472. else:
  473. return tree_map(wrap_tensor, result)
  474. # Non-pointwise operation: use functorch vmap layers
  475. with EnableAllLayers(result_levels) as guard:
  476. # Update arguments with batched tensors
  477. for i, info in enumerate(infos):
  478. if info and info.batchedtensor is not None:
  479. batched = info.batchedtensor
  480. if device_holding_tensor is not None and not info.has_device:
  481. batched = batched.to(device_holding_tensor.device)
  482. guard.inplace_update_layers(batched, info.levels)
  483. flat_args[i] = handle_from_tensor(batched)
  484. unflat_args, unflat_kwargs = tree_unflatten(flat_args, spec)
  485. result = func(*unflat_args, **unflat_kwargs)
  486. # Unwrap results from functorch layers
  487. def unwrap_tensor(obj: Any) -> Any:
  488. if isinstance(obj, torch.Tensor):
  489. return guard.from_batched(obj, device_holding_tensor is not None)
  490. return obj
  491. if isinstance(result, torch.Tensor):
  492. return unwrap_tensor(result)
  493. else:
  494. return tree_map(unwrap_tensor, result)
  495. def __setitem__(self, index: Any, value: Any) -> None:
  496. """Set values in tensor using first-class dimensions."""
  497. from functorch.dim._getsetitem import setitem
  498. return setitem(self, index, value)
  499. # expand and index are OK to be methods because they don't have torch.*
  500. # versions, but if they did they need the stack/cat treatment
  501. def expand(self, *args: Dim) -> _Tensor:
  502. """
  503. Expand tensor by adding new dimensions or expanding existing dimensions.
  504. If all arguments are Dim objects, adds new named dimensions.
  505. Otherwise, falls back to regular tensor expansion behavior.
  506. Args:
  507. args: Either Dim objects for new dimensions or sizes for regular expansion
  508. Returns:
  509. New tensor with expanded dimensions
  510. Example:
  511. >>> i, j = dims()
  512. >>> t = torch.randn(3, 4)
  513. >>> expanded = t[i].expand(j, k) # Add j, k dimensions
  514. >>> expanded2 = t[i].expand(2, 4) # Regular expand with sizes
  515. """
  516. info = TensorInfo.create(self, ensure_batched=False, ensure_present=False)
  517. for arg in args:
  518. if not isinstance(arg, Dim):
  519. # Not all args are Dims, fallback to regular expand
  520. if isinstance(self, torch.Tensor) and not isinstance(self, _Tensor):
  521. return torch.Tensor.expand(self, *args)
  522. else:
  523. return self.__torch_function__(
  524. torch.Tensor.expand, (type(self),), (self,) + args
  525. )
  526. # All args are Dim objects - proceed with first-class dimension expansion
  527. if not info:
  528. # No tensor info available, fallback
  529. return self.__torch_function__(
  530. torch.Tensor.expand, (type(self),), (self,) + args
  531. )
  532. # First-class dimension expansion - all args are Dim objects
  533. data = info.tensor
  534. if data is None:
  535. # No tensor data available, fallback
  536. return self.__torch_function__(
  537. torch.Tensor.expand, (type(self),), (self,) + args
  538. )
  539. levels = info.levels
  540. new_levels: list[DimEntry] = []
  541. new_sizes = []
  542. new_strides = []
  543. for d in args:
  544. # Check if dimension already exists in current levels or new_levels
  545. for level in levels:
  546. if not level.is_positional() and level.dim() is d:
  547. raise DimensionBindError(
  548. f"expanding dimension {d} already exists in tensor with dims"
  549. )
  550. for new_level in new_levels:
  551. if not new_level.is_positional() and new_level.dim() is d:
  552. raise DimensionBindError(
  553. f"expanding dimension {d} already exists in tensor with dims"
  554. )
  555. new_levels.append(DimEntry(d))
  556. new_sizes.append(d.size)
  557. new_strides.append(0)
  558. # Add existing levels
  559. new_levels.extend(levels)
  560. # Add existing sizes and strides
  561. orig_sizes = list(data.size())
  562. orig_strides = list(data.stride())
  563. new_sizes.extend(orig_sizes)
  564. new_strides.extend(orig_strides)
  565. # Create expanded tensor using as_strided
  566. expanded_data = data.as_strided(new_sizes, new_strides, data.storage_offset())
  567. # Return new tensor with expanded dimensions
  568. result = Tensor.from_positional(expanded_data, new_levels, info.has_device)
  569. return result # type: ignore[return-value] # Tensor and torch.Tensor are interchangeable
  570. def index(
  571. self,
  572. dims: Union[int, Dim, tuple[Union[int, Dim], ...], list[Union[int, Dim]]],
  573. indices: Union[
  574. int,
  575. slice,
  576. torch.Tensor,
  577. tuple[Union[int, slice, torch.Tensor], ...],
  578. list[Union[int, slice, torch.Tensor]],
  579. ],
  580. ) -> _Tensor:
  581. """
  582. Index tensor using first-class dimensions.
  583. """
  584. from ._dim_entry import _match_levels
  585. from ._getsetitem import getsetitem_flat, invoke_getitem
  586. from ._wrap import _wrap_dim
  587. # Helper to check if obj is a dimpack (tuple/list) and extract items
  588. def maybe_dimpack(obj: Any, check_first: bool = False) -> tuple[Any, bool]:
  589. if isinstance(obj, (tuple, list)):
  590. return list(obj), True
  591. return None, False
  592. def parse_dim_entry(s: Any) -> Any:
  593. d = _wrap_dim(s, self.ndim, False)
  594. if d.is_none():
  595. raise TypeError(f"expected a dimension specifyer but found {repr(s)}")
  596. return d
  597. # Helper for dimension not present errors
  598. def dim_not_present(d: Any) -> None:
  599. if d.is_positional():
  600. raise TypeError(
  601. f"dimension {d.position() + self.ndim} not in tensor of {self.ndim} dimensions"
  602. )
  603. else:
  604. raise TypeError(f"dimension {repr(d.dim())} not in tensor")
  605. dims_list: list[Union[int, Dim]] = []
  606. indices_list: list[Union[int, slice, torch.Tensor]] = []
  607. lhs_list = isinstance(dims, (tuple, list))
  608. rhs_list = isinstance(indices, (tuple, list))
  609. if lhs_list and rhs_list:
  610. # Type narrowing: we know dims and indices are sequences here
  611. dims_seq = dims # type: ignore[assignment]
  612. indices_seq = indices # type: ignore[assignment]
  613. if len(dims_seq) != len(indices_seq): # type: ignore[arg-type]
  614. raise TypeError(
  615. f"dims ({len(dims_seq)}) and indices ({len(indices_seq)}) must have the same length" # type: ignore[arg-type]
  616. )
  617. dims_list.extend(dims_seq) # type: ignore[arg-type]
  618. indices_list.extend(indices_seq) # type: ignore[arg-type]
  619. else:
  620. dims_list.append(dims) # type: ignore[arg-type]
  621. indices_list.append(indices) # type: ignore[arg-type]
  622. # Create tensor info
  623. self_info = TensorInfo.create(self, False, False)
  624. new_levels: list[Any] = []
  625. to_flatten: list[Any] = []
  626. dims_list_flat = []
  627. # Process each dim specification
  628. for i in range(len(dims_list)):
  629. m, is_dimpack = maybe_dimpack(dims_list[i], check_first=False)
  630. if is_dimpack:
  631. if len(m) == 0:
  632. dims_list_flat.append(DimEntry()) # Empty dimpack
  633. continue
  634. first = parse_dim_entry(m[0])
  635. dims_list_flat.append(first)
  636. if len(m) == 1:
  637. continue
  638. # Multi-element dimpack requires flattening
  639. if len(to_flatten) == 0:
  640. new_levels.extend(self_info.levels)
  641. rest = []
  642. for j in range(1, len(m)):
  643. d = parse_dim_entry(m[j])
  644. removed = False
  645. for k in range(len(new_levels)):
  646. if new_levels[k] == d:
  647. new_levels.pop(k)
  648. removed = True
  649. break
  650. if not removed:
  651. dim_not_present(d)
  652. rest.append(d)
  653. # Find first in new_levels
  654. first_idx = None
  655. for k in range(len(new_levels)):
  656. if new_levels[k] == first:
  657. first_idx = k
  658. break
  659. if first_idx is None:
  660. dim_not_present(first)
  661. continue # Skip this iteration if dimension not found
  662. for j, r in enumerate(rest):
  663. new_levels.insert(first_idx + 1 + j, r)
  664. to_flatten.extend(rest)
  665. else:
  666. dims_list_flat.append(parse_dim_entry(dims_list[i]))
  667. # Handle dimension flattening if needed
  668. if len(to_flatten) > 0:
  669. if self_info.tensor is None:
  670. raise AssertionError(
  671. "Cannot perform dimension flattening on None tensor"
  672. )
  673. rearranged = _match_levels(self_info.tensor, self_info.levels, new_levels)
  674. sizes = rearranged.size()
  675. new_sizes: list[Any] = []
  676. reshape_levels = []
  677. for i in range(len(new_levels)):
  678. if new_levels[i] in to_flatten:
  679. if len(new_sizes) == 0:
  680. new_sizes.append(sizes[i])
  681. else:
  682. new_sizes[-1] *= sizes[i]
  683. else:
  684. new_sizes.append(sizes[i])
  685. reshape_levels.append(new_levels[i])
  686. self_info.tensor = rearranged.reshape(new_sizes)
  687. self_info.levels = reshape_levels
  688. # Check for dimpacks in indices
  689. has_dimpacks = False
  690. for idx in indices_list:
  691. if isinstance(idx, (tuple, list)):
  692. has_dimpacks = True
  693. break
  694. # Call getsetitem_flat with correct parameters
  695. info = getsetitem_flat(
  696. self_info,
  697. [], # empty input_list
  698. dims_list_flat, # keys
  699. indices_list, # values
  700. has_dimpacks,
  701. )
  702. return invoke_getitem(info)
  703. def __repr__(self) -> str:
  704. tensor, levels, ndim = self._get_tensor(), self._get_levels(), self.ndim
  705. dims_repr = []
  706. for l in levels:
  707. if hasattr(l, "is_positional") and l.is_positional():
  708. # Convert negative positional to positive: -1 -> ndim-1, -2 -> ndim-2, etc.
  709. dims_repr.append(l.position() + ndim)
  710. elif hasattr(l, "dim"):
  711. dims_repr.append(l.dim())
  712. elif hasattr(l, "data"):
  713. dims_repr.append(l.data)
  714. else:
  715. dims_repr.append(l)
  716. return f"{tensor}\nwith dims={tuple(dims_repr)} sizes={tuple(tensor.size())}" # type: ignore[union-attr]
  717. TensorLike = (_Tensor, torch.Tensor)
  718. class Dim(_Tensor):
  719. _level: int
  720. _name: str
  721. _size: int
  722. _range: Optional[torch.Tensor]
  723. _batchtensor: Optional[torch.Tensor]
  724. def __init__(self, name: str, s: int = -1) -> None:
  725. global _n_dims_created
  726. self._name = name
  727. self._size = s
  728. self._level = _n_dims_created
  729. _n_dims_created += 1
  730. self._range = None
  731. self._batchtensor = None
  732. @property
  733. def ndim(self) -> int:
  734. return 1
  735. @classmethod
  736. def check_exact(cls, obj: Any) -> bool:
  737. return type(obj) is cls
  738. @property
  739. def size(self) -> int:
  740. if self._size == -1:
  741. raise ValueError(f"dimension {self._name} is unbound")
  742. return self._size
  743. @size.setter
  744. def size(self, v: int) -> None:
  745. if self._size == -1:
  746. self._size = v
  747. elif self._size != v:
  748. raise DimensionBindError(
  749. f"Dim '{repr(self)}' previously bound to a dimension of size {self._size} "
  750. f"cannot bind to a dimension of size {v}"
  751. )
  752. @property
  753. def is_bound(self) -> bool:
  754. """Return True if this dimension is bound to a size."""
  755. return self._size != -1
  756. def _get_range(self) -> torch.Tensor:
  757. """
  758. Get a tensor representing the range [0, size) for this dimension.
  759. Returns:
  760. A 1D tensor with values [0, 1, 2, ..., size-1]
  761. """
  762. if self._range is None:
  763. self._range = torch.arange(self.size)
  764. return self._range
  765. def _get_batchtensor(self) -> torch.Tensor:
  766. """
  767. Get a batched tensor representation of this dimension.
  768. Returns:
  769. A batched tensor created from the range tensor
  770. """
  771. if self._batchtensor is None:
  772. self._batchtensor = torch._C._functorch._add_batch_dim(
  773. self._get_range(), 0, self._level
  774. )
  775. return self._batchtensor
  776. def __repr__(self) -> str:
  777. """String representation of a Dim object."""
  778. return self._name
  779. # note that Dim comes before tensor because we want the Dim API for things like size to take precedence.
  780. # Tensor defines format, but we want to print Dims with special formatting
  781. __format__ = object.__format__
  782. # Somewhat confusingly, an FCD tensor is also called Tensor. This confusion
  783. # is somewhat intentional, as FCD tensors are intended to be substitutable
  784. # with regular Tensor (just with some positional dims hidden).
  785. class Tensor(_Tensor):
  786. _tensor: Optional[torch.Tensor]
  787. _batchtensor: Optional[torch.Tensor]
  788. _levels: list[DimEntry]
  789. _has_device: bool
  790. _delayed: Optional[Callable[[], torch.Tensor]]
  791. _delayed_orig: Optional[Callable]
  792. _delayed_args: Optional[tuple]
  793. @property
  794. def ndim(self) -> int:
  795. return sum(1 if l.is_positional() else 0 for l in self._levels)
  796. @classmethod
  797. def check_exact(cls, other: Any) -> bool:
  798. return type(other) is cls
  799. @classmethod
  800. def from_positional(
  801. cls, tensor: torch.Tensor, levels: list[DimEntry], has_device: bool
  802. ) -> Union[_Tensor, torch.Tensor]:
  803. """
  804. Create a functorch Tensor from a regular PyTorch tensor with specified dimension levels.
  805. This is the primary way to create Tensor objects with first-class dimensions.
  806. Args:
  807. tensor: The underlying PyTorch tensor
  808. levels: List of DimEntry objects specifying the dimension structure
  809. has_device: Whether the tensor is on a device (not CPU)
  810. Returns:
  811. A new Tensor instance with the specified dimensions, or a regular torch.Tensor
  812. if there are no named dimensions
  813. """
  814. seen_dims = 0
  815. last = 0
  816. for l in levels:
  817. if l.is_positional():
  818. # Validate consecutive positional dimensions
  819. if not (last == 0 or last + 1 == l.position()):
  820. raise AssertionError(
  821. f"Positional dimensions must be consecutive, got {last} then {l.position()}"
  822. )
  823. last = l.position()
  824. else:
  825. # This is a named dimension
  826. seen_dims += 1
  827. # Validate final positional dimension
  828. if not (last == 0 or last == -1):
  829. raise AssertionError(
  830. f"Final positional dimension must be 0 or -1, got {last}"
  831. )
  832. if not seen_dims:
  833. return tensor
  834. # Create Tensor object with proper level management
  835. result = cls()
  836. result._tensor = tensor
  837. result._levels = levels
  838. result._has_device = has_device
  839. result._batchtensor = None # Will be created lazily if needed
  840. result._delayed = None
  841. result._delayed_orig = None
  842. result._delayed_args = None
  843. # Validate tensor dimensionality matches levels
  844. if tensor.dim() != len(levels):
  845. raise AssertionError(
  846. f"Tensor has {tensor.dim()} dimensions but {len(levels)} levels provided"
  847. )
  848. return result
  849. @classmethod
  850. def create_delayed(
  851. cls, orig: Callable, args: tuple, levels: list[DimEntry], has_device: bool
  852. ) -> _Tensor:
  853. """
  854. Create a delayed tensor that defers the operation until later.
  855. """
  856. result = cls()
  857. result._tensor = None # Will be computed when needed
  858. result._levels = levels
  859. result._has_device = has_device
  860. result._batchtensor = None
  861. result._delayed_orig = orig
  862. result._delayed_args = args
  863. # Create delayed evaluation function that unwraps Tensor objects
  864. def evaluate_delayed() -> torch.Tensor:
  865. unwrapped_args = []
  866. for arg in args:
  867. if hasattr(arg, "_get_tensor"):
  868. unwrapped_args.append(arg._get_tensor())
  869. else:
  870. unwrapped_args.append(arg)
  871. return orig(*unwrapped_args)
  872. result._delayed = evaluate_delayed
  873. return result
  874. def _get_tensor(self) -> Optional[torch.Tensor]:
  875. """Get the underlying tensor, handling delayed operations if needed."""
  876. if (
  877. hasattr(self, "_delayed")
  878. and self._delayed is not None
  879. and self._tensor is None
  880. ):
  881. # Execute the delayed operation
  882. self._tensor = self._delayed()
  883. # Clear delayed operation to avoid re-execution
  884. self._delayed = None
  885. self._delayed_orig = None
  886. self._delayed_args = None
  887. return self._tensor
  888. def _get_levels(self) -> list[Any]:
  889. """Get the dimension levels."""
  890. return self._levels
  891. def _get_has_device(self) -> bool:
  892. """Get whether this tensor has device information."""
  893. return self._has_device
  894. def _get_batchtensor(self) -> Optional[torch.Tensor]:
  895. """Get the batched tensor representation, creating it lazily if needed."""
  896. if self._batchtensor is None:
  897. self._batchtensor = self._add_batch_dims(
  898. self._get_tensor(), self._get_levels()
  899. )
  900. return self._batchtensor
  901. def _add_batch_dims(
  902. self, t: Optional[torch.Tensor], levels_: list[Any]
  903. ) -> Optional[torch.Tensor]:
  904. levels = list(levels_)
  905. while True:
  906. min_real_index = -1
  907. min_index = -1
  908. min_value = float("inf") # INT_MAX equivalent
  909. i = 0
  910. r = 0
  911. for r, l in enumerate(levels):
  912. if not l.is_none():
  913. if not l.is_positional() and l.dim()._level < min_value:
  914. min_value = l.dim()._level
  915. min_index = i
  916. min_real_index = r
  917. i += 1
  918. if min_index == -1:
  919. return t
  920. if t is None:
  921. raise AssertionError("Expected t to be non-None")
  922. t = torch._C._functorch._add_batch_dim(t, min_index, int(min_value))
  923. levels[min_real_index] = DimEntry()
  924. return None
  925. def order(self, *dims: Any) -> _Tensor:
  926. """Reorder the dimensions of this tensor."""
  927. from ._order import order
  928. result = order(self, *dims)
  929. return result # type: ignore[return-value] # Tensor and torch.Tensor are interchangeable
  930. def stack(tensors: Any, new_dim: Any, dim: int = 0) -> _Tensor:
  931. """
  932. Stack tensors along a new dimension.
  933. Args:
  934. tensors: Sequence of tensors to stack
  935. new_dim: The new Dim to create for stacking
  936. dim: The dimension position to insert the new dimension (default: 0)
  937. Returns:
  938. Stacked tensor with the new dimension
  939. """
  940. if not tensors:
  941. raise ValueError("stack expects a non-empty sequence of tensors")
  942. # Check if new_dim is a Dim object
  943. if not isinstance(new_dim, Dim):
  944. # Fall back to regular torch.stack
  945. result = torch.stack(tensors, dim=dim)
  946. return result # type: ignore[return-value]
  947. # Collect all result_levels from input tensors
  948. result_levels = []
  949. infos = []
  950. for t in tensors:
  951. info = TensorInfo.create(t, ensure_batched=False, ensure_present=False)
  952. infos.append(info)
  953. for level in info.levels:
  954. if level not in result_levels:
  955. result_levels.append(level)
  956. # Set the new_dim size to match number of tensors
  957. new_dim.size = len(tensors)
  958. # Match all tensors to the common level structure using _match_levels
  959. inputs = []
  960. for info in infos:
  961. if info.tensor is None:
  962. raise AssertionError("Cannot stack tensors with None tensor data")
  963. matched_tensor = _match_levels(info.tensor, info.levels, result_levels)
  964. inputs.append(matched_tensor)
  965. # Calculate ndim and resolve the dim parameter
  966. ndim = ndim_of_levels(result_levels)
  967. rawdim = 0
  968. if dim is not None and not (isinstance(dim, int) and dim == 0):
  969. from ._wrap import _wrap_dim
  970. d = _wrap_dim(dim, ndim, False)
  971. try:
  972. idx = result_levels.index(d)
  973. except ValueError:
  974. raise TypeError(f"Dimension {dim} does not exist in inputs") from None
  975. rawdim = idx
  976. # Stack tensors at the resolved dimension
  977. result = torch.stack(inputs, rawdim)
  978. # Insert new dimension entry at the correct position
  979. result_levels.insert(rawdim, DimEntry(new_dim))
  980. # Return as a first-class tensor
  981. tensor_result = Tensor.from_positional(
  982. result, result_levels, infos[0].has_device if infos else True
  983. )
  984. return tensor_result # type: ignore[return-value]
  985. def split(tensor: Any, split_size_or_sections: Any, dim: Any = None) -> tuple:
  986. """
  987. Split tensor along a dimension.
  988. Can handle both regular integer sizes and Dim objects for split sizes.
  989. When Dim objects are used, they get bound to the resulting tensor dimensions.
  990. """
  991. from ._wrap import _wrap_dim
  992. # Check if dim is a Dim object
  993. dim_is_object = isinstance(dim, Dim)
  994. # Parse split_size_or_sections
  995. if isinstance(split_size_or_sections, int):
  996. # Single integer - use regular split
  997. if dim_is_object:
  998. raise TypeError(
  999. "when dim is specified as a Dim object, split sizes must also be dimensions."
  1000. )
  1001. return _Tensor._torch_function_fallback(
  1002. torch.Tensor.split,
  1003. (type(tensor),),
  1004. (tensor, split_size_or_sections),
  1005. {"dim": dim},
  1006. )
  1007. # Check if it's a sequence
  1008. sizes = []
  1009. all_dims = True
  1010. all_ints = True
  1011. for item in split_size_or_sections:
  1012. sizes.append(item)
  1013. if isinstance(item, Dim):
  1014. all_ints = False
  1015. else:
  1016. all_dims = False
  1017. if all_ints:
  1018. # All integers - use regular split
  1019. if dim_is_object:
  1020. raise TypeError(
  1021. "when dim is specified as a Dim object, split sizes must also be dimensions."
  1022. )
  1023. return _Tensor._torch_function_fallback(
  1024. torch.Tensor.split,
  1025. (type(tensor),),
  1026. (tensor, split_size_or_sections),
  1027. {"dim": dim},
  1028. )
  1029. if not all_dims:
  1030. raise TypeError("split list must be ints or dims but got a mix")
  1031. # All are Dim objects - handle first-class dimension split
  1032. self_info = TensorInfo.create(tensor, ensure_batched=False, ensure_present=False)
  1033. ndim = self_info.ndim()
  1034. if not dim_is_object and ndim == 0:
  1035. raise TypeError("split expects at least a 1-dimension tensor")
  1036. # Wrap the dimension
  1037. dim_l = _wrap_dim(dim, ndim, False) if dim is not None else DimEntry(-ndim)
  1038. # Find the index of the dimension in levels
  1039. idx = None
  1040. for i, level in enumerate(self_info.levels):
  1041. if level == dim_l:
  1042. idx = i
  1043. break
  1044. if idx is None:
  1045. if dim is None:
  1046. dim = 0
  1047. raise TypeError(f"tensor does not contain dimension {dim}")
  1048. # Calculate split indices
  1049. indices = []
  1050. total_size = 0
  1051. unbound = []
  1052. for i, size_dim in enumerate(sizes):
  1053. if size_dim.is_bound:
  1054. indices.append(size_dim.size)
  1055. total_size += indices[-1]
  1056. else:
  1057. indices.append(0)
  1058. unbound.append(i)
  1059. if self_info.tensor is None:
  1060. raise AssertionError("Cannot get tensor size on None tensor")
  1061. tensor_size = self_info.tensor.size(idx)
  1062. # Handle unbound dimensions
  1063. if unbound:
  1064. if total_size > tensor_size:
  1065. raise TypeError(
  1066. f"sizes of target dimensions add up to more ({total_size}) than source dim ({tensor_size})"
  1067. )
  1068. remaining_size = tensor_size - total_size
  1069. chunk_size = (remaining_size + len(unbound) - 1) // len(unbound)
  1070. for u in unbound:
  1071. sz = min(chunk_size, remaining_size)
  1072. sizes[u].size = sz
  1073. indices[u] = sz
  1074. remaining_size -= sz
  1075. elif tensor_size != total_size:
  1076. raise TypeError(
  1077. f"sum of sizes of target dimensions ({total_size}) do not match the source dim ({tensor_size})"
  1078. )
  1079. # Perform the split
  1080. result_tensors = self_info.tensor.split_with_sizes(indices, idx)
  1081. # Create result with new levels
  1082. result = []
  1083. new_levels = list(self_info.levels)
  1084. for i, (result_tensor, size_dim) in enumerate(zip(result_tensors, sizes)):
  1085. new_levels[idx] = DimEntry(size_dim)
  1086. result.append(
  1087. Tensor.from_positional(
  1088. result_tensor, list(new_levels), self_info.has_device
  1089. )
  1090. )
  1091. return tuple(result)
  1092. def cat(tensors: Any, dim: Any, new_dim: Any) -> _Tensor:
  1093. n = dims(1) # Get single Dim instead of tuple
  1094. return stack(tensors, n, dim).index([n, dim], new_dim) # type: ignore[list-item]
  1095. class DotPart:
  1096. """
  1097. Helper class for organizing dimensions in dot products.
  1098. """
  1099. def __init__(self) -> None:
  1100. self.dims: list[DimEntry] = []
  1101. self.total_size = 1
  1102. def append(self, dim_entry: Any) -> None:
  1103. """Add a dimension entry to this part."""
  1104. self.dims.append(dim_entry)
  1105. if not dim_entry.is_positional():
  1106. self.total_size *= dim_entry.dim().size
  1107. def dot_prepare(parts: list[DotPart], tensor_info: TensorInfo) -> torch.Tensor:
  1108. """
  1109. Prepare tensor for dot product by matching levels and reshaping.
  1110. """
  1111. new_levels = []
  1112. needs_reshape = False
  1113. for part in parts:
  1114. if len(part.dims) != 1:
  1115. needs_reshape = True
  1116. new_levels.extend(part.dims)
  1117. if tensor_info.tensor is None:
  1118. raise RuntimeError("Cannot perform dot product on None tensor")
  1119. result = _match_levels(tensor_info.tensor, tensor_info.levels, new_levels)
  1120. if not needs_reshape:
  1121. return result
  1122. # Reshape for matrix operations
  1123. view = [part.total_size for part in parts]
  1124. return result.reshape(view)
  1125. def dot_finish(parts: list[DotPart], result_tensor: torch.Tensor) -> Tensor:
  1126. """
  1127. Finish dot product by reshaping result and creating Tensor.
  1128. """
  1129. result_levels = []
  1130. needs_reshape = False
  1131. for part in parts:
  1132. if len(part.dims) != 1:
  1133. needs_reshape = True
  1134. result_levels.extend(part.dims)
  1135. if needs_reshape:
  1136. new_size = []
  1137. for level in result_levels:
  1138. new_size.append(level.dim().size)
  1139. result_tensor = result_tensor.reshape(new_size)
  1140. tensor_result = Tensor.from_positional(result_tensor, result_levels, True)
  1141. return tensor_result # type: ignore[return-value]
  1142. def dot(lhs: Any, rhs: Any, sum_dims: Any) -> Union[_Tensor, torch.Tensor]:
  1143. """
  1144. Perform dot product between two tensors along specified dimensions.
  1145. Args:
  1146. lhs: Left-hand side tensor
  1147. rhs: Right-hand side tensor
  1148. sum_dims: Dimensions to sum over (contract)
  1149. Returns:
  1150. Result of dot product
  1151. """
  1152. # Get tensor info
  1153. lhs_info = TensorInfo.create(lhs, ensure_batched=False, ensure_present=False)
  1154. rhs_info = TensorInfo.create(rhs, ensure_batched=False, ensure_present=False)
  1155. if not (lhs_info and rhs_info):
  1156. # Fall back to regular operations
  1157. return torch.matmul(lhs, rhs)
  1158. if lhs_info.tensor is None or rhs_info.tensor is None:
  1159. raise AssertionError("Cannot perform dot product on None tensors")
  1160. lhs_strides = lhs_info.tensor.stride()
  1161. rhs_strides = rhs_info.tensor.stride()
  1162. # Create dot parts for different dimension categories
  1163. lro_dims = DotPart() # Left-right-output (batch dims)
  1164. lo_dims = DotPart() # Left-output only
  1165. ro_dims = DotPart() # Right-output only
  1166. lr_dims = DotPart() # Left-right (contracted dims)
  1167. def insert_dim(d: Any, lhs_idx: Any, rhs_idx: Any) -> None:
  1168. """Insert dimension into appropriate part based on stride pattern."""
  1169. reduced = d in sum_dims
  1170. lhs_stride = lhs_strides[lhs_idx] if lhs_idx is not None else 0
  1171. rhs_stride = rhs_strides[rhs_idx] if rhs_idx is not None else 0
  1172. if reduced:
  1173. lr_dims.append(d)
  1174. else:
  1175. if (lhs_stride == 0) == (rhs_stride == 0):
  1176. lro_dims.append(d) # Both have or both lack this dim
  1177. elif lhs_stride != 0:
  1178. lo_dims.append(d) # Only lhs has this dim
  1179. else:
  1180. ro_dims.append(d) # Only rhs has this dim
  1181. # Track which rhs dimensions we've seen
  1182. rhs_seen = [False] * len(rhs_info.levels)
  1183. # Process lhs dimensions
  1184. for i, lhs_level in enumerate(lhs_info.levels):
  1185. rhs_idx = None
  1186. for j, rhs_level in enumerate(rhs_info.levels):
  1187. if lhs_level == rhs_level:
  1188. rhs_idx = j
  1189. rhs_seen[j] = True
  1190. break
  1191. insert_dim(lhs_level, i, rhs_idx)
  1192. # Process remaining rhs dimensions
  1193. for i, rhs_level in enumerate(rhs_info.levels):
  1194. if not rhs_seen[i]:
  1195. insert_dim(rhs_level, None, i)
  1196. # Validate sum dimensions exist
  1197. if len(lr_dims.dims) != len(sum_dims):
  1198. for d in sum_dims:
  1199. if d not in lhs_info.levels and d not in rhs_info.levels:
  1200. raise ValueError(f"summing over non-existent dimension {d}")
  1201. # Prepare tensors and perform matrix multiplication
  1202. if len(lro_dims.dims) != 0:
  1203. # Batched matrix multiply
  1204. lhs_tensor = dot_prepare([lro_dims, lo_dims, lr_dims], lhs_info)
  1205. rhs_tensor = dot_prepare([lro_dims, lr_dims, ro_dims], rhs_info)
  1206. result = torch.bmm(lhs_tensor, rhs_tensor)
  1207. return dot_finish([lro_dims, lo_dims, ro_dims], result)
  1208. else:
  1209. # Regular matrix multiply
  1210. lhs_tensor = dot_prepare([lo_dims, lr_dims], lhs_info)
  1211. rhs_tensor = dot_prepare([lr_dims, ro_dims], rhs_info)
  1212. result = torch.mm(lhs_tensor, rhs_tensor)
  1213. return dot_finish([lo_dims, ro_dims], result)
  1214. from functorch.dim._wrap import _wrap
  1215. from functorch.dim.wrap_type import wrap_type
  1216. wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__)
  1217. del _Tensor.ndim
  1218. def index(self: Any, positions: Any, dims: Any) -> _Tensor:
  1219. """
  1220. Index a regular tensor by binding specified positions to dims.
  1221. This converts a regular tensor to a first-class tensor by binding
  1222. the specified positional dimensions to Dim objects.
  1223. Args:
  1224. positions: Tuple of dimension positions to bind
  1225. dims: Dim objects or tuple of Dim objects to bind to
  1226. Returns:
  1227. First-class tensor with specified dimensions bound
  1228. """
  1229. # If this is already a first-class tensor (_Tensor), call its index method directly
  1230. if isinstance(self, _Tensor):
  1231. return _Tensor.index(self, positions, dims)
  1232. # Convert regular tensor to first-class tensor
  1233. info = TensorInfo.create(self, ensure_batched=False, ensure_present=False)
  1234. # Create the first-class tensor
  1235. if info.tensor is None:
  1236. raise AssertionError("Cannot index None tensor")
  1237. result = Tensor.from_positional(info.tensor, info.levels, info.has_device)
  1238. # Now call the index method on the first-class tensor
  1239. # Cast result to _Tensor for the method call
  1240. return _Tensor.index(result, positions, dims) # type: ignore[arg-type]
  1241. def _def(name: str, *args: Any, **kwargs: Any) -> None:
  1242. orig = getattr(torch.Tensor, name)
  1243. setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
  1244. _def("mean")
  1245. _def("sum")
  1246. _def("all")
  1247. _def("amax")
  1248. _def("amin")
  1249. _def("aminmax")
  1250. _def("any")
  1251. _def("count_nonzero")
  1252. _def("logsumexp")
  1253. _def("nanmean")
  1254. _def("nansum")
  1255. _def("prod")
  1256. _def("std", keepdim_offset=2)
  1257. _def("var", keepdim_offset=2)
  1258. _def("max", single_dim=True)
  1259. _def("min", single_dim=True)
  1260. _def("argmax", single_dim=True)
  1261. _def("argmin", single_dim=True)
  1262. _def("kthvalue", single_dim=True)
  1263. _def("median", single_dim=True)
  1264. _def("nanmedian", single_dim=True)
  1265. _def("mode", single_dim=True)
  1266. _def("sort", reduce=False)
  1267. _def("argsort", reduce=False)
  1268. _def("unbind", single_dim=True)
  1269. _def("chunk", dim_offset=1, reduce=False)
  1270. _def("cummax", single_dim=True, reduce=False)
  1271. _def("cummin", single_dim=True, reduce=False)
  1272. _def("cumprod", single_dim=True, reduce=False)
  1273. _def("cumprod_", single_dim=True, reduce=False)
  1274. _def("cumsum", single_dim=True, reduce=False)
  1275. _def("cumsum_", single_dim=True, reduce=False)
  1276. _def("logcumsumexp", single_dim=True, reduce=False)
  1277. _def("renorm", dim_offset=1, single_dim=True, reduce=False)
  1278. _def("softmax", single_dim=True, reduce=False)
  1279. softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
  1280. # stuff to handle in the future, because they require special
  1281. # binding logic for dims
  1282. # cross
  1283. # diag_embed
  1284. # diagonal
  1285. # diagonal_scatter
  1286. # diff
  1287. # nanquantile
  1288. # quantile
  1289. # roll
  1290. # rot90
  1291. # topk (new dimes on output)
  1292. # should these all be subsumed by inplace indexing?
  1293. # index_add_
  1294. # index_add
  1295. # index_copy
  1296. # index_copy_
  1297. # index_fill
  1298. # index_fill_
  1299. # index_select
  1300. # scatter
  1301. # scatter_
  1302. # scatter_add
  1303. # scatter_add_
  1304. # scatter_reduce