geometric.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.distributions import constraints
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.utils import (
  7. broadcast_all,
  8. lazy_property,
  9. logits_to_probs,
  10. probs_to_logits,
  11. )
  12. from torch.nn.functional import binary_cross_entropy_with_logits
  13. from torch.types import _Number, Number
  14. __all__ = ["Geometric"]
  15. class Geometric(Distribution):
  16. r"""
  17. Creates a Geometric distribution parameterized by :attr:`probs`,
  18. where :attr:`probs` is the probability of success of Bernoulli trials.
  19. .. math::
  20. P(X=k) = (1-p)^{k} p, k = 0, 1, ...
  21. .. note::
  22. :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success
  23. hence draws samples in :math:`\{0, 1, \ldots\}`, whereas
  24. :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`.
  25. Example::
  26. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  27. >>> m = Geometric(torch.tensor([0.3]))
  28. >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0
  29. tensor([ 2.])
  30. Args:
  31. probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1]
  32. logits (Number, Tensor): the log-odds of sampling `1`.
  33. """
  34. # pyrefly: ignore [bad-override]
  35. arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
  36. support = constraints.nonnegative_integer
  37. def __init__(
  38. self,
  39. probs: Tensor | Number | None = None,
  40. logits: Tensor | Number | None = None,
  41. validate_args: bool | None = None,
  42. ) -> None:
  43. if (probs is None) == (logits is None):
  44. raise ValueError(
  45. "Either `probs` or `logits` must be specified, but not both."
  46. )
  47. if probs is not None:
  48. # pyrefly: ignore [read-only]
  49. (self.probs,) = broadcast_all(probs)
  50. else:
  51. if logits is None:
  52. raise AssertionError("logits is unexpectedly None")
  53. # pyrefly: ignore [read-only]
  54. (self.logits,) = broadcast_all(logits)
  55. probs_or_logits = probs if probs is not None else logits
  56. if isinstance(probs_or_logits, _Number):
  57. batch_shape = torch.Size()
  58. else:
  59. if probs_or_logits is None:
  60. raise AssertionError("probs_or_logits is unexpectedly None")
  61. batch_shape = probs_or_logits.size()
  62. super().__init__(batch_shape, validate_args=validate_args)
  63. if self._validate_args and probs is not None:
  64. # Add an extra check beyond unit_interval
  65. value = self.probs
  66. valid = value > 0
  67. if not valid.all():
  68. invalid_value = value.data[~valid]
  69. raise ValueError(
  70. "Expected parameter probs "
  71. f"({type(value).__name__} of shape {tuple(value.shape)}) "
  72. f"of distribution {repr(self)} "
  73. f"to be positive but found invalid values:\n{invalid_value}"
  74. )
  75. def expand(self, batch_shape, _instance=None):
  76. new = self._get_checked_instance(Geometric, _instance)
  77. batch_shape = torch.Size(batch_shape)
  78. if "probs" in self.__dict__:
  79. new.probs = self.probs.expand(batch_shape)
  80. if "logits" in self.__dict__:
  81. new.logits = self.logits.expand(batch_shape)
  82. super(Geometric, new).__init__(batch_shape, validate_args=False)
  83. new._validate_args = self._validate_args
  84. return new
  85. @property
  86. def mean(self) -> Tensor:
  87. return 1.0 / self.probs - 1.0
  88. @property
  89. def mode(self) -> Tensor:
  90. return torch.zeros_like(self.probs)
  91. @property
  92. def variance(self) -> Tensor:
  93. return (1.0 / self.probs - 1.0) / self.probs
  94. @lazy_property
  95. def logits(self) -> Tensor:
  96. return probs_to_logits(self.probs, is_binary=True)
  97. @lazy_property
  98. def probs(self) -> Tensor:
  99. return logits_to_probs(self.logits, is_binary=True)
  100. def sample(self, sample_shape=torch.Size()):
  101. shape = self._extended_shape(sample_shape)
  102. tiny = torch.finfo(self.probs.dtype).tiny
  103. with torch.no_grad():
  104. if torch._C._get_tracing_state():
  105. # [JIT WORKAROUND] lack of support for .uniform_()
  106. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  107. u = u.clamp(min=tiny)
  108. else:
  109. u = self.probs.new(shape).uniform_(tiny, 1)
  110. return (u.log() / (-self.probs).log1p()).floor()
  111. def log_prob(self, value):
  112. if self._validate_args:
  113. self._validate_sample(value)
  114. value, probs = broadcast_all(value, self.probs)
  115. probs = probs.clone(memory_format=torch.contiguous_format)
  116. probs[(probs == 1) & (value == 0)] = 0
  117. return value * (-probs).log1p() + self.probs.log()
  118. def entropy(self):
  119. return (
  120. binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none")
  121. / self.probs
  122. )