test_bdtr.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import warnings
  2. import numpy as np
  3. import scipy.special as sc
  4. import pytest
  5. from numpy.testing import assert_allclose, assert_array_equal
  6. class TestBdtr:
  7. def test(self):
  8. val = sc.bdtr(0, 1, 0.5)
  9. assert_allclose(val, 0.5)
  10. def test_sum_is_one(self):
  11. val = sc.bdtr([0, 1, 2], 2, 0.5)
  12. assert_array_equal(val, [0.25, 0.75, 1.0])
  13. def test_rounding(self):
  14. double_val = sc.bdtr([0.1, 1.1, 2.1], 2, 0.5)
  15. int_val = sc.bdtr([0, 1, 2], 2, 0.5)
  16. assert_array_equal(double_val, int_val)
  17. @pytest.mark.parametrize('k, n, p', [
  18. (np.inf, 2, 0.5),
  19. (1.0, np.inf, 0.5),
  20. (1.0, 2, np.inf)
  21. ])
  22. def test_inf(self, k, n, p):
  23. with warnings.catch_warnings():
  24. warnings.simplefilter("ignore", DeprecationWarning)
  25. val = sc.bdtr(k, n, p)
  26. assert np.isnan(val)
  27. def test_domain(self):
  28. val = sc.bdtr(-1.1, 1, 0.5)
  29. assert np.isnan(val)
  30. class TestBdtrc:
  31. def test_value(self):
  32. val = sc.bdtrc(0, 1, 0.5)
  33. assert_allclose(val, 0.5)
  34. def test_sum_is_one(self):
  35. val = sc.bdtrc([0, 1, 2], 2, 0.5)
  36. assert_array_equal(val, [0.75, 0.25, 0.0])
  37. def test_rounding(self):
  38. double_val = sc.bdtrc([0.1, 1.1, 2.1], 2, 0.5)
  39. int_val = sc.bdtrc([0, 1, 2], 2, 0.5)
  40. assert_array_equal(double_val, int_val)
  41. @pytest.mark.parametrize('k, n, p', [
  42. (np.inf, 2, 0.5),
  43. (1.0, np.inf, 0.5),
  44. (1.0, 2, np.inf)
  45. ])
  46. def test_inf(self, k, n, p):
  47. with warnings.catch_warnings():
  48. warnings.simplefilter("ignore", DeprecationWarning)
  49. val = sc.bdtrc(k, n, p)
  50. assert np.isnan(val)
  51. def test_domain(self):
  52. val = sc.bdtrc(-1.1, 1, 0.5)
  53. val2 = sc.bdtrc(2.1, 1, 0.5)
  54. assert np.isnan(val2)
  55. assert_allclose(val, 1.0)
  56. def test_bdtr_bdtrc_sum_to_one(self):
  57. bdtr_vals = sc.bdtr([0, 1, 2], 2, 0.5)
  58. bdtrc_vals = sc.bdtrc([0, 1, 2], 2, 0.5)
  59. vals = bdtr_vals + bdtrc_vals
  60. assert_allclose(vals, [1.0, 1.0, 1.0])
  61. class TestBdtri:
  62. def test_value(self):
  63. val = sc.bdtri(0, 1, 0.5)
  64. assert_allclose(val, 0.5)
  65. def test_sum_is_one(self):
  66. val = sc.bdtri([0, 1], 2, 0.5)
  67. actual = np.asarray([1 - 1/np.sqrt(2), 1/np.sqrt(2)])
  68. assert_allclose(val, actual)
  69. def test_rounding(self):
  70. double_val = sc.bdtri([0.1, 1.1], 2, 0.5)
  71. int_val = sc.bdtri([0, 1], 2, 0.5)
  72. assert_allclose(double_val, int_val)
  73. @pytest.mark.parametrize('k, n, p', [
  74. (np.inf, 2, 0.5),
  75. (1.0, np.inf, 0.5),
  76. (1.0, 2, np.inf)
  77. ])
  78. def test_inf(self, k, n, p):
  79. with warnings.catch_warnings():
  80. warnings.simplefilter("ignore", DeprecationWarning)
  81. val = sc.bdtri(k, n, p)
  82. assert np.isnan(val)
  83. @pytest.mark.parametrize('k, n, p', [
  84. (-1.1, 1, 0.5),
  85. (2.1, 1, 0.5)
  86. ])
  87. def test_domain(self, k, n, p):
  88. val = sc.bdtri(k, n, p)
  89. assert np.isnan(val)
  90. def test_bdtr_bdtri_roundtrip(self):
  91. bdtr_vals = sc.bdtr([0, 1, 2], 2, 0.5)
  92. roundtrip_vals = sc.bdtri([0, 1, 2], 2, bdtr_vals)
  93. assert_allclose(roundtrip_vals, [0.5, 0.5, np.nan])