test_logsumexp.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. import itertools as it
  2. import math
  3. import pytest
  4. import numpy as np
  5. from scipy._lib._array_api import (is_array_api_strict, make_xp_test_case,
  6. xp_default_dtype, xp_device)
  7. from scipy._lib._array_api_no_0d import (xp_assert_equal, xp_assert_close,
  8. xp_assert_less)
  9. from scipy.special import log_softmax, logsumexp, softmax
  10. from scipy.special._logsumexp import _wrap_radians
  11. dtypes = ['float32', 'float64', 'int32', 'int64', 'complex64', 'complex128']
  12. integral_dtypes = ['int32', 'int64']
  13. def test_wrap_radians(xp):
  14. x = xp.asarray([-math.pi-1, -math.pi, -1, -1e-300,
  15. 0, 1e-300, 1, math.pi, math.pi+1])
  16. ref = xp.asarray([math.pi-1, math.pi, -1, -1e-300,
  17. 0, 1e-300, 1, math.pi, -math.pi+1])
  18. res = _wrap_radians(x, xp=xp)
  19. xp_assert_close(res, ref, atol=0)
  20. # numpy warning filters don't work for dask (dask/dask#3245)
  21. # (also we should not expect the numpy warning filter to work for any Array API
  22. # library)
  23. @pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
  24. @pytest.mark.filterwarnings("ignore:divide by zero encountered:RuntimeWarning")
  25. @pytest.mark.filterwarnings("ignore:overflow encountered:RuntimeWarning")
  26. @make_xp_test_case(logsumexp)
  27. class TestLogSumExp:
  28. def test_logsumexp(self, xp):
  29. # Test with zero-size array
  30. a = xp.asarray([])
  31. desired = xp.asarray(-xp.inf)
  32. xp_assert_equal(logsumexp(a), desired)
  33. # Test whether logsumexp() function correctly handles large inputs.
  34. a = xp.arange(200., dtype=xp.float64)
  35. desired = xp.log(xp.sum(xp.exp(a)))
  36. xp_assert_close(logsumexp(a), desired)
  37. # Now test with large numbers
  38. b = xp.asarray([1000., 1000.])
  39. desired = xp.asarray(1000.0 + math.log(2.0))
  40. xp_assert_close(logsumexp(b), desired)
  41. n = 1000
  42. b = xp.full((n,), 10000)
  43. desired = xp.asarray(10000.0 + math.log(n))
  44. xp_assert_close(logsumexp(b), desired)
  45. x = xp.asarray([1e-40] * 1000000)
  46. logx = xp.log(x)
  47. X = xp.stack([x, x])
  48. logX = xp.stack([logx, logx])
  49. xp_assert_close(xp.exp(logsumexp(logX)), xp.sum(X))
  50. xp_assert_close(xp.exp(logsumexp(logX, axis=0)), xp.sum(X, axis=0))
  51. xp_assert_close(xp.exp(logsumexp(logX, axis=1)), xp.sum(X, axis=1))
  52. # Handling special values properly
  53. inf = xp.asarray([xp.inf])
  54. nan = xp.asarray([xp.nan])
  55. xp_assert_equal(logsumexp(inf), inf[0])
  56. xp_assert_equal(logsumexp(-inf), -inf[0])
  57. xp_assert_equal(logsumexp(nan), nan[0])
  58. xp_assert_equal(logsumexp(xp.asarray([-xp.inf, -xp.inf])), -inf[0])
  59. # Handling an array with different magnitudes on the axes
  60. a = xp.asarray([[1e10, 1e-10],
  61. [-1e10, -np.inf]])
  62. ref = xp.asarray([1e10, -1e10])
  63. xp_assert_close(logsumexp(a, axis=-1), ref)
  64. # Test keeping dimensions
  65. ref = xp.expand_dims(ref, axis=-1)
  66. xp_assert_close(logsumexp(a, axis=-1, keepdims=True), ref)
  67. # Test multiple axes
  68. xp_assert_close(logsumexp(a, axis=(-1, -2)), xp.asarray(1e10))
  69. def test_logsumexp_b(self, xp):
  70. a = xp.arange(200., dtype=xp.float64)
  71. b = xp.arange(200., 0., -1.)
  72. desired = xp.log(xp.sum(b*xp.exp(a)))
  73. xp_assert_close(logsumexp(a, b=b), desired)
  74. a = xp.asarray([1000, 1000])
  75. b = xp.asarray([1.2, 1.2])
  76. desired = xp.asarray(1000 + math.log(2 * 1.2))
  77. xp_assert_close(logsumexp(a, b=b), desired)
  78. x = xp.asarray([1e-40] * 100000)
  79. b = xp.linspace(1, 1000, 100000)
  80. logx = xp.log(x)
  81. X = xp.stack((x, x))
  82. logX = xp.stack((logx, logx))
  83. B = xp.stack((b, b))
  84. xp_assert_close(xp.exp(logsumexp(logX, b=B)), xp.sum(B * X))
  85. xp_assert_close(xp.exp(logsumexp(logX, b=B, axis=0)), xp.sum(B * X, axis=0))
  86. xp_assert_close(xp.exp(logsumexp(logX, b=B, axis=1)), xp.sum(B * X, axis=1))
  87. def test_logsumexp_sign(self, xp):
  88. a = xp.asarray([1, 1, 1])
  89. b = xp.asarray([1, -1, -1])
  90. r, s = logsumexp(a, b=b, return_sign=True)
  91. xp_assert_close(r, xp.asarray(1.))
  92. xp_assert_equal(s, xp.asarray(-1.))
  93. def test_logsumexp_sign_zero(self, xp):
  94. a = xp.asarray([1, 1])
  95. b = xp.asarray([1, -1])
  96. r, s = logsumexp(a, b=b, return_sign=True)
  97. assert not xp.isfinite(r)
  98. assert not xp.isnan(r)
  99. assert r < 0
  100. assert s == 0
  101. def test_logsumexp_sign_shape(self, xp):
  102. a = xp.ones((1, 2, 3, 4))
  103. b = xp.ones_like(a)
  104. r, s = logsumexp(a, axis=2, b=b, return_sign=True)
  105. assert r.shape == s.shape == (1, 2, 4)
  106. r, s = logsumexp(a, axis=(1, 3), b=b, return_sign=True)
  107. assert r.shape == s.shape == (1,3)
  108. def test_logsumexp_complex_sign(self, xp):
  109. a = xp.asarray([1 + 1j, 2 - 1j, -2 + 3j])
  110. r, s = logsumexp(a, return_sign=True)
  111. expected_sumexp = xp.sum(xp.exp(a))
  112. # This is the numpy>=2.0 convention for np.sign
  113. expected_sign = expected_sumexp / xp.abs(expected_sumexp)
  114. xp_assert_close(s, expected_sign)
  115. xp_assert_close(s * xp.exp(r), expected_sumexp)
  116. def test_logsumexp_shape(self, xp):
  117. a = xp.ones((1, 2, 3, 4))
  118. b = xp.ones_like(a)
  119. r = logsumexp(a, axis=2, b=b)
  120. assert r.shape == (1, 2, 4)
  121. r = logsumexp(a, axis=(1, 3), b=b)
  122. assert r.shape == (1, 3)
  123. def test_logsumexp_b_zero(self, xp):
  124. a = xp.asarray([1, 10000])
  125. b = xp.asarray([1, 0])
  126. xp_assert_close(logsumexp(a, b=b), xp.asarray(1.))
  127. def test_logsumexp_b_shape(self, xp):
  128. a = xp.zeros((4, 1, 2, 1))
  129. b = xp.ones((3, 1, 5))
  130. logsumexp(a, b=b)
  131. @pytest.mark.parametrize('arg', (1, [1, 2, 3]))
  132. def test_xp_invalid_input(self, arg):
  133. assert logsumexp(arg) == logsumexp(np.asarray(np.atleast_1d(arg)))
  134. def test_array_like(self):
  135. a = [1000, 1000]
  136. desired = np.asarray(1000.0 + math.log(2.0))
  137. xp_assert_close(logsumexp(a), desired)
  138. @pytest.mark.parametrize('dtype', dtypes)
  139. def test_dtypes_a(self, dtype, xp):
  140. dtype = getattr(xp, dtype)
  141. a = xp.asarray([1000., 1000.], dtype=dtype)
  142. desired_dtype = (xp.asarray(1.).dtype if xp.isdtype(dtype, 'integral')
  143. else dtype) # true for all libraries tested
  144. desired = xp.asarray(1000.0 + math.log(2.0), dtype=desired_dtype)
  145. xp_assert_close(logsumexp(a), desired)
  146. @pytest.mark.parametrize('dtype_a', dtypes)
  147. @pytest.mark.parametrize('dtype_b', dtypes)
  148. def test_dtypes_ab(self, dtype_a, dtype_b, xp):
  149. xp_dtype_a = getattr(xp, dtype_a)
  150. xp_dtype_b = getattr(xp, dtype_b)
  151. a = xp.asarray([2, 1], dtype=xp_dtype_a)
  152. b = xp.asarray([1, -1], dtype=xp_dtype_b)
  153. if is_array_api_strict(xp):
  154. # special-case for `TypeError: array_api_strict.float32 and
  155. # and array_api_strict.int64 cannot be type promoted together`
  156. xp_float_dtypes = [dtype for dtype in [xp_dtype_a, xp_dtype_b]
  157. if not xp.isdtype(dtype, 'integral')]
  158. if len(xp_float_dtypes) < 2: # at least one is integral
  159. xp_float_dtypes.append(xp.asarray(1.).dtype)
  160. desired_dtype = xp.result_type(*xp_float_dtypes)
  161. else:
  162. desired_dtype = xp.result_type(xp_dtype_a, xp_dtype_b)
  163. if xp.isdtype(desired_dtype, 'integral'):
  164. desired_dtype = xp_default_dtype(xp)
  165. desired = xp.asarray(math.log(math.exp(2) - math.exp(1)), dtype=desired_dtype)
  166. xp_assert_close(logsumexp(a, b=b), desired)
  167. def test_gh18295(self, xp):
  168. # gh-18295 noted loss of precision when real part of one element is much
  169. # larger than the rest. Check that this is resolved.
  170. a = xp.asarray([0.0, -40.0])
  171. res = logsumexp(a)
  172. ref = xp.logaddexp(a[0], a[1])
  173. xp_assert_close(res, ref)
  174. @pytest.mark.parametrize('dtype', ['complex64', 'complex128'])
  175. def test_gh21610(self, xp, dtype):
  176. # gh-21610 noted that `logsumexp` could return imaginary components
  177. # outside the range (-pi, pi]. Check that this is resolved.
  178. # While working on this, I noticed that all other tests passed even
  179. # when the imaginary component of the result was zero. This suggested
  180. # the need of a stronger test with imaginary dtype.
  181. rng = np.random.default_rng(324984329582349862)
  182. dtype = getattr(xp, dtype)
  183. shape = (10, 100)
  184. x = rng.uniform(1, 40, shape) + 1.j * rng.uniform(1, 40, shape)
  185. x = xp.asarray(x, dtype=dtype)
  186. res = logsumexp(x, axis=1)
  187. ref = xp.log(xp.sum(xp.exp(x), axis=1))
  188. max = xp.full_like(xp.imag(res), xp.pi)
  189. xp_assert_less(xp.abs(xp.imag(res)), max)
  190. xp_assert_close(res, ref)
  191. out, sgn = logsumexp(x, return_sign=True, axis=1)
  192. ref = xp.sum(xp.exp(x), axis=1)
  193. xp_assert_less(xp.abs(xp.imag(sgn)), max)
  194. xp_assert_close(out, xp.real(xp.log(ref)))
  195. xp_assert_close(sgn, ref/xp.abs(ref))
  196. def test_gh21709_small_imaginary(self, xp):
  197. # Test that `logsumexp` does not lose relative precision of
  198. # small imaginary components
  199. x = xp.asarray([0, 0.+2.2204460492503132e-17j])
  200. res = logsumexp(x)
  201. # from mpmath import mp
  202. # mp.dps = 100
  203. # x, y = mp.mpc(0), mp.mpc('0', '2.2204460492503132e-17')
  204. # ref = complex(mp.log(mp.exp(x) + mp.exp(y)))
  205. ref = xp.asarray(0.6931471805599453+1.1102230246251566e-17j)
  206. xp_assert_close(xp.real(res), xp.real(ref))
  207. xp_assert_close(xp.imag(res), xp.imag(ref), atol=0, rtol=1e-15)
  208. @pytest.mark.parametrize('x,y', it.product(
  209. [
  210. -np.inf,
  211. np.inf,
  212. complex(-np.inf, 0.),
  213. complex(-np.inf, -0.),
  214. complex(-np.inf, np.inf),
  215. complex(-np.inf, -np.inf),
  216. complex(np.inf, 0.),
  217. complex(np.inf, -0.),
  218. complex(np.inf, np.inf),
  219. complex(np.inf, -np.inf),
  220. # Phase in each quadrant.
  221. complex(-np.inf, 0.7533),
  222. complex(-np.inf, 2.3562),
  223. complex(-np.inf, 3.9270),
  224. complex(-np.inf, 5.4978),
  225. complex(np.inf, 0.7533),
  226. complex(np.inf, 2.3562),
  227. complex(np.inf, 3.9270),
  228. complex(np.inf, 5.4978),
  229. ], repeat=2)
  230. )
  231. def test_gh22601_infinite_elements(self, x, y, xp):
  232. # Test that `logsumexp` does reasonable things in the presence of
  233. # real and complex infinities.
  234. res = logsumexp(xp.asarray([x, y]))
  235. ref = xp.log(xp.sum(xp.exp(xp.asarray([x, y]))))
  236. xp_assert_equal(res, ref)
  237. def test_no_writeback(self, xp):
  238. """Test that logsumexp doesn't accidentally write back to its parameters."""
  239. a = xp.asarray([5., 4.])
  240. b = xp.asarray([3., 2.])
  241. logsumexp(a)
  242. logsumexp(a, b=b)
  243. xp_assert_equal(a, xp.asarray([5., 4.]))
  244. xp_assert_equal(b, xp.asarray([3., 2.]))
  245. @pytest.mark.parametrize("x_raw", [1.0, 1.0j, []])
  246. def test_device(self, x_raw, xp, devices):
  247. """Test input device propagation to output."""
  248. for d in devices:
  249. x = xp.asarray(x_raw, device=d)
  250. assert xp_device(logsumexp(x)) == xp_device(x)
  251. assert xp_device(logsumexp(x, b=x)) == xp_device(x)
  252. def test_gh22903(self, xp):
  253. # gh-22903 reported that `logsumexp` produced NaN where the weight associated
  254. # with the max magnitude element was negative and `return_sign=False`, even if
  255. # the net result should be the log of a positive number.
  256. # result is log of positive number
  257. a = xp.asarray([3.06409428, 0.37251854, 3.87471931])
  258. b = xp.asarray([1.88190708, 2.84174795, -0.85016884])
  259. xp_assert_close(logsumexp(a, b=b), logsumexp(a, b=b, return_sign=True)[0])
  260. # result is log of negative number
  261. b = xp.asarray([1.88190708, 2.84174795, -3.85016884])
  262. xp_assert_close(logsumexp(a, b=b), xp.asarray(xp.nan))
  263. @pytest.mark.parametrize("a, b, sign_ref",
  264. [([np.inf], None, 1.),
  265. ([np.inf], [-1.], -1.)])
  266. def test_gh23548(self, xp, a, b, sign_ref):
  267. # gh-23548 reported that `logsumexp` with `return_sign=True` returned a sign
  268. # of NaN with infinite reals
  269. a, b = xp.asarray(a), xp.asarray(b) if b is not None else None
  270. val, sign = logsumexp(a, b=b, return_sign=True)
  271. assert xp.isinf(val)
  272. xp_assert_equal(sign, xp.asarray(sign_ref))
  273. @make_xp_test_case(softmax)
  274. class TestSoftmax:
  275. def test_softmax_fixtures(self, xp):
  276. xp_assert_close(softmax(xp.asarray([1000., 0., 0., 0.])),
  277. xp.asarray([1., 0., 0., 0.]), rtol=1e-13)
  278. xp_assert_close(softmax(xp.asarray([1., 1.])),
  279. xp.asarray([.5, .5]), rtol=1e-13)
  280. xp_assert_close(softmax(xp.asarray([0., 1.])),
  281. xp.asarray([1., np.e])/(1 + np.e),
  282. rtol=1e-13)
  283. # Expected value computed using mpmath (with mpmath.mp.dps = 200) and then
  284. # converted to float.
  285. x = xp.arange(4, dtype=xp.float64)
  286. expected = xp.asarray([0.03205860328008499,
  287. 0.08714431874203256,
  288. 0.23688281808991013,
  289. 0.6439142598879722], dtype=xp.float64)
  290. xp_assert_close(softmax(x), expected, rtol=1e-13)
  291. # Translation property. If all the values are changed by the same amount,
  292. # the softmax result does not change.
  293. xp_assert_close(softmax(x + 100), expected, rtol=1e-13)
  294. # When axis=None, softmax operates on the entire array, and preserves
  295. # the shape.
  296. xp_assert_close(softmax(xp.reshape(x, (2, 2))),
  297. xp.reshape(expected, (2, 2)), rtol=1e-13)
  298. def test_softmax_multi_axes(self, xp):
  299. xp_assert_close(softmax(xp.asarray([[1000., 0.], [1000., 0.]]), axis=0),
  300. xp.asarray([[.5, .5], [.5, .5]]), rtol=1e-13)
  301. xp_assert_close(softmax(xp.asarray([[1000., 0.], [1000., 0.]]), axis=1),
  302. xp.asarray([[1., 0.], [1., 0.]]), rtol=1e-13)
  303. # Expected value computed using mpmath (with mpmath.mp.dps = 200) and then
  304. # converted to float.
  305. x = xp.asarray([[-25., 0., 25., 50.],
  306. [ 1., 325., 749., 750.]])
  307. expected = xp.asarray([[2.678636961770877e-33,
  308. 1.9287498479371314e-22,
  309. 1.3887943864771144e-11,
  310. 0.999999999986112],
  311. [0.0,
  312. 1.9444526359919372e-185,
  313. 0.2689414213699951,
  314. 0.7310585786300048]])
  315. xp_assert_close(softmax(x, axis=1), expected, rtol=1e-13)
  316. xp_assert_close(softmax(x.T, axis=0), expected.T, rtol=1e-13)
  317. # 3-d input, with a tuple for the axis.
  318. x3d = xp.reshape(x, (2, 2, 2))
  319. xp_assert_close(softmax(x3d, axis=(1, 2)),
  320. xp.reshape(expected, (2, 2, 2)), rtol=1e-13)
  321. @pytest.mark.xfail_xp_backends("array_api_strict", reason="int->float promotion")
  322. def test_softmax_int_array(self, xp):
  323. xp_assert_close(softmax(xp.asarray([1000, 0, 0, 0])),
  324. xp.asarray([1., 0., 0., 0.]), rtol=1e-13)
  325. def test_softmax_scalar(self):
  326. xp_assert_close(softmax(1000), np.asarray(1.), rtol=1e-13)
  327. def test_softmax_array_like(self):
  328. xp_assert_close(softmax([1000, 0, 0, 0]),
  329. np.asarray([1., 0., 0., 0.]), rtol=1e-13)
  330. @make_xp_test_case(log_softmax)
  331. class TestLogSoftmax:
  332. def test_log_softmax_basic(self, xp):
  333. xp_assert_close(log_softmax(xp.asarray([1000., 1.])),
  334. xp.asarray([0., -999.]), rtol=1e-13)
  335. @pytest.mark.xfail_xp_backends("array_api_strict", reason="int->float promotion")
  336. def test_log_softmax_int_array(self, xp):
  337. xp_assert_close(log_softmax(xp.asarray([1000, 1])),
  338. xp.asarray([0., -999.]), rtol=1e-13)
  339. def test_log_softmax_scalar(self):
  340. xp_assert_close(log_softmax(1.0), 0.0, rtol=1e-13)
  341. def test_log_softmax_array_like(self):
  342. xp_assert_close(log_softmax([1000, 1]),
  343. np.asarray([0., -999.]), rtol=1e-13)
  344. @staticmethod
  345. def data_1d(xp):
  346. x = xp.arange(4, dtype=xp.float64)
  347. # Expected value computed using mpmath (with mpmath.mp.dps = 200)
  348. expect = [-3.4401896985611953,
  349. -2.4401896985611953,
  350. -1.4401896985611953,
  351. -0.44018969856119533]
  352. return x, xp.asarray(expect, dtype=xp.float64)
  353. @staticmethod
  354. def data_2d(xp):
  355. x = xp.reshape(xp.arange(8, dtype=xp.float64), (2, 4))
  356. # Expected value computed using mpmath (with mpmath.mp.dps = 200)
  357. expect = [[-3.4401896985611953,
  358. -2.4401896985611953,
  359. -1.4401896985611953,
  360. -0.44018969856119533],
  361. [-3.4401896985611953,
  362. -2.4401896985611953,
  363. -1.4401896985611953,
  364. -0.44018969856119533]]
  365. return x, xp.asarray(expect, dtype=xp.float64)
  366. @pytest.mark.parametrize("offset", [0, 100])
  367. def test_log_softmax_translation(self, offset, xp):
  368. # Translation property. If all the values are changed by the same amount,
  369. # the softmax result does not change.
  370. x, expect = self.data_1d(xp)
  371. x += offset
  372. xp_assert_close(log_softmax(x), expect, rtol=1e-13)
  373. def test_log_softmax_noneaxis(self, xp):
  374. # When axis=None, softmax operates on the entire array, and preserves
  375. # the shape.
  376. x, expect = self.data_1d(xp)
  377. x = xp.reshape(x, (2, 2))
  378. expect = xp.reshape(expect, (2, 2))
  379. xp_assert_close(log_softmax(x), expect, rtol=1e-13)
  380. @pytest.mark.parametrize('axis_2d, expected_2d', [
  381. (0, np.log(0.5) * np.ones((2, 2))),
  382. (1, [[0., -999.], [0., -999.]]),
  383. ])
  384. def test_axes(self, axis_2d, expected_2d, xp):
  385. x = xp.asarray([[1000., 1.], [1000., 1.]])
  386. xp_assert_close(log_softmax(x, axis=axis_2d),
  387. xp.asarray(expected_2d, dtype=x.dtype), rtol=1e-13)
  388. def test_log_softmax_2d_axis1(self, xp):
  389. x, expect = self.data_2d(xp)
  390. xp_assert_close(log_softmax(x, axis=1), expect, rtol=1e-13)
  391. def test_log_softmax_2d_axis0(self, xp):
  392. x, expect = self.data_2d(xp)
  393. xp_assert_close(log_softmax(x.T, axis=0), expect.T, rtol=1e-13)
  394. def test_log_softmax_3d(self, xp):
  395. # 3D input, with a tuple for the axis.
  396. x, expect = self.data_2d(xp)
  397. x = xp.reshape(x, (2, 2, 2))
  398. expect = xp.reshape(expect, (2, 2, 2))
  399. xp_assert_close(log_softmax(x, axis=(1, 2)), expect, rtol=1e-13)