generalized_pareto.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # mypy: allow-untyped-defs
  2. import math
  3. from numbers import Number, Real
  4. import torch
  5. from torch import inf, nan
  6. from torch.distributions import constraints, Distribution
  7. from torch.distributions.utils import broadcast_all
  8. __all__ = ["GeneralizedPareto"]
  9. class GeneralizedPareto(Distribution):
  10. r"""
  11. Creates a Generalized Pareto distribution parameterized by :attr:`loc`, :attr:`scale`, and :attr:`concentration`.
  12. The Generalized Pareto distribution is a family of continuous probability distributions on the real line.
  13. Special cases include Exponential (when :attr:`loc` = 0, :attr:`concentration` = 0), Pareto (when :attr:`concentration` > 0,
  14. :attr:`loc` = :attr:`scale` / :attr:`concentration`), and Uniform (when :attr:`concentration` = -1).
  15. This distribution is often used to model the tails of other distributions. This implementation is based on the
  16. implementation in TensorFlow Probability.
  17. Example::
  18. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  19. >>> m = GeneralizedPareto(torch.tensor([0.1]), torch.tensor([2.0]), torch.tensor([0.4]))
  20. >>> m.sample() # sample from a Generalized Pareto distribution with loc=0.1, scale=2.0, and concentration=0.4
  21. tensor([ 1.5623])
  22. Args:
  23. loc (float or Tensor): Location parameter of the distribution
  24. scale (float or Tensor): Scale parameter of the distribution
  25. concentration (float or Tensor): Concentration parameter of the distribution
  26. """
  27. # pyrefly: ignore [bad-override]
  28. arg_constraints = {
  29. "loc": constraints.real,
  30. "scale": constraints.positive,
  31. "concentration": constraints.real,
  32. }
  33. has_rsample = True
  34. def __init__(self, loc, scale, concentration, validate_args=None):
  35. self.loc, self.scale, self.concentration = broadcast_all(
  36. loc, scale, concentration
  37. )
  38. if (
  39. isinstance(loc, Number)
  40. and isinstance(scale, Number)
  41. and isinstance(concentration, Number)
  42. ):
  43. batch_shape = torch.Size()
  44. else:
  45. batch_shape = self.loc.size()
  46. super().__init__(batch_shape, validate_args=validate_args)
  47. def expand(self, batch_shape, _instance=None):
  48. new = self._get_checked_instance(GeneralizedPareto, _instance)
  49. batch_shape = torch.Size(batch_shape)
  50. new.loc = self.loc.expand(batch_shape)
  51. new.scale = self.scale.expand(batch_shape)
  52. new.concentration = self.concentration.expand(batch_shape)
  53. super(GeneralizedPareto, new).__init__(batch_shape, validate_args=False)
  54. new._validate_args = self._validate_args
  55. return new
  56. def rsample(self, sample_shape=torch.Size()):
  57. shape = self._extended_shape(sample_shape)
  58. u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device)
  59. return self.icdf(u)
  60. def log_prob(self, value):
  61. if self._validate_args:
  62. self._validate_sample(value)
  63. z = self._z(value)
  64. eq_zero = torch.isclose(self.concentration, torch.tensor(0.0))
  65. safe_conc = torch.where(
  66. eq_zero, torch.ones_like(self.concentration), self.concentration
  67. )
  68. y = 1 / safe_conc + torch.ones_like(z)
  69. where_nonzero = torch.where(y == 0, y, y * torch.log1p(safe_conc * z))
  70. log_scale = (
  71. math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
  72. )
  73. return -log_scale - torch.where(eq_zero, z, where_nonzero)
  74. def log_survival_function(self, value):
  75. if self._validate_args:
  76. self._validate_sample(value)
  77. z = self._z(value)
  78. eq_zero = torch.isclose(self.concentration, torch.tensor(0.0))
  79. safe_conc = torch.where(
  80. eq_zero, torch.ones_like(self.concentration), self.concentration
  81. )
  82. where_nonzero = -torch.log1p(safe_conc * z) / safe_conc
  83. return torch.where(eq_zero, -z, where_nonzero)
  84. def log_cdf(self, value):
  85. return torch.log1p(-torch.exp(self.log_survival_function(value)))
  86. def cdf(self, value):
  87. return torch.exp(self.log_cdf(value))
  88. def icdf(self, value):
  89. loc = self.loc
  90. scale = self.scale
  91. concentration = self.concentration
  92. eq_zero = torch.isclose(concentration, torch.zeros_like(concentration))
  93. safe_conc = torch.where(eq_zero, torch.ones_like(concentration), concentration)
  94. logu = torch.log1p(-value)
  95. where_nonzero = loc + scale / safe_conc * torch.expm1(-safe_conc * logu)
  96. where_zero = loc - scale * logu
  97. return torch.where(eq_zero, where_zero, where_nonzero)
  98. def _z(self, x):
  99. return (x - self.loc) / self.scale
  100. @property
  101. def mean(self):
  102. concentration = self.concentration
  103. valid = concentration < 1
  104. safe_conc = torch.where(valid, concentration, 0.5)
  105. result = self.loc + self.scale / (1 - safe_conc)
  106. return torch.where(valid, result, nan)
  107. @property
  108. def variance(self):
  109. concentration = self.concentration
  110. valid = concentration < 0.5
  111. safe_conc = torch.where(valid, concentration, 0.25)
  112. # pyrefly: ignore [unsupported-operation]
  113. result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
  114. return torch.where(valid, result, nan)
  115. def entropy(self):
  116. ans = torch.log(self.scale) + self.concentration + 1
  117. return torch.broadcast_to(ans, self._batch_shape)
  118. @property
  119. def mode(self):
  120. return self.loc
  121. @constraints.dependent_property(is_discrete=False, event_dim=0)
  122. # pyrefly: ignore [bad-override]
  123. def support(self):
  124. lower = self.loc
  125. upper = torch.where(
  126. self.concentration < 0, lower - self.scale / self.concentration, inf
  127. )
  128. return constraints.interval(lower, upper)