relaxed_categorical.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.distributions import constraints
  5. from torch.distributions.categorical import Categorical
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.transformed_distribution import TransformedDistribution
  8. from torch.distributions.transforms import ExpTransform
  9. from torch.distributions.utils import broadcast_all, clamp_probs
  10. from torch.types import _size
  11. __all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"]
  12. class ExpRelaxedCategorical(Distribution):
  13. r"""
  14. Creates a ExpRelaxedCategorical parameterized by
  15. :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
  16. Returns the log of a point in the simplex. Based on the interface to
  17. :class:`OneHotCategorical`.
  18. Implementation based on [1].
  19. See also: :func:`torch.distributions.OneHotCategorical`
  20. Args:
  21. temperature (Tensor): relaxation temperature
  22. probs (Tensor): event probabilities
  23. logits (Tensor): unnormalized log probability for each event
  24. [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
  25. (Maddison et al., 2017)
  26. [2] Categorical Reparametrization with Gumbel-Softmax
  27. (Jang et al., 2017)
  28. """
  29. # pyrefly: ignore [bad-override]
  30. arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
  31. support = (
  32. constraints.real_vector
  33. ) # The true support is actually a submanifold of this.
  34. has_rsample = True
  35. def __init__(
  36. self,
  37. temperature: Tensor,
  38. probs: Tensor | None = None,
  39. logits: Tensor | None = None,
  40. validate_args: bool | None = None,
  41. ) -> None:
  42. self._categorical = Categorical(probs, logits)
  43. self.temperature = temperature
  44. batch_shape = self._categorical.batch_shape
  45. event_shape = self._categorical.param_shape[-1:]
  46. # pyrefly: ignore [bad-argument-type]
  47. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  48. def expand(self, batch_shape, _instance=None):
  49. new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
  50. batch_shape = torch.Size(batch_shape)
  51. new.temperature = self.temperature
  52. new._categorical = self._categorical.expand(batch_shape)
  53. super(ExpRelaxedCategorical, new).__init__(
  54. batch_shape, self.event_shape, validate_args=False
  55. )
  56. new._validate_args = self._validate_args
  57. return new
  58. def _new(self, *args, **kwargs):
  59. return self._categorical._new(*args, **kwargs)
  60. @property
  61. def param_shape(self) -> torch.Size:
  62. return self._categorical.param_shape
  63. @property
  64. def logits(self) -> Tensor:
  65. return self._categorical.logits
  66. @property
  67. def probs(self) -> Tensor:
  68. return self._categorical.probs
  69. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  70. shape = self._extended_shape(sample_shape)
  71. uniforms = clamp_probs(
  72. torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
  73. )
  74. gumbels = -((-(uniforms.log())).log())
  75. scores = (self.logits + gumbels) / self.temperature
  76. return scores - scores.logsumexp(dim=-1, keepdim=True)
  77. def log_prob(self, value):
  78. K = self._categorical._num_events
  79. if self._validate_args:
  80. self._validate_sample(value)
  81. logits, value = broadcast_all(self.logits, value)
  82. log_scale = torch.full_like(
  83. self.temperature, float(K)
  84. ).lgamma() - self.temperature.log().mul(-(K - 1))
  85. score = logits - value.mul(self.temperature)
  86. score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
  87. return score + log_scale
  88. class RelaxedOneHotCategorical(TransformedDistribution):
  89. r"""
  90. Creates a RelaxedOneHotCategorical distribution parametrized by
  91. :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
  92. This is a relaxed version of the :class:`OneHotCategorical` distribution, so
  93. its samples are on simplex, and are reparametrizable.
  94. Example::
  95. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  96. >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
  97. ... torch.tensor([0.1, 0.2, 0.3, 0.4]))
  98. >>> m.sample()
  99. tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
  100. Args:
  101. temperature (Tensor): relaxation temperature
  102. probs (Tensor): event probabilities
  103. logits (Tensor): unnormalized log probability for each event
  104. """
  105. arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
  106. # pyrefly: ignore [bad-override]
  107. support = constraints.simplex
  108. has_rsample = True
  109. # pyrefly: ignore [bad-override]
  110. base_dist: ExpRelaxedCategorical
  111. def __init__(
  112. self,
  113. temperature: Tensor,
  114. probs: Tensor | None = None,
  115. logits: Tensor | None = None,
  116. validate_args: bool | None = None,
  117. ) -> None:
  118. base_dist = ExpRelaxedCategorical(
  119. temperature, probs, logits, validate_args=validate_args
  120. )
  121. super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
  122. def expand(self, batch_shape, _instance=None):
  123. new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
  124. return super().expand(batch_shape, _instance=new)
  125. @property
  126. def temperature(self) -> Tensor:
  127. return self.base_dist.temperature
  128. @property
  129. def logits(self) -> Tensor:
  130. return self.base_dist.logits
  131. @property
  132. def probs(self) -> Tensor:
  133. return self.base_dist.probs