_multiufuncs.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. import collections
  2. import numbers
  3. import numpy as np
  4. from ._input_validation import _nonneg_int_or_fail
  5. from ._special_ufuncs import (legendre_p, assoc_legendre_p,
  6. sph_legendre_p, sph_harm_y)
  7. from ._gufuncs import (legendre_p_all, assoc_legendre_p_all,
  8. sph_legendre_p_all, sph_harm_y_all)
  9. __all__ = [
  10. "assoc_legendre_p",
  11. "assoc_legendre_p_all",
  12. "legendre_p",
  13. "legendre_p_all",
  14. "sph_harm_y",
  15. "sph_harm_y_all",
  16. "sph_legendre_p",
  17. "sph_legendre_p_all",
  18. ]
  19. class MultiUFunc:
  20. def __init__(self, ufunc_or_ufuncs, name=None, doc=None, *,
  21. force_complex_output=False, **default_kwargs):
  22. if not isinstance(ufunc_or_ufuncs, np.ufunc):
  23. if isinstance(ufunc_or_ufuncs, collections.abc.Mapping):
  24. ufuncs_iter = ufunc_or_ufuncs.values()
  25. elif isinstance(ufunc_or_ufuncs, collections.abc.Iterable):
  26. ufuncs_iter = ufunc_or_ufuncs
  27. else:
  28. raise ValueError("ufunc_or_ufuncs should be a ufunc or a"
  29. " ufunc collection")
  30. # Perform input validation to ensure all ufuncs in ufuncs are
  31. # actually ufuncs and all take the same input types.
  32. seen_input_types = set()
  33. for ufunc in ufuncs_iter:
  34. if not isinstance(ufunc, np.ufunc):
  35. raise ValueError("All ufuncs must have type `numpy.ufunc`."
  36. f" Received {ufunc_or_ufuncs}")
  37. seen_input_types.add(frozenset(x.split("->")[0] for x in ufunc.types))
  38. if len(seen_input_types) > 1:
  39. raise ValueError("All ufuncs must take the same input types.")
  40. self.__name__ = name
  41. self._ufunc_or_ufuncs = ufunc_or_ufuncs
  42. self.__doc = doc
  43. self.__force_complex_output = force_complex_output
  44. self._default_kwargs = default_kwargs
  45. self._resolve_out_shapes = None
  46. self._finalize_out = None
  47. self._key = None
  48. self._ufunc_default_args = lambda *args, **kwargs: ()
  49. self._ufunc_default_kwargs = lambda *args, **kwargs: {}
  50. @property
  51. def __doc__(self):
  52. return self.__doc
  53. def _override_key(self, func):
  54. """Set `key` method by decorating a function.
  55. """
  56. self._key = func
  57. def _override_ufunc_default_args(self, func):
  58. self._ufunc_default_args = func
  59. def _override_ufunc_default_kwargs(self, func):
  60. self._ufunc_default_kwargs = func
  61. def _override_resolve_out_shapes(self, func):
  62. """Set `resolve_out_shapes` method by decorating a function."""
  63. if func.__doc__ is None:
  64. func.__doc__ = \
  65. """Resolve to output shapes based on relevant inputs."""
  66. func.__name__ = "resolve_out_shapes"
  67. self._resolve_out_shapes = func
  68. def _override_finalize_out(self, func):
  69. self._finalize_out = func
  70. def _resolve_ufunc(self, **kwargs):
  71. """Resolve to a ufunc based on keyword arguments."""
  72. if isinstance(self._ufunc_or_ufuncs, np.ufunc):
  73. return self._ufunc_or_ufuncs
  74. ufunc_key = self._key(**kwargs)
  75. return self._ufunc_or_ufuncs[ufunc_key]
  76. def __call__(self, *args, **kwargs):
  77. kwargs = self._default_kwargs | kwargs
  78. args += self._ufunc_default_args(**kwargs)
  79. ufunc = self._resolve_ufunc(**kwargs)
  80. # array arguments to be passed to the ufunc
  81. ufunc_args = [np.asarray(arg) for arg in args[-ufunc.nin:]]
  82. ufunc_kwargs = self._ufunc_default_kwargs(**kwargs)
  83. if (self._resolve_out_shapes is not None):
  84. ufunc_arg_shapes = tuple(np.shape(ufunc_arg) for ufunc_arg in ufunc_args)
  85. ufunc_out_shapes = self._resolve_out_shapes(*args[:-ufunc.nin],
  86. *ufunc_arg_shapes, ufunc.nout,
  87. **kwargs)
  88. ufunc_arg_dtypes = tuple(ufunc_arg.dtype if hasattr(ufunc_arg, 'dtype')
  89. else np.dtype(type(ufunc_arg))
  90. for ufunc_arg in ufunc_args)
  91. if hasattr(ufunc, 'resolve_dtypes'):
  92. ufunc_dtypes = ufunc_arg_dtypes + ufunc.nout * (None,)
  93. ufunc_dtypes = ufunc.resolve_dtypes(ufunc_dtypes)
  94. ufunc_out_dtypes = ufunc_dtypes[-ufunc.nout:]
  95. else:
  96. ufunc_out_dtype = np.result_type(*ufunc_arg_dtypes)
  97. if (not np.issubdtype(ufunc_out_dtype, np.inexact)):
  98. ufunc_out_dtype = np.float64
  99. ufunc_out_dtypes = ufunc.nout * (ufunc_out_dtype,)
  100. if self.__force_complex_output:
  101. ufunc_out_dtypes = tuple(np.result_type(1j, ufunc_out_dtype)
  102. for ufunc_out_dtype in ufunc_out_dtypes)
  103. out = tuple(np.empty(ufunc_out_shape, dtype=ufunc_out_dtype)
  104. for ufunc_out_shape, ufunc_out_dtype
  105. in zip(ufunc_out_shapes, ufunc_out_dtypes))
  106. ufunc_kwargs['out'] = out
  107. out = ufunc(*ufunc_args, **ufunc_kwargs)
  108. if (self._finalize_out is not None):
  109. out = self._finalize_out(out)
  110. return out
  111. sph_legendre_p = MultiUFunc(
  112. sph_legendre_p,
  113. "sph_legendre_p",
  114. r"""sph_legendre_p(n, m, theta, *, diff_n=0)
  115. Spherical Legendre polynomial of the first kind.
  116. Parameters
  117. ----------
  118. n : ArrayLike[int]
  119. Degree of the spherical Legendre polynomial. Must have ``n >= 0``.
  120. m : ArrayLike[int]
  121. Order of the spherical Legendre polynomial.
  122. theta : ArrayLike[float]
  123. Input value.
  124. diff_n : Optional[int]
  125. A non-negative integer. Compute and return all derivatives up
  126. to order ``diff_n``. Default is 0.
  127. Returns
  128. -------
  129. p : ndarray or tuple[ndarray]
  130. Spherical Legendre polynomial with ``diff_n`` derivatives.
  131. Notes
  132. -----
  133. The spherical counterpart of an (unnormalized) associated Legendre polynomial has
  134. the additional factor
  135. .. math::
  136. \sqrt{\frac{(2 n + 1) (n - m)!}{4 \pi (n + m)!}}
  137. It is the same as the spherical harmonic :math:`Y_{n}^{m}(\theta, \phi)`
  138. with :math:`\phi = 0`.
  139. """, diff_n=0
  140. )
  141. @sph_legendre_p._override_key
  142. def _(diff_n):
  143. diff_n = _nonneg_int_or_fail(diff_n, "diff_n", strict=False)
  144. if not 0 <= diff_n <= 2:
  145. raise ValueError(
  146. "diff_n is currently only implemented for orders 0, 1, and 2,"
  147. f" received: {diff_n}."
  148. )
  149. return diff_n
  150. @sph_legendre_p._override_finalize_out
  151. def _(out):
  152. return np.moveaxis(out, -1, 0)
  153. sph_legendre_p_all = MultiUFunc(
  154. sph_legendre_p_all,
  155. "sph_legendre_p_all",
  156. """sph_legendre_p_all(n, m, theta, *, diff_n=0)
  157. All spherical Legendre polynomials of the first kind up to the
  158. specified degree ``n``, order ``m``, and all derivatives up
  159. to order ``diff_n``.
  160. Output shape is ``(diff_n + 1, n + 1, 2 * m + 1, ...)``. The entry at
  161. ``(i, j, k)`` corresponds to the ``i``-th derivative, degree ``j``, and
  162. order ``k`` for all ``0 <= i <= diff_n``, ``0 <= j <= n``, and
  163. ``-m <= k <= m``.
  164. See Also
  165. --------
  166. sph_legendre_p
  167. """, diff_n=0
  168. )
  169. @sph_legendre_p_all._override_key
  170. def _(diff_n):
  171. diff_n = _nonneg_int_or_fail(diff_n, "diff_n", strict=False)
  172. if not 0 <= diff_n <= 2:
  173. raise ValueError(
  174. "diff_n is currently only implemented for orders 0, 1, and 2,"
  175. f" received: {diff_n}."
  176. )
  177. return diff_n
  178. @sph_legendre_p_all._override_ufunc_default_kwargs
  179. def _(diff_n):
  180. return {'axes': [()] + [(0, 1, -1)]}
  181. @sph_legendre_p_all._override_resolve_out_shapes
  182. def _(n, m, theta_shape, nout, diff_n):
  183. if not isinstance(n, numbers.Integral) or (n < 0):
  184. raise ValueError("n must be a non-negative integer.")
  185. return ((n + 1, 2 * abs(m) + 1) + theta_shape + (diff_n + 1,),)
  186. @sph_legendre_p_all._override_finalize_out
  187. def _(out):
  188. return np.moveaxis(out, -1, 0)
  189. assoc_legendre_p = MultiUFunc(
  190. assoc_legendre_p,
  191. "assoc_legendre_p",
  192. r"""assoc_legendre_p(n, m, z, *, branch_cut=2, norm=False, diff_n=0)
  193. Associated Legendre polynomial of the first kind.
  194. Parameters
  195. ----------
  196. n : ArrayLike[int]
  197. Degree of the associated Legendre polynomial. Must have ``n >= 0``.
  198. m : ArrayLike[int]
  199. order of the associated Legendre polynomial.
  200. z : ArrayLike[float | complex]
  201. Input value.
  202. branch_cut : Optional[ArrayLike[int]]
  203. Selects branch cut. Must be 2 (default) or 3.
  204. 2: cut on the real axis ``|z| > 1``
  205. 3: cut on the real axis ``-1 < z < 1``
  206. norm : Optional[bool]
  207. If ``True``, compute the normalized associated Legendre polynomial.
  208. Default is ``False``.
  209. diff_n : Optional[int]
  210. A non-negative integer. Compute and return all derivatives up
  211. to order ``diff_n``. Default is 0.
  212. Returns
  213. -------
  214. p : ndarray or tuple[ndarray]
  215. Associated Legendre polynomial with ``diff_n`` derivatives.
  216. Notes
  217. -----
  218. The normalized counterpart of an (unnormalized) associated Legendre
  219. polynomial has the additional factor
  220. .. math::
  221. \sqrt{\frac{(2 n + 1) (n - m)!}{2 (n + m)!}}
  222. """, branch_cut=2, norm=False, diff_n=0
  223. )
  224. @assoc_legendre_p._override_key
  225. def _(branch_cut, norm, diff_n):
  226. diff_n = _nonneg_int_or_fail(diff_n, "diff_n", strict=False)
  227. if not 0 <= diff_n <= 2:
  228. raise ValueError(
  229. "diff_n is currently only implemented for orders 0, 1, and 2,"
  230. f" received: {diff_n}."
  231. )
  232. return norm, diff_n
  233. @assoc_legendre_p._override_ufunc_default_args
  234. def _(branch_cut, norm, diff_n):
  235. return branch_cut,
  236. @assoc_legendre_p._override_finalize_out
  237. def _(out):
  238. return np.moveaxis(out, -1, 0)
  239. assoc_legendre_p_all = MultiUFunc(
  240. assoc_legendre_p_all,
  241. "assoc_legendre_p_all",
  242. """assoc_legendre_p_all(n, m, z, *, branch_cut=2, norm=False, diff_n=0)
  243. All associated Legendre polynomials of the first kind up to the
  244. specified degree ``n``, order ``m``, and all derivatives up
  245. to order ``diff_n``.
  246. Output shape is ``(diff_n + 1, n + 1, 2 * m + 1, ...)``. The entry at
  247. ``(i, j, k)`` corresponds to the ``i``-th derivative, degree ``j``, and
  248. order ``k`` for all ``0 <= i <= diff_n``, ``0 <= j <= n``, and
  249. ``-m <= k <= m``.
  250. See Also
  251. --------
  252. assoc_legendre_p
  253. """, branch_cut=2, norm=False, diff_n=0
  254. )
  255. @assoc_legendre_p_all._override_key
  256. def _(branch_cut, norm, diff_n):
  257. if not ((isinstance(diff_n, numbers.Integral))
  258. and diff_n >= 0):
  259. raise ValueError(
  260. f"diff_n must be a non-negative integer, received: {diff_n}."
  261. )
  262. if not 0 <= diff_n <= 2:
  263. raise ValueError(
  264. "diff_n is currently only implemented for orders 0, 1, and 2,"
  265. f" received: {diff_n}."
  266. )
  267. return norm, diff_n
  268. @assoc_legendre_p_all._override_ufunc_default_args
  269. def _(branch_cut, norm, diff_n):
  270. return branch_cut,
  271. @assoc_legendre_p_all._override_ufunc_default_kwargs
  272. def _(branch_cut, norm, diff_n):
  273. return {'axes': [(), ()] + [(0, 1, -1)]}
  274. @assoc_legendre_p_all._override_resolve_out_shapes
  275. def _(n, m, z_shape, branch_cut_shape, nout, **kwargs):
  276. diff_n = kwargs['diff_n']
  277. if not isinstance(n, numbers.Integral) or (n < 0):
  278. raise ValueError("n must be a non-negative integer.")
  279. if not isinstance(m, numbers.Integral) or (m < 0):
  280. raise ValueError("m must be a non-negative integer.")
  281. return ((n + 1, 2 * abs(m) + 1) +
  282. np.broadcast_shapes(z_shape, branch_cut_shape) + (diff_n + 1,),)
  283. @assoc_legendre_p_all._override_finalize_out
  284. def _(out):
  285. return np.moveaxis(out, -1, 0)
  286. legendre_p = MultiUFunc(
  287. legendre_p,
  288. "legendre_p",
  289. """legendre_p(n, z, *, diff_n=0)
  290. Legendre polynomial of the first kind.
  291. Parameters
  292. ----------
  293. n : ArrayLike[int]
  294. Degree of the Legendre polynomial. Must have ``n >= 0``.
  295. z : ArrayLike[float]
  296. Input value.
  297. diff_n : Optional[int]
  298. A non-negative integer. Compute and return all derivatives up
  299. to order ``diff_n``. Default is 0.
  300. Returns
  301. -------
  302. p : ndarray or tuple[ndarray]
  303. Legendre polynomial with ``diff_n`` derivatives.
  304. See Also
  305. --------
  306. legendre
  307. References
  308. ----------
  309. .. [1] Zhang, Shanjie and Jin, Jianming. "Computation of Special
  310. Functions", John Wiley and Sons, 1996.
  311. https://people.sc.fsu.edu/~jburkardt/f77_src/special_functions/special_functions.html
  312. """, diff_n=0
  313. )
  314. @legendre_p._override_key
  315. def _(diff_n):
  316. if (not isinstance(diff_n, numbers.Integral)) or (diff_n < 0):
  317. raise ValueError(
  318. f"diff_n must be a non-negative integer, received: {diff_n}."
  319. )
  320. if not 0 <= diff_n <= 2:
  321. raise NotImplementedError(
  322. "diff_n is currently only implemented for orders 0, 1, and 2,"
  323. f" received: {diff_n}."
  324. )
  325. return diff_n
  326. @legendre_p._override_finalize_out
  327. def _(out):
  328. return np.moveaxis(out, -1, 0)
  329. legendre_p_all = MultiUFunc(
  330. legendre_p_all,
  331. "legendre_p_all",
  332. """legendre_p_all(n, z, *, diff_n=0)
  333. All Legendre polynomials of the first kind up to the specified degree
  334. ``n`` and all derivatives up to order ``diff_n``.
  335. Output shape is ``(diff_n + 1, n + 1, ...)``. The entry at ``(i, j)``
  336. corresponds to the ``i``-th derivative and degree ``j`` for all
  337. ``0 <= i <= diff_n`` and ``0 <= j <= n``.
  338. See Also
  339. --------
  340. legendre_p
  341. """, diff_n=0
  342. )
  343. @legendre_p_all._override_key
  344. def _(diff_n):
  345. diff_n = _nonneg_int_or_fail(diff_n, "diff_n", strict=False)
  346. if not 0 <= diff_n <= 2:
  347. raise ValueError(
  348. "diff_n is currently only implemented for orders 0, 1, and 2,"
  349. f" received: {diff_n}."
  350. )
  351. return diff_n
  352. @legendre_p_all._override_ufunc_default_kwargs
  353. def _(diff_n):
  354. return {'axes': [(), (0, -1)]}
  355. @legendre_p_all._override_resolve_out_shapes
  356. def _(n, z_shape, nout, diff_n):
  357. n = _nonneg_int_or_fail(n, 'n', strict=False)
  358. return nout * ((n + 1,) + z_shape + (diff_n + 1,),)
  359. @legendre_p_all._override_finalize_out
  360. def _(out):
  361. return np.moveaxis(out, -1, 0)
  362. sph_harm_y = MultiUFunc(
  363. sph_harm_y,
  364. "sph_harm_y",
  365. r"""sph_harm_y(n, m, theta, phi, *, diff_n=0)
  366. Spherical harmonics. They are defined as
  367. .. math::
  368. Y_n^m(\theta,\phi) = \sqrt{\frac{2 n + 1}{4 \pi} \frac{(n - m)!}{(n + m)!}}
  369. P_n^m(\cos(\theta)) e^{i m \phi}
  370. where :math:`P_n^m` are the (unnormalized) associated Legendre polynomials.
  371. Parameters
  372. ----------
  373. n : ArrayLike[int]
  374. Degree of the harmonic. Must have ``n >= 0``. This is
  375. often denoted by ``l`` (lower case L) in descriptions of
  376. spherical harmonics.
  377. m : ArrayLike[int]
  378. Order of the harmonic.
  379. theta : ArrayLike[float]
  380. Polar (colatitudinal) coordinate; must be in ``[0, pi]``.
  381. phi : ArrayLike[float]
  382. Azimuthal (longitudinal) coordinate; must be in ``[0, 2*pi]``.
  383. diff_n : Optional[int]
  384. A non-negative integer. Compute and return all derivatives up
  385. to order ``diff_n``. Default is 0.
  386. Returns
  387. -------
  388. y : ndarray[complex] or tuple[ndarray[complex]]
  389. Spherical harmonics with ``diff_n`` derivatives.
  390. Notes
  391. -----
  392. There are different conventions for the meanings of the input
  393. arguments ``theta`` and ``phi``. In SciPy ``theta`` is the
  394. polar angle and ``phi`` is the azimuthal angle. It is common to
  395. see the opposite convention, that is, ``theta`` as the azimuthal angle
  396. and ``phi`` as the polar angle.
  397. Note that SciPy's spherical harmonics include the Condon-Shortley
  398. phase [2]_ because it is part of `sph_legendre_p`.
  399. With SciPy's conventions, the first several spherical harmonics
  400. are
  401. .. math::
  402. Y_0^0(\theta, \phi) &= \frac{1}{2} \sqrt{\frac{1}{\pi}} \\
  403. Y_1^{-1}(\theta, \phi) &= \frac{1}{2} \sqrt{\frac{3}{2\pi}}
  404. e^{-i\phi} \sin(\theta) \\
  405. Y_1^0(\theta, \phi) &= \frac{1}{2} \sqrt{\frac{3}{\pi}}
  406. \cos(\theta) \\
  407. Y_1^1(\theta, \phi) &= -\frac{1}{2} \sqrt{\frac{3}{2\pi}}
  408. e^{i\phi} \sin(\theta).
  409. References
  410. ----------
  411. .. [1] Digital Library of Mathematical Functions, 14.30.
  412. https://dlmf.nist.gov/14.30
  413. .. [2] https://en.wikipedia.org/wiki/Spherical_harmonics#Condon.E2.80.93Shortley_phase
  414. """, force_complex_output=True, diff_n=0
  415. )
  416. @sph_harm_y._override_key
  417. def _(diff_n):
  418. diff_n = _nonneg_int_or_fail(diff_n, "diff_n", strict=False)
  419. if not 0 <= diff_n <= 2:
  420. raise ValueError(
  421. "diff_n is currently only implemented for orders 0, 1, and 2,"
  422. f" received: {diff_n}."
  423. )
  424. return diff_n
  425. @sph_harm_y._override_finalize_out
  426. def _(out):
  427. if (out.shape[-1] == 1):
  428. return out[..., 0, 0]
  429. if (out.shape[-1] == 2):
  430. return out[..., 0, 0], out[..., [1, 0], [0, 1]]
  431. if (out.shape[-1] == 3):
  432. return (out[..., 0, 0], out[..., [1, 0], [0, 1]],
  433. out[..., [[2, 1], [1, 0]], [[0, 1], [1, 2]]])
  434. sph_harm_y_all = MultiUFunc(
  435. sph_harm_y_all,
  436. "sph_harm_y_all",
  437. """sph_harm_y_all(n, m, theta, phi, *, diff_n=0)
  438. All spherical harmonics up to the specified degree ``n``, order ``m``,
  439. and all derivatives up to order ``diff_n``.
  440. Returns a tuple of length ``diff_n + 1`` (if ``diff_n > 0``). The first
  441. entry corresponds to the spherical harmonics, the second entry
  442. (if ``diff_n >= 1``) to the gradient, and the third entry
  443. (if ``diff_n >= 2``) to the Hessian matrix. Each entry is an array of
  444. shape ``(n + 1, 2 * m + 1, ...)``, where the entry at ``(i, j)``
  445. corresponds to degree ``i`` and order ``j`` for all ``0 <= i <= n``
  446. and ``-m <= j <= m``.
  447. See Also
  448. --------
  449. sph_harm_y
  450. """, force_complex_output=True, diff_n=0
  451. )
  452. @sph_harm_y_all._override_key
  453. def _(diff_n):
  454. diff_n = _nonneg_int_or_fail(diff_n, "diff_n", strict=False)
  455. if not 0 <= diff_n <= 2:
  456. raise ValueError(
  457. "diff_n is currently only implemented for orders 2,"
  458. f" received: {diff_n}."
  459. )
  460. return diff_n
  461. @sph_harm_y_all._override_ufunc_default_kwargs
  462. def _(diff_n):
  463. return {'axes': [(), ()] + [(0, 1, -2, -1)]}
  464. @sph_harm_y_all._override_resolve_out_shapes
  465. def _(n, m, theta_shape, phi_shape, nout, **kwargs):
  466. diff_n = kwargs['diff_n']
  467. if not isinstance(n, numbers.Integral) or (n < 0):
  468. raise ValueError("n must be a non-negative integer.")
  469. return ((n + 1, 2 * abs(m) + 1) + np.broadcast_shapes(theta_shape, phi_shape) +
  470. (diff_n + 1, diff_n + 1),)
  471. @sph_harm_y_all._override_finalize_out
  472. def _(out):
  473. if (out.shape[-1] == 1):
  474. return out[..., 0, 0]
  475. if (out.shape[-1] == 2):
  476. return out[..., 0, 0], out[..., [1, 0], [0, 1]]
  477. if (out.shape[-1] == 3):
  478. return (out[..., 0, 0], out[..., [1, 0], [0, 1]],
  479. out[..., [[2, 1], [1, 0]], [[0, 1], [1, 2]]])