continuous_bernoulli.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # mypy: allow-untyped-defs
  2. import math
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.exp_family import ExponentialFamily
  7. from torch.distributions.utils import (
  8. broadcast_all,
  9. clamp_probs,
  10. lazy_property,
  11. logits_to_probs,
  12. probs_to_logits,
  13. )
  14. from torch.nn.functional import binary_cross_entropy_with_logits
  15. from torch.types import _Number, _size, Number
  16. __all__ = ["ContinuousBernoulli"]
  17. class ContinuousBernoulli(ExponentialFamily):
  18. r"""
  19. Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
  20. or :attr:`logits` (but not both).
  21. The distribution is supported in [0, 1] and parameterized by 'probs' (in
  22. (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
  23. does not correspond to a probability and 'logits' does not correspond to
  24. log-odds, but the same names are used due to the similarity with the
  25. Bernoulli. See [1] for more details.
  26. Example::
  27. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  28. >>> m = ContinuousBernoulli(torch.tensor([0.3]))
  29. >>> m.sample()
  30. tensor([ 0.2538])
  31. Args:
  32. probs (Number, Tensor): (0,1) valued parameters
  33. logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
  34. [1] The continuous Bernoulli: fixing a pervasive error in variational
  35. autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
  36. https://arxiv.org/abs/1907.06845
  37. """
  38. # pyrefly: ignore [bad-override]
  39. arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
  40. support = constraints.unit_interval
  41. _mean_carrier_measure = 0
  42. has_rsample = True
  43. def __init__(
  44. self,
  45. probs: Tensor | Number | None = None,
  46. logits: Tensor | Number | None = None,
  47. lims: tuple[float, float] = (0.499, 0.501),
  48. validate_args: bool | None = None,
  49. ) -> None:
  50. if (probs is None) == (logits is None):
  51. raise ValueError(
  52. "Either `probs` or `logits` must be specified, but not both."
  53. )
  54. if probs is not None:
  55. is_scalar = isinstance(probs, _Number)
  56. # pyrefly: ignore [read-only]
  57. (self.probs,) = broadcast_all(probs)
  58. # validate 'probs' here if necessary as it is later clamped for numerical stability
  59. # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
  60. if validate_args is not None:
  61. if not self.arg_constraints["probs"].check(self.probs).all():
  62. raise ValueError("The parameter probs has invalid values")
  63. # pyrefly: ignore [read-only]
  64. self.probs = clamp_probs(self.probs)
  65. else:
  66. if logits is None:
  67. raise AssertionError("logits is unexpectedly None")
  68. is_scalar = isinstance(logits, _Number)
  69. # pyrefly: ignore [read-only]
  70. (self.logits,) = broadcast_all(logits)
  71. self._param = self.probs if probs is not None else self.logits
  72. if is_scalar:
  73. batch_shape = torch.Size()
  74. else:
  75. batch_shape = self._param.size()
  76. self._lims = lims
  77. super().__init__(batch_shape, validate_args=validate_args)
  78. def expand(self, batch_shape, _instance=None):
  79. new = self._get_checked_instance(ContinuousBernoulli, _instance)
  80. new._lims = self._lims
  81. batch_shape = torch.Size(batch_shape)
  82. if "probs" in self.__dict__:
  83. new.probs = self.probs.expand(batch_shape)
  84. new._param = new.probs
  85. if "logits" in self.__dict__:
  86. new.logits = self.logits.expand(batch_shape)
  87. new._param = new.logits
  88. super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
  89. new._validate_args = self._validate_args
  90. return new
  91. def _new(self, *args, **kwargs):
  92. return self._param.new(*args, **kwargs)
  93. def _outside_unstable_region(self):
  94. return torch.max(
  95. torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1])
  96. )
  97. def _cut_probs(self):
  98. return torch.where(
  99. self._outside_unstable_region(),
  100. self.probs,
  101. self._lims[0] * torch.ones_like(self.probs),
  102. )
  103. def _cont_bern_log_norm(self):
  104. """computes the log normalizing constant as a function of the 'probs' parameter"""
  105. cut_probs = self._cut_probs()
  106. cut_probs_below_half = torch.where(
  107. torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs)
  108. )
  109. cut_probs_above_half = torch.where(
  110. torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs)
  111. )
  112. log_norm = torch.log(
  113. torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))
  114. ) - torch.where(
  115. torch.le(cut_probs, 0.5),
  116. torch.log1p(-2.0 * cut_probs_below_half),
  117. torch.log(2.0 * cut_probs_above_half - 1.0),
  118. )
  119. x = torch.pow(self.probs - 0.5, 2)
  120. taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
  121. return torch.where(self._outside_unstable_region(), log_norm, taylor)
  122. @property
  123. def mean(self) -> Tensor:
  124. cut_probs = self._cut_probs()
  125. mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (
  126. torch.log1p(-cut_probs) - torch.log(cut_probs)
  127. )
  128. x = self.probs - 0.5
  129. taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
  130. return torch.where(self._outside_unstable_region(), mus, taylor)
  131. @property
  132. def stddev(self) -> Tensor:
  133. return torch.sqrt(self.variance)
  134. @property
  135. def variance(self) -> Tensor:
  136. cut_probs = self._cut_probs()
  137. vars = cut_probs * (cut_probs - 1.0) / torch.pow(
  138. 1.0 - 2.0 * cut_probs, 2
  139. ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
  140. x = torch.pow(self.probs - 0.5, 2)
  141. taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x
  142. return torch.where(self._outside_unstable_region(), vars, taylor)
  143. @lazy_property
  144. def logits(self) -> Tensor:
  145. return probs_to_logits(self.probs, is_binary=True)
  146. @lazy_property
  147. def probs(self) -> Tensor:
  148. return clamp_probs(logits_to_probs(self.logits, is_binary=True))
  149. @property
  150. def param_shape(self) -> torch.Size:
  151. return self._param.size()
  152. def sample(self, sample_shape=torch.Size()):
  153. shape = self._extended_shape(sample_shape)
  154. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  155. with torch.no_grad():
  156. return self.icdf(u)
  157. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  158. shape = self._extended_shape(sample_shape)
  159. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  160. return self.icdf(u)
  161. def log_prob(self, value):
  162. if self._validate_args:
  163. self._validate_sample(value)
  164. logits, value = broadcast_all(self.logits, value)
  165. return (
  166. -binary_cross_entropy_with_logits(logits, value, reduction="none")
  167. + self._cont_bern_log_norm()
  168. )
  169. def cdf(self, value):
  170. if self._validate_args:
  171. self._validate_sample(value)
  172. cut_probs = self._cut_probs()
  173. cdfs = (
  174. torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
  175. + cut_probs
  176. - 1.0
  177. ) / (2.0 * cut_probs - 1.0)
  178. unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
  179. return torch.where(
  180. torch.le(value, 0.0),
  181. torch.zeros_like(value),
  182. torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs),
  183. )
  184. def icdf(self, value):
  185. cut_probs = self._cut_probs()
  186. return torch.where(
  187. self._outside_unstable_region(),
  188. (
  189. torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
  190. - torch.log1p(-cut_probs)
  191. )
  192. / (torch.log(cut_probs) - torch.log1p(-cut_probs)),
  193. value,
  194. )
  195. def entropy(self):
  196. log_probs0 = torch.log1p(-self.probs)
  197. log_probs1 = torch.log(self.probs)
  198. return (
  199. self.mean * (log_probs0 - log_probs1)
  200. - self._cont_bern_log_norm()
  201. - log_probs0
  202. )
  203. @property
  204. def _natural_params(self) -> tuple[Tensor]:
  205. return (self.logits,)
  206. # pyrefly: ignore [bad-override]
  207. def _log_normalizer(self, x):
  208. """computes the log normalizing constant as a function of the natural parameter"""
  209. out_unst_reg = torch.max(
  210. torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5)
  211. )
  212. cut_nat_params = torch.where(
  213. out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x)
  214. )
  215. log_norm = torch.log(
  216. torch.abs(torch.special.expm1(cut_nat_params))
  217. ) - torch.log(torch.abs(cut_nat_params))
  218. taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
  219. return torch.where(out_unst_reg, log_norm, taylor)