test_logit.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import numpy as np
  2. from numpy.testing import assert_equal, assert_allclose
  3. from scipy.special import logit, expit, log_expit
  4. class TestLogit:
  5. def check_logit_out(self, a, expected):
  6. actual = logit(a)
  7. assert_equal(actual.dtype, a.dtype)
  8. rtol = 16*np.finfo(a.dtype).eps
  9. assert_allclose(actual, expected, rtol=rtol)
  10. def test_float32(self):
  11. a = np.concatenate((np.linspace(0, 1, 10, dtype=np.float32),
  12. [np.float32(0.0001), np.float32(0.49999),
  13. np.float32(0.50001)]))
  14. # Expected values computed with mpmath from float32 inputs, e.g.
  15. # from mpmath import mp
  16. # mp.dps = 200
  17. # a = np.float32(1/9)
  18. # print(np.float32(mp.log(a) - mp.log1p(-a)))
  19. # prints `-2.0794415`.
  20. expected = np.array([-np.inf, -2.0794415, -1.2527629, -6.9314712e-01,
  21. -2.2314353e-01, 2.2314365e-01, 6.9314724e-01,
  22. 1.2527630, 2.0794415, np.inf,
  23. -9.2102404, -4.0054321e-05, 4.0054321e-05],
  24. dtype=np.float32)
  25. self.check_logit_out(a, expected)
  26. def test_float64(self):
  27. a = np.concatenate((np.linspace(0, 1, 10, dtype=np.float64),
  28. [1e-8, 0.4999999999999, 0.50000000001]))
  29. # Expected values computed with mpmath.
  30. expected = np.array([-np.inf,
  31. -2.079441541679836,
  32. -1.252762968495368,
  33. -0.6931471805599454,
  34. -0.22314355131420985,
  35. 0.22314355131420985,
  36. 0.6931471805599452,
  37. 1.2527629684953674,
  38. 2.0794415416798353,
  39. np.inf,
  40. -18.420680733952366,
  41. -3.999023334699814e-13,
  42. 4.000000330961484e-11])
  43. self.check_logit_out(a, expected)
  44. def test_nan(self):
  45. expected = np.array([np.nan]*4)
  46. with np.errstate(invalid='ignore'):
  47. actual = logit(np.array([-3., -2., 2., 3.]))
  48. assert_equal(expected, actual)
  49. class TestExpit:
  50. def check_expit_out(self, dtype, expected):
  51. a = np.linspace(-4, 4, 10)
  52. a = np.array(a, dtype=dtype)
  53. actual = expit(a)
  54. assert_allclose(actual, expected, atol=1.5e-7, rtol=0)
  55. assert_equal(actual.dtype, np.dtype(dtype))
  56. def test_float32(self):
  57. expected = np.array([0.01798621, 0.04265125,
  58. 0.09777259, 0.20860852,
  59. 0.39068246, 0.60931754,
  60. 0.79139149, 0.9022274,
  61. 0.95734876, 0.98201376], dtype=np.float32)
  62. self.check_expit_out('f4', expected)
  63. def test_float64(self):
  64. expected = np.array([0.01798621, 0.04265125,
  65. 0.0977726, 0.20860853,
  66. 0.39068246, 0.60931754,
  67. 0.79139147, 0.9022274,
  68. 0.95734875, 0.98201379])
  69. self.check_expit_out('f8', expected)
  70. def test_large(self):
  71. for dtype in (np.float32, np.float64, np.longdouble):
  72. for n in (88, 89, 709, 710, 11356, 11357):
  73. n = np.array(n, dtype=dtype)
  74. assert_allclose(expit(n), 1.0, atol=1e-20)
  75. assert_allclose(expit(-n), 0.0, atol=1e-20)
  76. assert_equal(expit(n).dtype, dtype)
  77. assert_equal(expit(-n).dtype, dtype)
  78. class TestLogExpit:
  79. def test_large_negative(self):
  80. x = np.array([-10000.0, -750.0, -500.0, -35.0])
  81. y = log_expit(x)
  82. assert_equal(y, x)
  83. def test_large_positive(self):
  84. x = np.array([750.0, 1000.0, 10000.0])
  85. y = log_expit(x)
  86. # y will contain -0.0, and -0.0 is used in the expected value,
  87. # but assert_equal does not check the sign of zeros, and I don't
  88. # think the sign is an essential part of the test (i.e. it would
  89. # probably be OK if log_expit(1000) returned 0.0 instead of -0.0).
  90. assert_equal(y, np.array([-0.0, -0.0, -0.0]))
  91. def test_basic_float64(self):
  92. x = np.array([-32, -20, -10, -3, -1, -0.1, -1e-9,
  93. 0, 1e-9, 0.1, 1, 10, 100, 500, 710, 725, 735])
  94. y = log_expit(x)
  95. #
  96. # Expected values were computed with mpmath:
  97. #
  98. # import mpmath
  99. #
  100. # mpmath.mp.dps = 100
  101. #
  102. # def mp_log_expit(x):
  103. # return -mpmath.log1p(mpmath.exp(-x))
  104. #
  105. # expected = [float(mp_log_expit(t)) for t in x]
  106. #
  107. expected = [-32.000000000000014, -20.000000002061153,
  108. -10.000045398899218, -3.048587351573742,
  109. -1.3132616875182228, -0.7443966600735709,
  110. -0.6931471810599453, -0.6931471805599453,
  111. -0.6931471800599454, -0.6443966600735709,
  112. -0.3132616875182228, -4.539889921686465e-05,
  113. -3.720075976020836e-44, -7.124576406741286e-218,
  114. -4.47628622567513e-309, -1.36930634e-315,
  115. -6.217e-320]
  116. # When tested locally, only one value in y was not exactly equal to
  117. # expected. That was for x=1, and the y value differed from the
  118. # expected by 1 ULP. For this test, however, I'll use rtol=1e-15.
  119. assert_allclose(y, expected, rtol=1e-15)
  120. def test_basic_float32(self):
  121. x = np.array([-32, -20, -10, -3, -1, -0.1, -1e-9,
  122. 0, 1e-9, 0.1, 1, 10, 100], dtype=np.float32)
  123. y = log_expit(x)
  124. #
  125. # Expected values were computed with mpmath:
  126. #
  127. # import mpmath
  128. #
  129. # mpmath.mp.dps = 100
  130. #
  131. # def mp_log_expit(x):
  132. # return -mpmath.log1p(mpmath.exp(-x))
  133. #
  134. # expected = [np.float32(mp_log_expit(t)) for t in x]
  135. #
  136. expected = np.array([-32.0, -20.0, -10.000046, -3.0485873,
  137. -1.3132616, -0.7443967, -0.6931472,
  138. -0.6931472, -0.6931472, -0.64439666,
  139. -0.3132617, -4.5398898e-05, -3.8e-44],
  140. dtype=np.float32)
  141. assert_allclose(y, expected, rtol=5e-7)