_getsetitem.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field
  3. from typing import Any, Optional, TYPE_CHECKING, Union
  4. import torch
  5. from ._dim_entry import _match_levels, DimEntry
  6. from ._tensor_info import TensorInfo
  7. if TYPE_CHECKING:
  8. from . import Dim
  9. def _safe_index(lst: list, item: Any) -> Optional[int]:
  10. """
  11. Helper function to find index of item in list.
  12. For DimEntry objects, uses __eq__ comparison which properly handles
  13. both positional and Dim entries.
  14. Returns the index if found, None if not found.
  15. """
  16. for i, list_item in enumerate(lst):
  17. # Use == for DimEntry objects as they have proper __eq__ implementation
  18. if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
  19. if list_item == item:
  20. return i
  21. elif list_item is item:
  22. return i
  23. return None
  24. @dataclass
  25. class IndexingInfo:
  26. can_call_original: bool = False
  27. advanced_indexing: bool = False
  28. self_tensor: Optional[torch.Tensor] = None
  29. flat_inputs: list[Any] = field(default_factory=list)
  30. result_levels: list[DimEntry] = field(default_factory=list)
  31. has_device: bool = False
  32. def has_dims(obj: Any) -> bool:
  33. """
  34. Check if an object has first-class dimensions.
  35. This function checks if the object is either a Dim or a functorch Tensor
  36. that has first-class dimensions, using the proper check_exact methods.
  37. """
  38. from . import Dim, Tensor
  39. return Dim.check_exact(obj) or Tensor.check_exact(obj)
  40. def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
  41. """
  42. Bind dimensions to size and calculate proper strides for dim packs.
  43. """
  44. from . import DimensionBindError
  45. rhs_prod = 1
  46. for i, dim in enumerate(dims):
  47. if not dim.is_bound:
  48. # Check for multiple unbound dimensions
  49. for j in range(i + 1, len(dims)):
  50. if not dims[j].is_bound:
  51. raise DimensionBindError(
  52. f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
  53. )
  54. rhs_prod *= dims[j].size
  55. # Calculate the size for this unbound dimension
  56. if sz % rhs_prod != 0:
  57. tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
  58. raise DimensionBindError(
  59. f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
  60. )
  61. inferred_size = sz // rhs_prod
  62. dim.size = inferred_size
  63. rhs_prod = sz
  64. break
  65. else:
  66. rhs_prod *= dim.size
  67. # Final validation that dimensions match
  68. if rhs_prod != sz:
  69. tup = tuple(dims)
  70. raise DimensionBindError(
  71. f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
  72. )
  73. # Calculate new sizes and strides for each dimension in the pack
  74. # First calculate all strides by iterating in reverse
  75. new_strides = [0] * len(dims)
  76. current_stride = sd
  77. for i in reversed(range(len(dims))):
  78. new_strides[i] = current_stride
  79. current_stride *= dims[i].size
  80. # Then append sizes and strides in forward order
  81. for i in range(len(dims)):
  82. nsz.append(dims[i].size)
  83. nsd.append(new_strides[i])
  84. def slice_to_tuple(flat_inputs: list) -> tuple:
  85. return tuple(flat_inputs)
  86. def extractIndices(index: Any, indices: list) -> bool:
  87. if isinstance(index, tuple): # mpy::tuple_view::check
  88. indices.extend(index)
  89. return True
  90. elif isinstance(index, torch.Tensor): # THPVariable_Check
  91. indices.append(index)
  92. return False
  93. elif not hasattr(index, "__iter__") or isinstance(
  94. index, (str, bytes)
  95. ): # !mpy::is_sequence
  96. indices.append(index)
  97. return False
  98. # Handle sequence case (list)
  99. if isinstance(index, list):
  100. if len(index) >= 32:
  101. indices.extend(index)
  102. return True
  103. # Check each item in the sequence
  104. for item in index:
  105. if (
  106. isinstance(item, (torch.Tensor, slice))
  107. or hasattr(item, "__iter__")
  108. or item is ...
  109. or item is None
  110. or has_dims(item)
  111. ):
  112. indices.extend(index)
  113. return True
  114. # If we got here, treat as single index
  115. indices.append(index)
  116. return False
  117. # Default case
  118. indices.append(index)
  119. return False
  120. def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
  121. self = args[0]
  122. index = args[1]
  123. iinfo = getsetitem(self, index, has_dims(self))
  124. if iinfo.can_call_original:
  125. # Call original tensor __getitem__ directly, bypassing __torch_function__
  126. return torch.Tensor.__getitem__(self, index)
  127. return invoke_getitem(iinfo)
  128. def setitem(self: Any, index: Any, rhs: Any) -> None:
  129. """Set values in tensor using first-class dimensions."""
  130. from . import DimensionBindError, TensorInfo
  131. iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
  132. if iinfo.can_call_original:
  133. # Call original tensor __setitem__ directly, bypassing __torch_function__
  134. torch._C.TensorBase.__setitem__(self, index, rhs)
  135. return
  136. # Handle RHS tensor with dimensions
  137. rhs_info = TensorInfo.create(rhs, False, False)
  138. if rhs_info:
  139. # Check that rhs dimensions are compatible with result dimensions
  140. for l in rhs_info.levels:
  141. if not l.is_positional():
  142. # Find this dimension in result levels
  143. found = False
  144. for result_level in iinfo.result_levels:
  145. if (
  146. not result_level.is_positional()
  147. and result_level.dim() is l.dim()
  148. ):
  149. found = True
  150. break
  151. if not found:
  152. # Create tuple representation of result levels for error message
  153. result_dims: list[Union[int, Dim]] = []
  154. for rl in iinfo.result_levels:
  155. if rl.is_positional():
  156. result_dims.append(rl.position())
  157. else:
  158. result_dims.append(rl.dim())
  159. raise DimensionBindError(
  160. f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
  161. f"({tuple(result_dims)!r})"
  162. )
  163. # Match RHS tensor to result levels
  164. if rhs_info.tensor is None:
  165. raise AssertionError("Cannot match levels on None tensor")
  166. matched_rhs = _match_levels(
  167. rhs_info.tensor, rhs_info.levels, iinfo.result_levels
  168. )
  169. else:
  170. matched_rhs = rhs
  171. # For advanced indexing with dimensions, we need special handling
  172. if iinfo.advanced_indexing:
  173. # Use advanced indexing - the flat_inputs already contain matched tensors
  174. tup = slice_to_tuple(iinfo.flat_inputs)
  175. if iinfo.self_tensor is None:
  176. raise RuntimeError("Cannot setitem on None tensor")
  177. torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
  178. else:
  179. # Simple copy operation
  180. if iinfo.self_tensor is None:
  181. raise RuntimeError("Cannot copy to None tensor")
  182. iinfo.self_tensor.copy_(matched_rhs)
  183. def invoke_getitem(iinfo: IndexingInfo) -> Any:
  184. if iinfo.advanced_indexing:
  185. self_tensor = iinfo.self_tensor
  186. tup = slice_to_tuple(iinfo.flat_inputs)
  187. if self_tensor is None:
  188. raise RuntimeError("Cannot getitem on None tensor")
  189. rtensor = self_tensor[tup]
  190. else:
  191. rtensor = iinfo.self_tensor # type: ignore[assignment]
  192. if rtensor is None:
  193. raise RuntimeError("Cannot getitem on None tensor")
  194. # rtensor is now guaranteed to be not None
  195. # Create a Tensor with the proper dimensions using the class method
  196. from . import Tensor
  197. return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
  198. def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
  199. from . import DimList # Import DimList for type checking
  200. can_call_original_getitem = not tensors_have_dims
  201. input_list = []
  202. if has_dims(index):
  203. input_list.append(index)
  204. else:
  205. is_sequence = extractIndices(index, input_list)
  206. # nothing about first class dims here, fallback to getitem
  207. if can_call_original_getitem and not is_sequence:
  208. return IndexingInfo(can_call_original=True)
  209. # Calculate how many dimensions have been indexed in order to compute the
  210. # size of ... or expand a potentially unbound dimension list.
  211. dims_indexed = 0
  212. expanding_object = -1
  213. unbound_dim_list = None
  214. dimlists = [] # Track DimList positions for later processing
  215. def check_expanding(i: int) -> None:
  216. nonlocal expanding_object
  217. if expanding_object != -1:
  218. from . import DimensionBindError
  219. raise DimensionBindError(
  220. f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
  221. f"{expanding_object} and {i}"
  222. )
  223. expanding_object = i
  224. def is_dimpack(s: Any) -> bool:
  225. from . import Dim
  226. return (
  227. isinstance(s, (tuple, list))
  228. and len(s) > 0
  229. and all(Dim.check_exact(item) for item in s)
  230. )
  231. has_dimpacks_or_none = False
  232. for i, s in enumerate(input_list):
  233. if has_dims(s):
  234. can_call_original_getitem = False
  235. dims_indexed += 1
  236. elif s is ...:
  237. check_expanding(i)
  238. elif isinstance(s, DimList):
  239. can_call_original_getitem = False
  240. if not s.is_bound:
  241. check_expanding(i)
  242. unbound_dim_list = s
  243. else:
  244. dims_indexed += len(s._dims)
  245. dimlists.append(i)
  246. elif s is None:
  247. has_dimpacks_or_none = True
  248. elif is_dimpack(s):
  249. can_call_original_getitem = False
  250. has_dimpacks_or_none = True
  251. dims_indexed += 1
  252. else:
  253. dims_indexed += 1
  254. # Early return if we can use original getitem
  255. if can_call_original_getitem:
  256. return IndexingInfo(can_call_original=True)
  257. self_info = TensorInfo.create(self, False, True)
  258. total_dims = len(self_info.levels) # Total dimensions (positional + named)
  259. if dims_indexed > total_dims:
  260. raise ValueError(
  261. f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
  262. )
  263. # Expand any unbound dimension list, or expand ... into individual : slices.
  264. expanding_dims = total_dims - dims_indexed
  265. if expanding_object != -1:
  266. if unbound_dim_list is not None:
  267. # Bind unbound dimension list to the expanding dimensions
  268. unbound_dim_list.bind_len(expanding_dims)
  269. else:
  270. # Expand ... into slice(None) objects
  271. no_slices = [slice(None)] * expanding_dims
  272. input_list = (
  273. input_list[:expanding_object]
  274. + no_slices
  275. + input_list[expanding_object + 1 :]
  276. )
  277. # Flatten out any dimensions stored in dimlist elements directly into the inputs
  278. # Process in reverse order to maintain indices
  279. for i in range(len(dimlists) - 1, -1, -1):
  280. idx = dimlists[i]
  281. # We added more elements to input because of ...
  282. # so we need to also adjust the index to get back to where the
  283. # dimlist existed
  284. if (
  285. unbound_dim_list is None
  286. and expanding_object != -1
  287. and idx > expanding_object
  288. ):
  289. idx += expanding_dims
  290. dl = input_list[idx]
  291. # PRIVATE here naughty
  292. input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
  293. return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
  294. def getsetitem_flat(
  295. self_info: TensorInfo,
  296. input_list: list,
  297. keys: list[DimEntry],
  298. values: list,
  299. has_dimpacks_or_none: bool,
  300. ) -> IndexingInfo:
  301. from . import Dim
  302. # Track dimension usage
  303. seen_dims: list[Any] = []
  304. seen_dims_nuses: list[int] = []
  305. def add_dim(dim: Any) -> None:
  306. # Use safe indexing to avoid triggering __torch_function__ on Dim objects
  307. idx = _safe_index(seen_dims, dim)
  308. if idx is not None:
  309. seen_dims_nuses[idx] += 1
  310. else:
  311. seen_dims.append(dim)
  312. seen_dims_nuses.append(1)
  313. flat_inputs = []
  314. tensor_inputs: list[Any] = []
  315. device_holding_tensor = None
  316. def append_flat_handle(handle: Any) -> None:
  317. flat_inputs.append(handle)
  318. tensor_inputs.append(None)
  319. def append_tensor_input(ti: TensorInfo) -> None:
  320. flat_inputs.append(None)
  321. tensor_inputs.append(ti)
  322. nonlocal device_holding_tensor
  323. if ti.has_device and device_holding_tensor is None:
  324. device_holding_tensor = ti.tensor
  325. nsz = []
  326. nsd = []
  327. if self_info.tensor is None:
  328. raise RuntimeError("Cannot get size/stride on None tensor")
  329. sz = self_info.tensor.size()
  330. sd = self_info.tensor.stride()
  331. def append_size(i: int) -> None:
  332. if has_dimpacks_or_none:
  333. nsz.append(sz[i])
  334. nsd.append(sd[i])
  335. input_it = input_list[:]
  336. def parse_nones() -> None:
  337. nonlocal input_it
  338. while input_it and input_it[0] is None:
  339. append_flat_handle(slice(None))
  340. nsz.append(1)
  341. nsd.append(0)
  342. input_it = input_it[1:]
  343. def append_item(i: int, arg: Any) -> None:
  344. if Dim.check_exact(arg):
  345. d = arg
  346. if d._size == -1:
  347. d.size = sz[i]
  348. add_dim(d)
  349. append_size(i)
  350. append_flat_handle(arg)
  351. return
  352. info = TensorInfo.create(arg, False, False)
  353. if info:
  354. append_size(i)
  355. append_tensor_input(info)
  356. for level in info.levels:
  357. if not level.is_positional():
  358. add_dim(level.dim())
  359. return
  360. if has_dimpacks_or_none:
  361. if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
  362. # dim pack
  363. dim_pack = list(arg)
  364. for d in dim_pack:
  365. add_dim(d)
  366. append_flat_handle(d)
  367. _bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
  368. return
  369. append_size(i)
  370. append_flat_handle(arg)
  371. # Match indexing expressions with tensor dimensions
  372. for i, level in enumerate(self_info.levels):
  373. # Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
  374. idx = _safe_index(keys, level)
  375. if idx is not None:
  376. append_item(i, values[idx])
  377. else:
  378. if level.is_positional():
  379. parse_nones()
  380. if not input_it:
  381. append_flat_handle(slice(None))
  382. append_size(i)
  383. else:
  384. arg = input_it[0]
  385. input_it = input_it[1:]
  386. append_item(i, arg)
  387. else:
  388. add_dim(level.dim())
  389. append_flat_handle(level.dim())
  390. append_size(i)
  391. parse_nones()
  392. # Restride tensor if needed
  393. if has_dimpacks_or_none and nsz:
  394. if self_info.tensor is None:
  395. raise RuntimeError("Cannot restride None tensor")
  396. self_tensor = self_info.tensor.as_strided(
  397. nsz, nsd, self_info.tensor.storage_offset()
  398. )
  399. else:
  400. self_tensor = self_info.tensor
  401. # Determine result shape and indexing requirements
  402. result_levels: list[Any] = []
  403. index_levels = []
  404. tensor_insert_point = -1
  405. requires_getindex = False
  406. def mark_tensor_index() -> None:
  407. nonlocal tensor_insert_point
  408. if tensor_insert_point == -1:
  409. tensor_insert_point = len(result_levels)
  410. elif tensor_insert_point != len(result_levels):
  411. tensor_insert_point = 0
  412. for i, inp in enumerate(flat_inputs):
  413. if tensor_inputs[i] is not None:
  414. requires_getindex = True
  415. mark_tensor_index()
  416. for level in tensor_inputs[i].levels:
  417. if level not in index_levels:
  418. index_levels.append(level)
  419. elif Dim.check_exact(inp):
  420. d = inp
  421. # Use safe indexing to avoid triggering __torch_function__
  422. dim_idx = _safe_index(seen_dims, d)
  423. if dim_idx is None:
  424. raise AssertionError(f"Dim {d} not found in seen_dims")
  425. if seen_dims_nuses[dim_idx] == 1:
  426. flat_inputs[i] = slice(None)
  427. result_levels.append(DimEntry(d))
  428. else:
  429. requires_getindex = True
  430. flat_inputs[i] = None
  431. tensor_inputs[i] = TensorInfo(
  432. d._get_range(), [DimEntry(d)], False, None
  433. )
  434. if DimEntry(d) not in index_levels:
  435. index_levels.append(DimEntry(d))
  436. mark_tensor_index()
  437. else:
  438. if inp != slice(None):
  439. requires_getindex = True
  440. if not isinstance(inp, int):
  441. result_levels.append(DimEntry(-1))
  442. # Insert indexing dimensions at first tensor use point
  443. if tensor_insert_point != -1:
  444. for level in reversed(index_levels):
  445. result_levels.insert(tensor_insert_point, level)
  446. # Match tensors to indexing shape
  447. if requires_getindex:
  448. for i in range(len(flat_inputs)):
  449. if tensor_inputs[i] is not None:
  450. t = tensor_inputs[i].tensor
  451. if t is None:
  452. raise AssertionError("TensorInfo should have valid tensor data")
  453. if (
  454. not tensor_inputs[i].has_device
  455. and device_holding_tensor is not None
  456. ):
  457. t = t.to(device_holding_tensor.device)
  458. flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
  459. # Number positional dimensions correctly
  460. seen_positionals = 0
  461. for i in reversed(range(len(result_levels))):
  462. if result_levels[i].is_positional():
  463. seen_positionals += 1
  464. result_levels[i] = DimEntry(-seen_positionals)
  465. return IndexingInfo(
  466. can_call_original=False,
  467. advanced_indexing=requires_getindex,
  468. self_tensor=self_tensor,
  469. flat_inputs=flat_inputs,
  470. result_levels=result_levels,
  471. has_device=self_info.has_device,
  472. )