_index.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. """Indexing mixin for sparse array/matrix classes.
  2. """
  3. import numpy as np
  4. from ._sputils import isintlike
  5. from ._base import sparray, issparse
  6. INT_TYPES = (int, np.integer)
  7. def _broadcast_arrays(*arrays):
  8. """
  9. Same as np.broadcast_arrays(a, b) but old writeability rules.
  10. NumPy >= 1.17.0 transitions broadcast_arrays to return
  11. read-only arrays. Set writeability explicitly to avoid warnings.
  12. Retain the old writeability rules, as our Cython code assumes
  13. the old behavior.
  14. """
  15. broadcast_arrays = np.broadcast_arrays(*arrays)
  16. for x, a in zip(broadcast_arrays, arrays):
  17. x.flags.writeable = a.flags.writeable
  18. return broadcast_arrays
  19. class IndexMixin:
  20. """
  21. This class provides common dispatching and validation logic for indexing.
  22. """
  23. def __getitem__(self, key):
  24. index, new_shape, _, _ = _validate_indices(key, self.shape, self.format)
  25. if len(new_shape) > 2:
  26. raise IndexError("Indexing that leads to >2D is not supported by "
  27. f"{self.format} format. Try converting to COO format")
  28. # 1D array
  29. if len(index) == 1:
  30. idx = index[0]
  31. if isinstance(idx, np.ndarray):
  32. if idx.shape == ():
  33. idx = idx.item()
  34. if isinstance(idx, INT_TYPES):
  35. res = self._get_int(idx)
  36. elif isinstance(idx, slice):
  37. res = self._get_slice(idx)
  38. else: # assume array idx
  39. res = self._get_array(idx)
  40. # package the result and return
  41. if not isinstance(self, sparray):
  42. return res
  43. # handle np.newaxis in idx when result would otherwise be a scalar
  44. if res.shape == () and new_shape != ():
  45. if len(new_shape) == 1:
  46. return self.__class__([res], shape=new_shape, dtype=self.dtype)
  47. if len(new_shape) == 2:
  48. return self.__class__([[res]], shape=new_shape, dtype=self.dtype)
  49. return res.reshape(new_shape)
  50. # 2D array
  51. row, col = index
  52. # Dispatch to specialized methods.
  53. if isinstance(row, INT_TYPES):
  54. if isinstance(col, INT_TYPES):
  55. res = self._get_intXint(row, col)
  56. elif isinstance(col, slice):
  57. res = self._get_intXslice(row, col)
  58. elif col.ndim == 1:
  59. res = self._get_intXarray(row, col)
  60. elif col.ndim == 2:
  61. res = self._get_intXarray(row, col)
  62. else:
  63. raise IndexError('index results in >2 dimensions')
  64. elif isinstance(row, slice):
  65. if isinstance(col, INT_TYPES):
  66. res = self._get_sliceXint(row, col)
  67. elif isinstance(col, slice):
  68. if row == slice(None) and row == col:
  69. res = self.copy()
  70. else:
  71. res = self._get_sliceXslice(row, col)
  72. elif col.ndim == 1:
  73. res = self._get_sliceXarray(row, col)
  74. else:
  75. raise IndexError('index results in >2 dimensions')
  76. else:
  77. if isinstance(col, INT_TYPES):
  78. res = self._get_arrayXint(row, col)
  79. elif isinstance(col, slice):
  80. res = self._get_arrayXslice(row, col)
  81. # arrayXarray preprocess
  82. elif (row.ndim == 2 and row.shape[1] == 1
  83. and (col.ndim == 1 or col.shape[0] == 1)):
  84. # outer indexing
  85. res = self._get_columnXarray(row[:, 0], col.reshape(-1))
  86. else:
  87. # inner indexing
  88. row, col = _broadcast_arrays(row, col)
  89. if row.shape != col.shape:
  90. raise IndexError('number of row and column indices differ')
  91. if row.size == 0:
  92. res = self.__class__(np.atleast_2d(row).shape, dtype=self.dtype)
  93. else:
  94. res = self._get_arrayXarray(row, col)
  95. # handle spmatrix (must be 2d, dont let 1d new_shape start reshape)
  96. if not isinstance(self, sparray):
  97. if new_shape == () or (len(new_shape) == 1 and res.ndim != 0):
  98. # res handles cases not inflated by None
  99. return res
  100. if len(new_shape) == 1:
  101. # shape inflated to 1D by None in index. Make 2D
  102. new_shape = (1,) + new_shape
  103. # reshape if needed (when None changes shape, e.g. A[1,:,None])
  104. return res if new_shape == res.shape else res.reshape(new_shape)
  105. # package the result and return
  106. if res.shape != new_shape:
  107. # handle formats that support indexing but not 1D (lil for now)
  108. if self.format == "lil" and len(new_shape) != 2:
  109. if res.shape == ():
  110. return self._coo_container([res], shape = new_shape)
  111. return res.tocoo().reshape(new_shape)
  112. return res.reshape(new_shape)
  113. return res
  114. def __setitem__(self, key, x):
  115. index, new_shape, _, _ = _validate_indices(key, self.shape, self.format)
  116. # 1D array
  117. if len(index) == 1:
  118. idx = index[0]
  119. if issparse(x):
  120. x = x.toarray()
  121. else:
  122. x = np.asarray(x, dtype=self.dtype)
  123. if isinstance(idx, INT_TYPES):
  124. if x.size != 1:
  125. raise ValueError('Trying to assign a sequence to an item')
  126. self._set_int(idx, x.flat[0])
  127. return
  128. if isinstance(idx, slice):
  129. # check for simple case of slice that gives 1 item
  130. # Note: Python `range` does not use lots of memory
  131. idx_range = range(*idx.indices(self.shape[0]))
  132. N = len(idx_range)
  133. if N == 1 and x.size == 1:
  134. self._set_int(idx_range[0], x.flat[0])
  135. return
  136. idx = np.arange(*idx.indices(self.shape[0]))
  137. idx_shape = idx.shape
  138. else:
  139. idx_shape = idx.squeeze().shape
  140. # broadcast scalar to full 1d
  141. if x.squeeze().shape != idx_shape:
  142. x = np.broadcast_to(x, idx.shape)
  143. if x.size != 0:
  144. self._set_array(idx, x)
  145. return
  146. # 2D array
  147. row, col = index
  148. if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES):
  149. if issparse(x):
  150. x = x.toarray()
  151. else:
  152. x = np.asarray(x, dtype=self.dtype)
  153. if x.size != 1:
  154. raise ValueError('Trying to assign a sequence to an item')
  155. self._set_intXint(row, col, x.flat[0])
  156. return
  157. if isinstance(row, slice):
  158. row = np.arange(*row.indices(self.shape[0]))[:, None]
  159. else:
  160. row = np.atleast_1d(row)
  161. if isinstance(col, slice):
  162. col = np.arange(*col.indices(self.shape[1]))[None, :]
  163. if row.ndim == 1:
  164. row = row[:, None]
  165. else:
  166. col = np.atleast_1d(col)
  167. i, j = _broadcast_arrays(row, col)
  168. if i.shape != j.shape:
  169. raise IndexError('number of row and column indices differ')
  170. if issparse(x):
  171. if 0 in x.shape:
  172. return
  173. if i.ndim == 1:
  174. # Inner indexing, so treat them like row vectors.
  175. i = i[None]
  176. j = j[None]
  177. x = x.tocoo(copy=False).reshape(x._shape_as_2d, copy=True)
  178. broadcast_row = x.shape[0] == 1 and i.shape[0] != 1
  179. broadcast_col = x.shape[1] == 1 and i.shape[1] != 1
  180. if not ((broadcast_row or x.shape[0] == i.shape[0]) and
  181. (broadcast_col or x.shape[1] == i.shape[1])):
  182. raise ValueError('shape mismatch in assignment')
  183. x.sum_duplicates()
  184. self._set_arrayXarray_sparse(i, j, x)
  185. else:
  186. # Make x and i into the same shape
  187. x = np.asarray(x, dtype=self.dtype)
  188. if x.squeeze().shape != i.squeeze().shape:
  189. x = np.broadcast_to(x, i.shape)
  190. if x.size == 0:
  191. return
  192. x = x.reshape(i.shape)
  193. self._set_arrayXarray(i, j, x)
  194. def _getrow(self, i):
  195. """Return a copy of row i of the matrix, as a (1 x n) row vector.
  196. """
  197. M, N = self.shape
  198. i = int(i)
  199. if i < -M or i >= M:
  200. raise IndexError(f'index ({i}) out of range')
  201. if i < 0:
  202. i += M
  203. return self._get_intXslice(i, slice(None))
  204. def _getcol(self, i):
  205. """Return a copy of column i of the matrix, as a (m x 1) column vector.
  206. """
  207. M, N = self.shape
  208. i = int(i)
  209. if i < -N or i >= N:
  210. raise IndexError(f'index ({i}) out of range')
  211. if i < 0:
  212. i += N
  213. return self._get_sliceXint(slice(None), i)
  214. def _get_int(self, idx):
  215. raise NotImplementedError()
  216. def _get_slice(self, idx):
  217. raise NotImplementedError()
  218. def _get_array(self, idx):
  219. raise NotImplementedError()
  220. def _get_intXint(self, row, col):
  221. raise NotImplementedError()
  222. def _get_intXarray(self, row, col):
  223. raise NotImplementedError()
  224. def _get_intXslice(self, row, col):
  225. raise NotImplementedError()
  226. def _get_sliceXint(self, row, col):
  227. raise NotImplementedError()
  228. def _get_sliceXslice(self, row, col):
  229. raise NotImplementedError()
  230. def _get_sliceXarray(self, row, col):
  231. raise NotImplementedError()
  232. def _get_arrayXint(self, row, col):
  233. raise NotImplementedError()
  234. def _get_arrayXslice(self, row, col):
  235. raise NotImplementedError()
  236. def _get_columnXarray(self, row, col):
  237. raise NotImplementedError()
  238. def _get_arrayXarray(self, row, col):
  239. raise NotImplementedError()
  240. def _set_int(self, idx, x):
  241. raise NotImplementedError()
  242. def _set_array(self, idx, x):
  243. raise NotImplementedError()
  244. def _set_intXint(self, row, col, x):
  245. raise NotImplementedError()
  246. def _set_arrayXarray(self, row, col, x):
  247. raise NotImplementedError()
  248. def _set_arrayXarray_sparse(self, row, col, x):
  249. # Fall back to densifying x
  250. x = np.asarray(x.toarray(), dtype=self.dtype)
  251. x, _ = _broadcast_arrays(x, row)
  252. self._set_arrayXarray(row, col, x)
  253. def _validate_indices(key, self_shape, self_format):
  254. """Returns four sequences: (index, requested shape, arrays, nones)
  255. index : tuple of validated idx objects. bool arrays->nonzero(),
  256. arrays broadcast, ints and slices as they are, Nones removed
  257. requested shape : the shape of the indexed space, including Nones
  258. arr_pos : position within index of all arrays or ints (for array fancy indexing)
  259. none_pos : insert positions to put newaxis coords in indexed space.
  260. """
  261. self_ndim = len(self_shape)
  262. # single ellipsis
  263. if key is Ellipsis:
  264. return (slice(None),) * self_ndim, self_shape, [], []
  265. if not isinstance(key, tuple):
  266. key = [key]
  267. # pass 1:
  268. # - expand ellipsis to allow matching to self_shape
  269. # - preprocess boolean array index
  270. # - error on sparse array as an index
  271. # - count the ndim of the index and check if too long
  272. ellps_pos = None
  273. index_1st = []
  274. prelim_ndim = 0
  275. for i, idx in enumerate(key):
  276. if idx is Ellipsis:
  277. if ellps_pos is not None:
  278. raise IndexError('an index can only have a single ellipsis')
  279. ellps_pos = i
  280. elif idx is None:
  281. index_1st.append(idx)
  282. elif isinstance(idx, slice) or isintlike(idx):
  283. index_1st.append(idx)
  284. prelim_ndim += 1
  285. elif (ix := _compatible_boolean_index(idx, self_ndim)) is not None:
  286. # can't check the shape of ix until we resolve ellipsis (pass 2)
  287. index_1st.append(ix)
  288. prelim_ndim += ix.ndim
  289. elif issparse(idx):
  290. # TODO: make sparse indexing work for sparray
  291. raise IndexError(
  292. 'Indexing with sparse matrices is not supported '
  293. 'except boolean indexing where matrix and index '
  294. 'are equal shapes.')
  295. else: # dense array
  296. index_1st.append(np.asarray(idx))
  297. prelim_ndim += 1
  298. if prelim_ndim > self_ndim:
  299. raise IndexError(
  300. 'Too many indices for array or tuple index out of range. '
  301. f'Key {key} needs {prelim_ndim}D. Array is {self_ndim}D'
  302. )
  303. ellip_slices = (self_ndim - prelim_ndim) * [slice(None)]
  304. if ellip_slices:
  305. if ellps_pos is None:
  306. index_1st.extend(ellip_slices)
  307. else:
  308. index_1st = index_1st[:ellps_pos] + ellip_slices + index_1st[ellps_pos:]
  309. # second pass (have processed ellipsis and preprocessed arrays)
  310. # pass 2:
  311. # note: integer arrays provide info for one axis even if >1D array.
  312. # The shape of array affects outgoing(get)/incoming(set) shape only
  313. # - form `new_shape` (shape of outgo/incom-ing result of key
  314. # - form `index` (validated form of each slice/int/array index)
  315. # - validate and make canonical: slice and int
  316. # - turn bool arrays to int arrays via `.nonzero()`
  317. # - collect positions of Newaxis/None in `none_positions`
  318. # - collect positions of "array or int" in `arr_int_pos`
  319. idx_shape = []
  320. index_ndim = 0
  321. index = []
  322. array_indices = []
  323. none_positions = []
  324. arr_int_pos = [] # track positions of arrays and integers
  325. for i, idx in enumerate(index_1st):
  326. if idx is None:
  327. none_positions.append(len(idx_shape))
  328. idx_shape.append(1)
  329. elif isinstance(idx, slice):
  330. index.append(idx)
  331. Ms = self_shape[index_ndim]
  332. len_slice = len(range(*idx.indices(Ms)))
  333. idx_shape.append(len_slice)
  334. index_ndim += 1
  335. elif isintlike(idx):
  336. N = self_shape[index_ndim]
  337. if not (-N <= idx < N):
  338. raise IndexError(f'index ({idx}) out of range')
  339. idx = int(idx + N if idx < 0 else idx)
  340. index.append(idx)
  341. arr_int_pos.append(index_ndim)
  342. index_ndim += 1
  343. # bool array (checked in first pass)
  344. elif idx.dtype.kind == 'b':
  345. tmp_ndim = index_ndim + idx.ndim
  346. mid_shape = self_shape[index_ndim:tmp_ndim]
  347. if idx.shape != mid_shape:
  348. raise IndexError(
  349. f"bool index {i} has shape {mid_shape} instead of {idx.shape}"
  350. )
  351. index.extend(idx.nonzero())
  352. array_indices.extend(range(index_ndim, tmp_ndim))
  353. arr_int_pos.extend(range(index_ndim, tmp_ndim))
  354. index_ndim = tmp_ndim
  355. else: # dense array
  356. N = self_shape[index_ndim]
  357. idx = _asindices(idx, N, self_format)
  358. index.append(idx)
  359. arr_int_pos.append(index_ndim)
  360. array_indices.append(index_ndim)
  361. index_ndim += 1
  362. if len(array_indices) > 1:
  363. arr_shapes = [index[i].shape for i in array_indices]
  364. try:
  365. arr_shape = np.broadcast_shapes(*arr_shapes)
  366. except ValueError:
  367. shapes = " ".join(str(shp) for shp in arr_shapes)
  368. msg = (f'shape mismatch: indexing arrays could not be broadcast '
  369. f'together with shapes {shapes}')
  370. raise IndexError(msg)
  371. # len(array_indices) implies arr_int_pos has at least one element
  372. # if arrays and ints not adjacent, move to front of shape
  373. if len(arr_int_pos) != (arr_int_pos[-1] - arr_int_pos[0] + 1):
  374. idx_shape = list(arr_shape) + idx_shape
  375. else:
  376. arr_pos = arr_int_pos[0]
  377. idx_shape = idx_shape[:arr_pos] + list(arr_shape) + idx_shape[arr_pos:]
  378. elif len(array_indices) == 1:
  379. arr_shape = index[array_indices[0]].shape
  380. arr_pos = arr_int_pos[0]
  381. idx_shape = idx_shape[:arr_pos] + list(arr_shape) + idx_shape[arr_pos:]
  382. return tuple(index), tuple(idx_shape), arr_int_pos, none_positions
  383. def _asindices(idx, length, format):
  384. """Convert `idx` to a valid index for an axis with a given length.
  385. Subclasses that need special validation can override this method.
  386. """
  387. try:
  388. ix = np.asarray(idx)
  389. except (ValueError, TypeError, MemoryError) as e:
  390. raise IndexError('invalid index') from e
  391. if format != "coo" and ix.ndim not in (1, 2) or format == "coo" and ix.ndim == 0:
  392. raise IndexError(f'Index dimension must be 1 or 2. Got {ix.ndim}')
  393. # LIL routines handle bounds-checking for us, so don't do it here.
  394. if format == "lil":
  395. return ix
  396. if ix.size == 0:
  397. return ix
  398. # Check bounds
  399. max_indx = ix.max()
  400. if max_indx >= length:
  401. raise IndexError(f'index ({max_indx}) out of range')
  402. min_indx = ix.min()
  403. if min_indx < 0:
  404. if min_indx < -length:
  405. raise IndexError(f'index ({min_indx}) out of range')
  406. if ix is idx or not ix.flags.owndata:
  407. ix = ix.copy()
  408. ix[ix < 0] += length
  409. return ix
  410. def _compatible_boolean_index(idx, desired_ndim):
  411. """Check for boolean array or array-like. peek before asarray for array-like"""
  412. # use attribute ndim to indicate a compatible array and check dtype
  413. # if not, look at 1st element as quick rejection of bool, else slower asanyarray
  414. if not hasattr(idx, 'ndim'):
  415. # is first element boolean?
  416. try:
  417. ix = next(iter(idx), None)
  418. for _ in range(desired_ndim):
  419. if isinstance(ix, bool):
  420. break
  421. ix = next(iter(ix), None)
  422. else:
  423. return None
  424. except TypeError:
  425. return None
  426. # since first is boolean, construct array and check all elements
  427. idx = np.asanyarray(idx)
  428. if idx.dtype.kind == 'b':
  429. return idx
  430. return None