_util.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # mypy: ignore-errors
  2. """Assorted utilities, which do not need anything other then torch and stdlib."""
  3. import operator
  4. import torch
  5. from . import _dtypes_impl
  6. # https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
  7. def is_sequence(seq):
  8. if isinstance(seq, str):
  9. return False
  10. try:
  11. len(seq)
  12. except Exception:
  13. return False
  14. return True
  15. class AxisError(ValueError, IndexError):
  16. pass
  17. class UFuncTypeError(TypeError, RuntimeError):
  18. pass
  19. def cast_if_needed(tensor, dtype):
  20. # NB: no casting if dtype=None
  21. if dtype is not None and tensor.dtype != dtype:
  22. tensor = tensor.to(dtype)
  23. return tensor
  24. def cast_int_to_float(x):
  25. # cast integers and bools to the default float dtype
  26. if _dtypes_impl._category(x.dtype) < 2:
  27. x = x.to(_dtypes_impl.default_dtypes().float_dtype)
  28. return x
  29. # a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
  30. def normalize_axis_index(ax, ndim, argname=None):
  31. if not (-ndim <= ax < ndim):
  32. raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}")
  33. if ax < 0:
  34. ax += ndim
  35. return ax
  36. # from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378
  37. def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
  38. """
  39. Normalizes an axis argument into a tuple of non-negative integer axes.
  40. This handles shorthands such as ``1`` and converts them to ``(1,)``,
  41. as well as performing the handling of negative indices covered by
  42. `normalize_axis_index`.
  43. By default, this forbids axes from being specified multiple times.
  44. Used internally by multi-axis-checking logic.
  45. Parameters
  46. ----------
  47. axis : int, iterable of int
  48. The un-normalized index or indices of the axis.
  49. ndim : int
  50. The number of dimensions of the array that `axis` should be normalized
  51. against.
  52. argname : str, optional
  53. A prefix to put before the error message, typically the name of the
  54. argument.
  55. allow_duplicate : bool, optional
  56. If False, the default, disallow an axis from being specified twice.
  57. Returns
  58. -------
  59. normalized_axes : tuple of int
  60. The normalized axis index, such that `0 <= normalized_axis < ndim`
  61. """
  62. # Optimization to speed-up the most common cases.
  63. if type(axis) not in (tuple, list):
  64. try:
  65. axis = [operator.index(axis)]
  66. except TypeError:
  67. pass
  68. # Going via an iterator directly is slower than via list comprehension.
  69. axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
  70. if not allow_duplicate and len(set(map(int, axis))) != len(axis):
  71. if argname:
  72. raise ValueError(f"repeated axis in `{argname}` argument")
  73. else:
  74. raise ValueError("repeated axis")
  75. return axis
  76. def allow_only_single_axis(axis):
  77. if axis is None:
  78. return axis
  79. if len(axis) != 1:
  80. raise NotImplementedError("does not handle tuple axis")
  81. return axis[0]
  82. def expand_shape(arr_shape, axis):
  83. # taken from numpy 1.23.x, expand_dims function
  84. if type(axis) not in (list, tuple):
  85. axis = (axis,)
  86. out_ndim = len(axis) + len(arr_shape)
  87. axis = normalize_axis_tuple(axis, out_ndim)
  88. shape_it = iter(arr_shape)
  89. shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
  90. return shape
  91. def apply_keepdims(tensor, axis, ndim):
  92. if axis is None:
  93. # tensor was a scalar
  94. shape = (1,) * ndim
  95. tensor = tensor.expand(shape).contiguous()
  96. else:
  97. shape = expand_shape(tensor.shape, axis)
  98. tensor = tensor.reshape(shape)
  99. return tensor
  100. def axis_none_flatten(*tensors, axis=None):
  101. """Flatten the arrays if axis is None."""
  102. if axis is None:
  103. tensors = tuple(ar.flatten() for ar in tensors)
  104. return tensors, 0
  105. else:
  106. return tensors, axis
  107. def typecast_tensor(t, target_dtype, casting):
  108. """Dtype-cast tensor to target_dtype.
  109. Parameters
  110. ----------
  111. t : torch.Tensor
  112. The tensor to cast
  113. target_dtype : torch dtype object
  114. The array dtype to cast all tensors to
  115. casting : str
  116. The casting mode, see `np.can_cast`
  117. Returns
  118. -------
  119. `torch.Tensor` of the `target_dtype` dtype
  120. Raises
  121. ------
  122. ValueError
  123. if the argument cannot be cast according to the `casting` rule
  124. """
  125. can_cast = _dtypes_impl.can_cast_impl
  126. if not can_cast(t.dtype, target_dtype, casting=casting):
  127. raise TypeError(
  128. f"Cannot cast array data from {t.dtype} to"
  129. f" {target_dtype} according to the rule '{casting}'"
  130. )
  131. return cast_if_needed(t, target_dtype)
  132. def typecast_tensors(tensors, target_dtype, casting):
  133. return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
  134. def _try_convert_to_tensor(obj):
  135. try:
  136. tensor = torch.as_tensor(obj)
  137. except Exception as e:
  138. mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}."
  139. raise NotImplementedError(mesg) # noqa: B904
  140. return tensor
  141. def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
  142. """The core logic of the array(...) function.
  143. Parameters
  144. ----------
  145. obj : tensor_like
  146. The thing to coerce
  147. dtype : torch.dtype object or None
  148. Coerce to this torch dtype
  149. copy : bool
  150. Copy or not
  151. ndmin : int
  152. The results as least this many dimensions
  153. is_weak : bool
  154. Whether obj is a weakly typed python scalar.
  155. Returns
  156. -------
  157. tensor : torch.Tensor
  158. a tensor object with requested dtype, ndim and copy semantics.
  159. Notes
  160. -----
  161. This is almost a "tensor_like" coercive function. Does not handle wrapper
  162. ndarrays (those should be handled in the ndarray-aware layer prior to
  163. invoking this function).
  164. """
  165. if isinstance(obj, torch.Tensor):
  166. tensor = obj
  167. else:
  168. # tensor.dtype is the pytorch default, typically float32. If obj's elements
  169. # are not exactly representable in float32, we've lost precision:
  170. # >>> torch.as_tensor(1e12).item() - 1e12
  171. # -4096.0
  172. default_dtype = torch.get_default_dtype()
  173. torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32))
  174. try:
  175. tensor = _try_convert_to_tensor(obj)
  176. finally:
  177. torch.set_default_dtype(default_dtype)
  178. # type cast if requested
  179. tensor = cast_if_needed(tensor, dtype)
  180. # adjust ndim if needed
  181. ndim_extra = ndmin - tensor.ndim
  182. if ndim_extra > 0:
  183. tensor = tensor.view((1,) * ndim_extra + tensor.shape)
  184. # special handling for np._CopyMode
  185. try:
  186. copy = bool(copy)
  187. except ValueError:
  188. # TODO handle _CopyMode.IF_NEEDED correctly
  189. copy = False
  190. # copy if requested
  191. if copy:
  192. tensor = tensor.clone()
  193. return tensor
  194. def ndarrays_to_tensors(*inputs):
  195. """Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
  196. from ._ndarray import ndarray
  197. if len(inputs) == 0:
  198. return ValueError()
  199. elif len(inputs) == 1:
  200. input_ = inputs[0]
  201. if isinstance(input_, ndarray):
  202. return input_.tensor
  203. elif isinstance(input_, tuple):
  204. result = []
  205. for sub_input in input_:
  206. sub_result = ndarrays_to_tensors(sub_input)
  207. result.append(sub_result)
  208. return tuple(result)
  209. else:
  210. return input_
  211. else:
  212. if not isinstance(inputs, tuple):
  213. raise AssertionError(
  214. f"Expected inputs to be a tuple, got {type(inputs).__name__}"
  215. )
  216. return ndarrays_to_tensors(inputs)