lkj_cholesky.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # mypy: allow-untyped-defs
  2. """
  3. This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).
  4. Original copyright notice:
  5. # Copyright: Contributors to the Pyro project.
  6. # SPDX-License-Identifier: Apache-2.0
  7. """
  8. import math
  9. import torch
  10. from torch import Tensor
  11. from torch.distributions import Beta, constraints
  12. from torch.distributions.distribution import Distribution
  13. from torch.distributions.utils import broadcast_all
  14. __all__ = ["LKJCholesky"]
  15. class LKJCholesky(Distribution):
  16. r"""
  17. LKJ distribution for lower Cholesky factor of correlation matrices.
  18. The distribution is controlled by ``concentration`` parameter :math:`\eta`
  19. to make the probability of the correlation matrix :math:`M` generated from
  20. a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that,
  21. when ``concentration == 1``, we have a uniform distribution over Cholesky
  22. factors of correlation matrices::
  23. L ~ LKJCholesky(dim, concentration)
  24. X = L @ L' ~ LKJCorr(dim, concentration)
  25. Note that this distribution samples the
  26. Cholesky factor of correlation matrices and not the correlation matrices
  27. themselves and thereby differs slightly from the derivations in [1] for
  28. the `LKJCorr` distribution. For sampling, this uses the Onion method from
  29. [1] Section 3.
  30. Example::
  31. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  32. >>> l = LKJCholesky(3, 0.5)
  33. >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix
  34. tensor([[ 1.0000, 0.0000, 0.0000],
  35. [ 0.3516, 0.9361, 0.0000],
  36. [-0.1899, 0.4748, 0.8593]])
  37. Args:
  38. dimension (dim): dimension of the matrices
  39. concentration (float or Tensor): concentration/shape parameter of the
  40. distribution (often referred to as eta)
  41. **References**
  42. [1] `Generating random correlation matrices based on vines and extended onion method` (2009),
  43. Daniel Lewandowski, Dorota Kurowicka, Harry Joe.
  44. Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
  45. """
  46. # pyrefly: ignore [bad-override]
  47. arg_constraints = {"concentration": constraints.positive}
  48. support = constraints.corr_cholesky
  49. def __init__(
  50. self,
  51. dim: int,
  52. concentration: Tensor | float = 1.0,
  53. validate_args: bool | None = None,
  54. ) -> None:
  55. if dim < 2:
  56. raise ValueError(
  57. f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}."
  58. )
  59. self.dim = dim
  60. (self.concentration,) = broadcast_all(concentration)
  61. batch_shape = self.concentration.size()
  62. event_shape = torch.Size((dim, dim))
  63. # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].
  64. marginal_conc = self.concentration + 0.5 * (self.dim - 2)
  65. offset = torch.arange(
  66. self.dim - 1,
  67. dtype=self.concentration.dtype,
  68. device=self.concentration.device,
  69. )
  70. offset = torch.cat([offset.new_zeros((1,)), offset])
  71. beta_conc1 = offset + 0.5
  72. beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset
  73. self._beta = Beta(beta_conc1, beta_conc0)
  74. super().__init__(batch_shape, event_shape, validate_args)
  75. def expand(self, batch_shape, _instance=None):
  76. new = self._get_checked_instance(LKJCholesky, _instance)
  77. batch_shape = torch.Size(batch_shape)
  78. new.dim = self.dim
  79. new.concentration = self.concentration.expand(batch_shape)
  80. new._beta = self._beta.expand(batch_shape + (self.dim,))
  81. super(LKJCholesky, new).__init__(
  82. batch_shape, self.event_shape, validate_args=False
  83. )
  84. new._validate_args = self._validate_args
  85. return new
  86. def sample(self, sample_shape=torch.Size()):
  87. # This uses the Onion method, but there are a few differences from [1] Sec. 3.2:
  88. # - This vectorizes the for loop and also works for heterogeneous eta.
  89. # - Same algorithm generalizes to n=1.
  90. # - The procedure is simplified since we are sampling the cholesky factor of
  91. # the correlation matrix instead of the correlation matrix itself. As such,
  92. # we only need to generate `w`.
  93. y = self._beta.sample(sample_shape).unsqueeze(-1)
  94. u_normal = torch.randn(
  95. self._extended_shape(sample_shape), dtype=y.dtype, device=y.device
  96. ).tril(-1)
  97. u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True)
  98. # Replace NaNs in first row
  99. u_hypersphere[..., 0, :].fill_(0.0)
  100. w = torch.sqrt(y) * u_hypersphere
  101. # Fill diagonal elements; clamp for numerical stability
  102. eps = torch.finfo(w.dtype).tiny
  103. diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt()
  104. w += torch.diag_embed(diag_elems)
  105. return w
  106. def log_prob(self, value):
  107. # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
  108. # The probability of a correlation matrix is proportional to
  109. # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
  110. # Additionally, the Jacobian of the transformation from Cholesky factor to
  111. # correlation matrix is:
  112. # prod(L_ii ^ (D - i))
  113. # So the probability of a Cholesky factor is proportional to
  114. # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
  115. # with order_i = 2 * concentration - 2 + D - i
  116. if self._validate_args:
  117. self._validate_sample(value)
  118. diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
  119. order = torch.arange(2, self.dim + 1, device=self.concentration.device)
  120. order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
  121. unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
  122. # Compute normalization constant (page 1999 of [1])
  123. dm1 = self.dim - 1
  124. alpha = self.concentration + 0.5 * dm1
  125. denominator = torch.lgamma(alpha) * dm1
  126. numerator = torch.mvlgamma(alpha - 0.5, dm1)
  127. # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
  128. # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
  129. # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
  130. pi_constant = 0.5 * dm1 * math.log(math.pi)
  131. normalize_term = pi_constant + numerator - denominator
  132. return unnormalized_log_pdf - normalize_term