dirichlet.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.autograd import Function
  5. from torch.autograd.function import once_differentiable
  6. from torch.distributions import constraints
  7. from torch.distributions.exp_family import ExponentialFamily
  8. from torch.types import _size
  9. __all__ = ["Dirichlet"]
  10. # This helper is exposed for testing.
  11. def _Dirichlet_backward(x, concentration, grad_output):
  12. total = concentration.sum(-1, True).expand_as(concentration)
  13. grad = torch._dirichlet_grad(x, concentration, total)
  14. return grad * (grad_output - (x * grad_output).sum(-1, True))
  15. class _Dirichlet(Function):
  16. @staticmethod
  17. # pyrefly: ignore [bad-override]
  18. def forward(ctx, concentration):
  19. x = torch._sample_dirichlet(concentration)
  20. ctx.save_for_backward(x, concentration)
  21. return x
  22. @staticmethod
  23. @once_differentiable
  24. # pyrefly: ignore [bad-override]
  25. def backward(ctx, grad_output):
  26. x, concentration = ctx.saved_tensors
  27. return _Dirichlet_backward(x, concentration, grad_output)
  28. class Dirichlet(ExponentialFamily):
  29. r"""
  30. Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.
  31. Example::
  32. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  33. >>> m = Dirichlet(torch.tensor([0.5, 0.5]))
  34. >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5]
  35. tensor([ 0.1046, 0.8954])
  36. Args:
  37. concentration (Tensor): concentration parameter of the distribution
  38. (often referred to as alpha)
  39. """
  40. # pyrefly: ignore [bad-override]
  41. arg_constraints = {
  42. "concentration": constraints.independent(constraints.positive, 1)
  43. }
  44. support = constraints.simplex
  45. has_rsample = True
  46. def __init__(
  47. self,
  48. concentration: Tensor,
  49. validate_args: bool | None = None,
  50. ) -> None:
  51. if concentration.dim() < 1:
  52. raise ValueError(
  53. "`concentration` parameter must be at least one-dimensional."
  54. )
  55. self.concentration = concentration
  56. batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
  57. # pyrefly: ignore [bad-argument-type]
  58. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  59. def expand(self, batch_shape, _instance=None):
  60. new = self._get_checked_instance(Dirichlet, _instance)
  61. batch_shape = torch.Size(batch_shape)
  62. new.concentration = self.concentration.expand(batch_shape + self.event_shape)
  63. super(Dirichlet, new).__init__(
  64. batch_shape, self.event_shape, validate_args=False
  65. )
  66. new._validate_args = self._validate_args
  67. return new
  68. def rsample(self, sample_shape: _size = ()) -> Tensor:
  69. shape = self._extended_shape(sample_shape)
  70. concentration = self.concentration.expand(shape)
  71. return _Dirichlet.apply(concentration)
  72. def log_prob(self, value):
  73. if self._validate_args:
  74. self._validate_sample(value)
  75. return (
  76. torch.xlogy(self.concentration - 1.0, value).sum(-1)
  77. + torch.lgamma(self.concentration.sum(-1))
  78. - torch.lgamma(self.concentration).sum(-1)
  79. )
  80. @property
  81. def mean(self) -> Tensor:
  82. return self.concentration / self.concentration.sum(-1, True)
  83. @property
  84. def mode(self) -> Tensor:
  85. concentrationm1 = (self.concentration - 1).clamp(min=0.0)
  86. mode = concentrationm1 / concentrationm1.sum(-1, True)
  87. mask = (self.concentration < 1).all(dim=-1)
  88. mode[mask] = torch.nn.functional.one_hot(
  89. mode[mask].argmax(dim=-1), concentrationm1.shape[-1]
  90. ).to(mode)
  91. return mode
  92. @property
  93. def variance(self) -> Tensor:
  94. con0 = self.concentration.sum(-1, True)
  95. return (
  96. self.concentration
  97. * (con0 - self.concentration)
  98. / (con0.pow(2) * (con0 + 1))
  99. )
  100. def entropy(self):
  101. k = self.concentration.size(-1)
  102. a0 = self.concentration.sum(-1)
  103. return (
  104. torch.lgamma(self.concentration).sum(-1)
  105. - torch.lgamma(a0)
  106. - (k - a0) * torch.digamma(a0)
  107. - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
  108. )
  109. @property
  110. def _natural_params(self) -> tuple[Tensor]:
  111. return (self.concentration,)
  112. # pyrefly: ignore [bad-override]
  113. def _log_normalizer(self, x):
  114. return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))