exponential.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.distributions import constraints
  5. from torch.distributions.exp_family import ExponentialFamily
  6. from torch.distributions.utils import broadcast_all
  7. from torch.types import _Number, _size
  8. __all__ = ["Exponential"]
  9. class Exponential(ExponentialFamily):
  10. r"""
  11. Creates a Exponential distribution parameterized by :attr:`rate`.
  12. Example::
  13. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  14. >>> m = Exponential(torch.tensor([1.0]))
  15. >>> m.sample() # Exponential distributed with rate=1
  16. tensor([ 0.1046])
  17. Args:
  18. rate (float or Tensor): rate = 1 / scale of the distribution
  19. """
  20. # pyrefly: ignore [bad-override]
  21. arg_constraints = {"rate": constraints.positive}
  22. support = constraints.nonnegative
  23. has_rsample = True
  24. _mean_carrier_measure = 0
  25. @property
  26. def mean(self) -> Tensor:
  27. return self.rate.reciprocal()
  28. @property
  29. def mode(self) -> Tensor:
  30. return torch.zeros_like(self.rate)
  31. @property
  32. def stddev(self) -> Tensor:
  33. return self.rate.reciprocal()
  34. @property
  35. def variance(self) -> Tensor:
  36. return self.rate.pow(-2)
  37. def __init__(
  38. self,
  39. rate: Tensor | float,
  40. validate_args: bool | None = None,
  41. ) -> None:
  42. (self.rate,) = broadcast_all(rate)
  43. batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size()
  44. super().__init__(batch_shape, validate_args=validate_args)
  45. def expand(self, batch_shape, _instance=None):
  46. new = self._get_checked_instance(Exponential, _instance)
  47. batch_shape = torch.Size(batch_shape)
  48. new.rate = self.rate.expand(batch_shape)
  49. super(Exponential, new).__init__(batch_shape, validate_args=False)
  50. new._validate_args = self._validate_args
  51. return new
  52. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  53. shape = self._extended_shape(sample_shape)
  54. return self.rate.new(shape).exponential_() / self.rate
  55. def log_prob(self, value):
  56. if self._validate_args:
  57. self._validate_sample(value)
  58. return self.rate.log() - self.rate * value
  59. def cdf(self, value):
  60. if self._validate_args:
  61. self._validate_sample(value)
  62. return 1 - torch.exp(-self.rate * value)
  63. def icdf(self, value):
  64. return -torch.log1p(-value) / self.rate
  65. def entropy(self):
  66. return 1.0 - torch.log(self.rate)
  67. @property
  68. def _natural_params(self) -> tuple[Tensor]:
  69. return (-self.rate,)
  70. # pyrefly: ignore [bad-override]
  71. def _log_normalizer(self, x):
  72. return -torch.log(-x)