multinomial.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import inf, Tensor
  4. from torch.distributions import Categorical, constraints
  5. from torch.distributions.binomial import Binomial
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import broadcast_all
  8. __all__ = ["Multinomial"]
  9. class Multinomial(Distribution):
  10. r"""
  11. Creates a Multinomial distribution parameterized by :attr:`total_count` and
  12. either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
  13. :attr:`probs` indexes over categories. All other dimensions index over batches.
  14. Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
  15. called (see example below)
  16. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
  17. and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
  18. will return this normalized value.
  19. The `logits` argument will be interpreted as unnormalized log probabilities
  20. and can therefore be any real number. It will likewise be normalized so that
  21. the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
  22. will return this normalized value.
  23. - :meth:`sample` requires a single shared `total_count` for all
  24. parameters and samples.
  25. - :meth:`log_prob` allows different `total_count` for each parameter and
  26. sample.
  27. Example::
  28. >>> # xdoctest: +SKIP("FIXME: found invalid values")
  29. >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
  30. >>> x = m.sample() # equal probability of 0, 1, 2, 3
  31. tensor([ 21., 24., 30., 25.])
  32. >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
  33. tensor([-4.1338])
  34. Args:
  35. total_count (int): number of trials
  36. probs (Tensor): event probabilities
  37. logits (Tensor): event log probabilities (unnormalized)
  38. """
  39. # pyrefly: ignore [bad-override]
  40. arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
  41. total_count: int
  42. @property
  43. def mean(self) -> Tensor:
  44. return self.probs * self.total_count
  45. @property
  46. def variance(self) -> Tensor:
  47. return self.total_count * self.probs * (1 - self.probs)
  48. def __init__(
  49. self,
  50. total_count: int = 1,
  51. probs: Tensor | None = None,
  52. logits: Tensor | None = None,
  53. validate_args: bool | None = None,
  54. ) -> None:
  55. if not isinstance(total_count, int):
  56. raise NotImplementedError("inhomogeneous total_count is not supported")
  57. self.total_count = total_count
  58. self._categorical = Categorical(probs=probs, logits=logits)
  59. self._binomial = Binomial(total_count=total_count, probs=self.probs)
  60. batch_shape = self._categorical.batch_shape
  61. event_shape = self._categorical.param_shape[-1:]
  62. # pyrefly: ignore [bad-argument-type]
  63. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  64. def expand(self, batch_shape, _instance=None):
  65. new = self._get_checked_instance(Multinomial, _instance)
  66. batch_shape = torch.Size(batch_shape)
  67. new.total_count = self.total_count
  68. new._categorical = self._categorical.expand(batch_shape)
  69. super(Multinomial, new).__init__(
  70. batch_shape, self.event_shape, validate_args=False
  71. )
  72. new._validate_args = self._validate_args
  73. return new
  74. def _new(self, *args, **kwargs):
  75. return self._categorical._new(*args, **kwargs)
  76. @constraints.dependent_property(is_discrete=True, event_dim=1)
  77. # pyrefly: ignore [bad-override]
  78. def support(self):
  79. return constraints.multinomial(self.total_count)
  80. @property
  81. def logits(self) -> Tensor:
  82. return self._categorical.logits
  83. @property
  84. def probs(self) -> Tensor:
  85. return self._categorical.probs
  86. @property
  87. def param_shape(self) -> torch.Size:
  88. return self._categorical.param_shape
  89. def sample(self, sample_shape=torch.Size()):
  90. sample_shape = torch.Size(sample_shape)
  91. samples = self._categorical.sample(
  92. torch.Size((self.total_count,)) + sample_shape
  93. )
  94. # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
  95. # (sample_shape, batch_shape, total_count)
  96. shifted_idx = list(range(samples.dim()))
  97. shifted_idx.append(shifted_idx.pop(0))
  98. samples = samples.permute(*shifted_idx)
  99. counts = samples.new(self._extended_shape(sample_shape)).zero_()
  100. counts.scatter_add_(-1, samples, torch.ones_like(samples))
  101. return counts.type_as(self.probs)
  102. def entropy(self):
  103. n = torch.tensor(self.total_count)
  104. cat_entropy = self._categorical.entropy()
  105. term1 = n * cat_entropy - torch.lgamma(n + 1)
  106. support = self._binomial.enumerate_support(expand=False)[1:]
  107. binomial_probs = torch.exp(self._binomial.log_prob(support))
  108. weights = torch.lgamma(support + 1)
  109. term2 = (binomial_probs * weights).sum([0, -1])
  110. return term1 + term2
  111. def log_prob(self, value):
  112. if self._validate_args:
  113. self._validate_sample(value)
  114. logits, value = broadcast_all(self.logits, value)
  115. logits = logits.clone(memory_format=torch.contiguous_format)
  116. log_factorial_n = torch.lgamma(value.sum(-1) + 1)
  117. log_factorial_xs = torch.lgamma(value + 1).sum(-1)
  118. logits[(value == 0) & (logits == -inf)] = 0
  119. log_powers = (logits * value).sum(-1)
  120. return log_factorial_n - log_factorial_xs + log_powers