_ndarray.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import builtins
  4. import math
  5. import operator
  6. from collections.abc import Sequence
  7. import torch
  8. from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
  9. from ._normalizations import (
  10. ArrayLike,
  11. normalize_array_like,
  12. normalizer,
  13. NotImplementedType,
  14. )
  15. newaxis = None
  16. FLAGS = [
  17. "C_CONTIGUOUS",
  18. "F_CONTIGUOUS",
  19. "OWNDATA",
  20. "WRITEABLE",
  21. "ALIGNED",
  22. "WRITEBACKIFCOPY",
  23. "FNC",
  24. "FORC",
  25. "BEHAVED",
  26. "CARRAY",
  27. "FARRAY",
  28. ]
  29. SHORTHAND_TO_FLAGS = {
  30. "C": "C_CONTIGUOUS",
  31. "F": "F_CONTIGUOUS",
  32. "O": "OWNDATA",
  33. "W": "WRITEABLE",
  34. "A": "ALIGNED",
  35. "X": "WRITEBACKIFCOPY",
  36. "B": "BEHAVED",
  37. "CA": "CARRAY",
  38. "FA": "FARRAY",
  39. }
  40. class Flags:
  41. def __init__(self, flag_to_value: dict):
  42. invalid_keys = [k for k in flag_to_value if k not in FLAGS]
  43. if invalid_keys:
  44. raise AssertionError(f"Invalid flag keys: {invalid_keys}")
  45. self._flag_to_value = flag_to_value
  46. def __getattr__(self, attr: str):
  47. if attr.islower() and attr.upper() in FLAGS:
  48. return self[attr.upper()]
  49. else:
  50. raise AttributeError(f"No flag attribute '{attr}'")
  51. def __getitem__(self, key):
  52. if key in SHORTHAND_TO_FLAGS:
  53. key = SHORTHAND_TO_FLAGS[key]
  54. if key in FLAGS:
  55. try:
  56. return self._flag_to_value[key]
  57. except KeyError as e:
  58. raise NotImplementedError(f"{key=}") from e
  59. else:
  60. raise KeyError(f"No flag key '{key}'")
  61. def __setattr__(self, attr, value):
  62. if attr.islower() and attr.upper() in FLAGS:
  63. self[attr.upper()] = value
  64. else:
  65. super().__setattr__(attr, value)
  66. def __setitem__(self, key, value):
  67. if key in FLAGS or key in SHORTHAND_TO_FLAGS:
  68. raise NotImplementedError("Modifying flags is not implemented")
  69. else:
  70. raise KeyError(f"No flag key '{key}'")
  71. def create_method(fn, name=None):
  72. name = name or fn.__name__
  73. def f(*args, **kwargs):
  74. return fn(*args, **kwargs)
  75. f.__name__ = name
  76. f.__qualname__ = f"ndarray.{name}"
  77. return f
  78. # Map ndarray.name_method -> np.name_func
  79. # If name_func == None, it means that name_method == name_func
  80. methods = {
  81. "clip": None,
  82. "nonzero": None,
  83. "repeat": None,
  84. "round": None,
  85. "squeeze": None,
  86. "swapaxes": None,
  87. "ravel": None,
  88. # linalg
  89. "diagonal": None,
  90. "dot": None,
  91. "trace": None,
  92. # sorting
  93. "argsort": None,
  94. "searchsorted": None,
  95. # reductions
  96. "argmax": None,
  97. "argmin": None,
  98. "any": None,
  99. "all": None,
  100. "max": None,
  101. "min": None,
  102. "ptp": None,
  103. "sum": None,
  104. "prod": None,
  105. "mean": None,
  106. "var": None,
  107. "std": None,
  108. # scans
  109. "cumsum": None,
  110. "cumprod": None,
  111. # advanced indexing
  112. "take": None,
  113. "choose": None,
  114. }
  115. dunder = {
  116. "abs": "absolute",
  117. "invert": None,
  118. "pos": "positive",
  119. "neg": "negative",
  120. "gt": "greater",
  121. "lt": "less",
  122. "ge": "greater_equal",
  123. "le": "less_equal",
  124. }
  125. # dunder methods with right-looking and in-place variants
  126. ri_dunder = {
  127. "add": None,
  128. "sub": "subtract",
  129. "mul": "multiply",
  130. "truediv": "divide",
  131. "floordiv": "floor_divide",
  132. "pow": "power",
  133. "mod": "remainder",
  134. "and": "bitwise_and",
  135. "or": "bitwise_or",
  136. "xor": "bitwise_xor",
  137. "lshift": "left_shift",
  138. "rshift": "right_shift",
  139. "matmul": None,
  140. }
  141. def _upcast_int_indices(index):
  142. if isinstance(index, torch.Tensor):
  143. if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
  144. return index.to(torch.int64)
  145. elif isinstance(index, tuple):
  146. return tuple(_upcast_int_indices(i) for i in index)
  147. return index
  148. def _has_advanced_indexing(index):
  149. """Check if there's any advanced indexing"""
  150. return any(
  151. isinstance(idx, (Sequence, bool))
  152. or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
  153. for idx in index
  154. )
  155. def _numpy_compatible_indexing(index):
  156. """Convert scalar indices to lists when advanced indexing is present for NumPy compatibility."""
  157. if not isinstance(index, tuple):
  158. index = (index,)
  159. # Check if there's any advanced indexing (sequences, booleans, or tensors)
  160. has_advanced = _has_advanced_indexing(index)
  161. if not has_advanced:
  162. return index
  163. # Convert integer scalar indices to single-element lists when advanced indexing is present
  164. # Note: Do NOT convert boolean scalars (True/False) as they have special meaning in NumPy
  165. converted = []
  166. for idx in index:
  167. if isinstance(idx, int) and not isinstance(idx, bool):
  168. # Integer scalars should be converted to lists
  169. converted.append([idx])
  170. elif (
  171. isinstance(idx, torch.Tensor)
  172. and idx.ndim == 0
  173. and not torch.is_floating_point(idx)
  174. and idx.dtype != torch.bool
  175. ):
  176. # Zero-dimensional tensors holding integers should be treated the same as integer scalars
  177. converted.append([idx])
  178. else:
  179. # Everything else (booleans, lists, slices, etc.) stays as is
  180. converted.append(idx)
  181. return tuple(converted)
  182. def _get_bool_depth(s):
  183. """Returns the depth of a boolean sequence/tensor"""
  184. if isinstance(s, bool):
  185. return True, 0
  186. if isinstance(s, torch.Tensor) and s.dtype == torch.bool:
  187. return True, s.ndim
  188. if not (isinstance(s, Sequence) and s and s[0] != s):
  189. return False, 0
  190. is_bool, depth = _get_bool_depth(s[0])
  191. return is_bool, depth + 1
  192. def _numpy_empty_ellipsis_patch(index, tensor_ndim):
  193. """
  194. Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions.
  195. In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array,
  196. it still acts as a separator between advanced indices. PyTorch doesn't have this behavior.
  197. This function detects when we have:
  198. 1. Advanced indexing on both sides of an ellipsis
  199. 2. The ellipsis doesn't actually match any dimensions
  200. """
  201. if not isinstance(index, tuple):
  202. index = (index,)
  203. # Find ellipsis position
  204. ellipsis_pos = None
  205. for i, idx in enumerate(index):
  206. if idx is Ellipsis:
  207. ellipsis_pos = i
  208. break
  209. # If no ellipsis, no patch needed
  210. if ellipsis_pos is None:
  211. return index, lambda x: x, lambda x: x
  212. # Count non-ellipsis dimensions consumed by the index
  213. consumed_dims = 0
  214. for idx in index:
  215. is_bool, depth = _get_bool_depth(idx)
  216. if is_bool:
  217. consumed_dims += depth
  218. elif idx is Ellipsis or idx is None:
  219. continue
  220. else:
  221. consumed_dims += 1
  222. # Calculate how many dimensions the ellipsis should match
  223. ellipsis_dims = tensor_ndim - consumed_dims
  224. # Check if ellipsis doesn't match any dimensions
  225. if ellipsis_dims == 0:
  226. # Check if we have advanced indexing on both sides of ellipsis
  227. left_advanced = _has_advanced_indexing(index[:ellipsis_pos])
  228. right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :])
  229. if left_advanced and right_advanced:
  230. # This is the case where NumPy and PyTorch differ
  231. # We need to ensure the advanced indices are treated as separated
  232. new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :]
  233. end_ndims = 1 + sum(
  234. 1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice)
  235. )
  236. def squeeze_fn(x):
  237. return x.squeeze(-end_ndims)
  238. def unsqueeze_fn(x):
  239. if isinstance(x, torch.Tensor) and x.ndim >= end_ndims:
  240. return x.unsqueeze(-end_ndims)
  241. return x
  242. return new_index, squeeze_fn, unsqueeze_fn
  243. return index, lambda x: x, lambda x: x
  244. # Used to indicate that a parameter is unspecified (as opposed to explicitly
  245. # `None`)
  246. class _Unspecified:
  247. pass
  248. _Unspecified.unspecified = _Unspecified()
  249. ###############################################################
  250. # ndarray class #
  251. ###############################################################
  252. class ndarray:
  253. def __init__(self, t=None):
  254. if t is None:
  255. self.tensor = torch.Tensor()
  256. elif isinstance(t, torch.Tensor):
  257. self.tensor = t
  258. else:
  259. raise ValueError(
  260. "ndarray constructor is not recommended; prefer"
  261. "either array(...) or zeros/empty(...)"
  262. )
  263. # Register NumPy functions as methods
  264. for method, name in methods.items():
  265. fn = getattr(_funcs, name or method)
  266. vars()[method] = create_method(fn, method)
  267. # Regular methods but coming from ufuncs
  268. conj = create_method(_ufuncs.conjugate, "conj")
  269. conjugate = create_method(_ufuncs.conjugate)
  270. for method, name in dunder.items():
  271. fn = getattr(_ufuncs, name or method)
  272. method = f"__{method}__"
  273. vars()[method] = create_method(fn, method)
  274. for method, name in ri_dunder.items():
  275. fn = getattr(_ufuncs, name or method)
  276. plain = f"__{method}__"
  277. vars()[plain] = create_method(fn, plain)
  278. rvar = f"__r{method}__"
  279. vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
  280. ivar = f"__i{method}__"
  281. vars()[ivar] = create_method(
  282. lambda self, other, fn=fn: fn(self, other, out=self), ivar
  283. )
  284. # There's no __idivmod__
  285. __divmod__ = create_method(_ufuncs.divmod, "__divmod__")
  286. __rdivmod__ = create_method(
  287. lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
  288. )
  289. # prevent loop variables leaking into the ndarray class namespace
  290. del ivar, rvar, name, plain, fn, method
  291. @property
  292. def shape(self):
  293. return tuple(self.tensor.shape)
  294. @property
  295. def size(self):
  296. return self.tensor.numel()
  297. @property
  298. def ndim(self):
  299. return self.tensor.ndim
  300. @property
  301. def dtype(self):
  302. return _dtypes.dtype(self.tensor.dtype)
  303. @property
  304. def strides(self):
  305. elsize = self.tensor.element_size()
  306. return tuple(stride * elsize for stride in self.tensor.stride())
  307. @property
  308. def itemsize(self):
  309. return self.tensor.element_size()
  310. @property
  311. def flags(self):
  312. # Note contiguous in torch is assumed C-style
  313. return Flags(
  314. {
  315. "C_CONTIGUOUS": self.tensor.is_contiguous(),
  316. "F_CONTIGUOUS": self.T.tensor.is_contiguous(),
  317. "OWNDATA": self.tensor._base is None,
  318. "WRITEABLE": True, # pytorch does not have readonly tensors
  319. }
  320. )
  321. @property
  322. def data(self):
  323. return self.tensor.data_ptr()
  324. @property
  325. def nbytes(self):
  326. return self.tensor.storage().nbytes()
  327. @property
  328. def T(self):
  329. return self.transpose()
  330. @property
  331. def real(self):
  332. return _funcs.real(self)
  333. @real.setter
  334. def real(self, value):
  335. self.tensor.real = asarray(value).tensor
  336. @property
  337. def imag(self):
  338. return _funcs.imag(self)
  339. @imag.setter
  340. def imag(self, value):
  341. self.tensor.imag = asarray(value).tensor
  342. # ctors
  343. def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
  344. if order != "K":
  345. raise NotImplementedError(f"astype(..., order={order} is not implemented.")
  346. if casting != "unsafe":
  347. raise NotImplementedError(
  348. f"astype(..., casting={casting} is not implemented."
  349. )
  350. if not subok:
  351. raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
  352. if not copy:
  353. raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
  354. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  355. t = self.tensor.to(torch_dtype)
  356. return ndarray(t)
  357. @normalizer
  358. def copy(self: ArrayLike, order: NotImplementedType = "C"):
  359. return self.clone()
  360. @normalizer
  361. def flatten(self: ArrayLike, order: NotImplementedType = "C"):
  362. return torch.flatten(self)
  363. def resize(self, *new_shape, refcheck=False):
  364. # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
  365. if refcheck:
  366. raise NotImplementedError(
  367. f"resize(..., refcheck={refcheck} is not implemented."
  368. )
  369. if new_shape in [(), (None,)]:
  370. return
  371. # support both x.resize((2, 2)) and x.resize(2, 2)
  372. if len(new_shape) == 1:
  373. new_shape = new_shape[0]
  374. if isinstance(new_shape, int):
  375. new_shape = (new_shape,)
  376. if builtins.any(x < 0 for x in new_shape):
  377. raise ValueError("all elements of `new_shape` must be non-negative")
  378. new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
  379. self.tensor.resize_(new_shape)
  380. if new_numel >= old_numel:
  381. # zero-fill new elements
  382. if not self.tensor.is_contiguous():
  383. raise AssertionError("tensor must be contiguous for resize with growth")
  384. b = self.tensor.flatten() # does not copy
  385. b[old_numel:].zero_()
  386. def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
  387. if dtype is _Unspecified.unspecified:
  388. dtype = self.dtype
  389. if type is not _Unspecified.unspecified:
  390. raise NotImplementedError(f"view(..., type={type} is not implemented.")
  391. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  392. tview = self.tensor.view(torch_dtype)
  393. return ndarray(tview)
  394. @normalizer
  395. def fill(self, value: ArrayLike):
  396. # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
  397. # error out on D > 0 arrays
  398. self.tensor.fill_(value)
  399. def tolist(self):
  400. return self.tensor.tolist()
  401. def __iter__(self):
  402. return (ndarray(x) for x in self.tensor.__iter__())
  403. def __str__(self):
  404. return (
  405. str(self.tensor)
  406. .replace("tensor", "torch.ndarray")
  407. .replace("dtype=torch.", "dtype=")
  408. )
  409. __repr__ = create_method(__str__)
  410. def __eq__(self, other):
  411. try:
  412. return _ufuncs.equal(self, other)
  413. except (RuntimeError, TypeError):
  414. # Failed to convert other to array: definitely not equal.
  415. falsy = torch.full(self.shape, fill_value=False, dtype=bool)
  416. return asarray(falsy)
  417. def __ne__(self, other):
  418. return ~(self == other)
  419. def __index__(self):
  420. try:
  421. return operator.index(self.tensor.item())
  422. except Exception as exc:
  423. raise TypeError(
  424. "only integer scalar arrays can be converted to a scalar index"
  425. ) from exc
  426. def __bool__(self):
  427. return bool(self.tensor)
  428. def __int__(self):
  429. return int(self.tensor)
  430. def __float__(self):
  431. return float(self.tensor)
  432. def __complex__(self):
  433. return complex(self.tensor)
  434. def is_integer(self):
  435. try:
  436. v = self.tensor.item()
  437. result = int(v) == v
  438. except Exception:
  439. result = False
  440. return result
  441. def __len__(self):
  442. return self.tensor.shape[0]
  443. def __contains__(self, x):
  444. return self.tensor.__contains__(x)
  445. def transpose(self, *axes):
  446. # np.transpose(arr, axis=None) but arr.transpose(*axes)
  447. return _funcs.transpose(self, axes)
  448. def reshape(self, *shape, order="C"):
  449. # arr.reshape(shape) and arr.reshape(*shape)
  450. return _funcs.reshape(self, shape, order=order)
  451. def sort(self, axis=-1, kind=None, order=None):
  452. # ndarray.sort works in-place
  453. _funcs.copyto(self, _funcs.sort(self, axis, kind, order))
  454. def item(self, *args):
  455. # Mimic NumPy's implementation with three special cases (no arguments,
  456. # a flat index and a multi-index):
  457. # https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/methods.c#L702
  458. if args == ():
  459. return self.tensor.item()
  460. elif len(args) == 1:
  461. # int argument
  462. return self.ravel()[args[0]]
  463. else:
  464. return self.__getitem__(args)
  465. def __getitem__(self, index):
  466. tensor = self.tensor
  467. def neg_step(i, s):
  468. if not (isinstance(s, slice) and s.step is not None and s.step < 0):
  469. return s
  470. nonlocal tensor
  471. tensor = torch.flip(tensor, (i,))
  472. # Account for the fact that a slice includes the start but not the end
  473. if not (isinstance(s.start, int) or s.start is None):
  474. raise AssertionError(
  475. f"slice start must be int or None, got {type(s.start).__name__}"
  476. )
  477. if not (isinstance(s.stop, int) or s.stop is None):
  478. raise AssertionError(
  479. f"slice stop must be int or None, got {type(s.stop).__name__}"
  480. )
  481. start = s.stop + 1 if s.stop else None
  482. stop = s.start + 1 if s.start else None
  483. return slice(start, stop, -s.step)
  484. if isinstance(index, Sequence):
  485. index = type(index)(neg_step(i, s) for i, s in enumerate(index))
  486. else:
  487. index = neg_step(0, index)
  488. index = _util.ndarrays_to_tensors(index)
  489. index = _upcast_int_indices(index)
  490. # Apply NumPy-compatible indexing conversion
  491. index = _numpy_compatible_indexing(index)
  492. # Apply NumPy-compatible empty ellipsis behavior
  493. index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim)
  494. return maybe_squeeze(ndarray(tensor.__getitem__(index)))
  495. def __setitem__(self, index, value):
  496. index = _util.ndarrays_to_tensors(index)
  497. index = _upcast_int_indices(index)
  498. # Apply NumPy-compatible indexing conversion
  499. index = _numpy_compatible_indexing(index)
  500. # Apply NumPy-compatible empty ellipsis behavior
  501. index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim)
  502. if not _dtypes_impl.is_scalar(value):
  503. value = normalize_array_like(value)
  504. value = _util.cast_if_needed(value, self.tensor.dtype)
  505. return self.tensor.__setitem__(index, maybe_unsqueeze(value))
  506. take = _funcs.take
  507. put = _funcs.put
  508. def __dlpack__(self, *, stream=None):
  509. return self.tensor.__dlpack__(stream=stream)
  510. def __dlpack_device__(self):
  511. return self.tensor.__dlpack_device__()
  512. def _tolist(obj):
  513. """Recursively convert tensors into lists."""
  514. a1 = []
  515. for elem in obj:
  516. if isinstance(elem, (list, tuple)):
  517. elem = _tolist(elem)
  518. if isinstance(elem, ndarray):
  519. a1.append(elem.tensor.tolist())
  520. else:
  521. a1.append(elem)
  522. return a1
  523. # This is the ideally the only place which talks to ndarray directly.
  524. # The rest goes through asarray (preferred) or array.
  525. def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
  526. if subok is not False:
  527. raise NotImplementedError("'subok' parameter is not supported.")
  528. if like is not None:
  529. raise NotImplementedError("'like' parameter is not supported.")
  530. if order != "K":
  531. raise NotImplementedError
  532. # a happy path
  533. if (
  534. isinstance(obj, ndarray)
  535. and copy is False
  536. and dtype is None
  537. and ndmin <= obj.ndim
  538. ):
  539. return obj
  540. if isinstance(obj, (list, tuple)):
  541. # FIXME and they have the same dtype, device, etc
  542. if obj and all(isinstance(x, torch.Tensor) for x in obj):
  543. # list of arrays: *under torch.Dynamo* these are FakeTensors
  544. obj = torch.stack(obj)
  545. else:
  546. # XXX: remove tolist
  547. # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
  548. obj = _tolist(obj)
  549. # is obj an ndarray already?
  550. if isinstance(obj, ndarray):
  551. obj = obj.tensor
  552. # is a specific dtype requested?
  553. torch_dtype = None
  554. if dtype is not None:
  555. torch_dtype = _dtypes.dtype(dtype).torch_dtype
  556. tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
  557. return ndarray(tensor)
  558. def asarray(a, dtype=None, order="K", *, like=None):
  559. return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
  560. def ascontiguousarray(a, dtype=None, *, like=None):
  561. arr = asarray(a, dtype=dtype, like=like)
  562. if not arr.tensor.is_contiguous():
  563. arr.tensor = arr.tensor.contiguous()
  564. return arr
  565. def from_dlpack(x, /):
  566. t = torch.from_dlpack(x)
  567. return ndarray(t)
  568. def _extract_dtype(entry):
  569. try:
  570. dty = _dtypes.dtype(entry)
  571. except Exception:
  572. dty = asarray(entry).dtype
  573. return dty
  574. def can_cast(from_, to, casting="safe"):
  575. from_ = _extract_dtype(from_)
  576. to_ = _extract_dtype(to)
  577. return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
  578. def result_type(*arrays_and_dtypes):
  579. tensors = []
  580. for entry in arrays_and_dtypes:
  581. try:
  582. t = asarray(entry).tensor
  583. except (RuntimeError, ValueError, TypeError):
  584. dty = _dtypes.dtype(entry)
  585. t = torch.empty(1, dtype=dty.torch_dtype)
  586. tensors.append(t)
  587. torch_dtype = _dtypes_impl.result_type_impl(*tensors)
  588. return _dtypes.dtype(torch_dtype)