wishart.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # mypy: allow-untyped-defs
  2. import math
  3. import warnings
  4. import torch
  5. from torch import nan, Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.exp_family import ExponentialFamily
  8. from torch.distributions.multivariate_normal import _precision_to_scale_tril
  9. from torch.distributions.utils import lazy_property
  10. from torch.types import _Number, _size, Number
  11. __all__ = ["Wishart"]
  12. _log_2 = math.log(2)
  13. def _mvdigamma(x: Tensor, p: int) -> Tensor:
  14. if not x.gt((p - 1) / 2).all():
  15. raise AssertionError("Wrong domain for multivariate digamma function.")
  16. return torch.digamma(
  17. x.unsqueeze(-1)
  18. - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
  19. ).sum(-1)
  20. def _clamp_above_eps(x: Tensor) -> Tensor:
  21. # We assume positive input for this function
  22. return x.clamp(min=torch.finfo(x.dtype).eps)
  23. class Wishart(ExponentialFamily):
  24. r"""
  25. Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
  26. or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
  27. Example:
  28. >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
  29. >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
  30. >>> m.sample() # Wishart distributed with mean=`df * I` and
  31. >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
  32. Args:
  33. df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
  34. covariance_matrix (Tensor): positive-definite covariance matrix
  35. precision_matrix (Tensor): positive-definite precision matrix
  36. scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
  37. Note:
  38. Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
  39. :attr:`scale_tril` can be specified.
  40. Using :attr:`scale_tril` will be more efficient: all computations internally
  41. are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
  42. :attr:`precision_matrix` is passed instead, it is only used to compute
  43. the corresponding lower triangular matrices using a Cholesky decomposition.
  44. 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
  45. **References**
  46. [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
  47. [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
  48. [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
  49. [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
  50. [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
  51. """
  52. support = constraints.positive_definite
  53. has_rsample = True
  54. _mean_carrier_measure = 0
  55. @property
  56. def arg_constraints(self):
  57. return {
  58. "covariance_matrix": constraints.positive_definite,
  59. "precision_matrix": constraints.positive_definite,
  60. "scale_tril": constraints.lower_cholesky,
  61. "df": constraints.greater_than(self.event_shape[-1] - 1),
  62. }
  63. def __init__(
  64. self,
  65. df: Tensor | Number,
  66. covariance_matrix: Tensor | None = None,
  67. precision_matrix: Tensor | None = None,
  68. scale_tril: Tensor | None = None,
  69. validate_args: bool | None = None,
  70. ) -> None:
  71. if (
  72. (covariance_matrix is not None)
  73. + (scale_tril is not None)
  74. + (precision_matrix is not None)
  75. ) != 1:
  76. raise AssertionError(
  77. "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
  78. )
  79. param = next(
  80. p
  81. for p in (covariance_matrix, precision_matrix, scale_tril)
  82. if p is not None
  83. )
  84. if param.dim() < 2:
  85. raise ValueError(
  86. "scale_tril must be at least two-dimensional, with optional leading batch dimensions"
  87. )
  88. if isinstance(df, _Number):
  89. batch_shape = torch.Size(param.shape[:-2])
  90. self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
  91. else:
  92. batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
  93. self.df = df.expand(batch_shape)
  94. event_shape = param.shape[-2:]
  95. if self.df.le(event_shape[-1] - 1).any():
  96. raise ValueError(
  97. f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1] - 1}."
  98. )
  99. if scale_tril is not None:
  100. # pyrefly: ignore [read-only]
  101. self.scale_tril = param.expand(batch_shape + (-1, -1))
  102. elif covariance_matrix is not None:
  103. # pyrefly: ignore [read-only]
  104. self.covariance_matrix = param.expand(batch_shape + (-1, -1))
  105. elif precision_matrix is not None:
  106. # pyrefly: ignore [read-only]
  107. self.precision_matrix = param.expand(batch_shape + (-1, -1))
  108. if self.df.lt(event_shape[-1]).any():
  109. warnings.warn(
  110. "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.",
  111. stacklevel=2,
  112. )
  113. # pyrefly: ignore [bad-argument-type]
  114. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  115. self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
  116. if scale_tril is not None:
  117. self._unbroadcasted_scale_tril = scale_tril
  118. elif covariance_matrix is not None:
  119. self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
  120. else: # precision_matrix is not None
  121. self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
  122. # Chi2 distribution is needed for Bartlett decomposition sampling
  123. self._dist_chi2 = torch.distributions.chi2.Chi2(
  124. df=(
  125. self.df.unsqueeze(-1)
  126. - torch.arange(
  127. self._event_shape[-1],
  128. dtype=self._unbroadcasted_scale_tril.dtype,
  129. device=self._unbroadcasted_scale_tril.device,
  130. ).expand(batch_shape + (-1,))
  131. )
  132. )
  133. def expand(self, batch_shape, _instance=None):
  134. new = self._get_checked_instance(Wishart, _instance)
  135. batch_shape = torch.Size(batch_shape)
  136. cov_shape = batch_shape + self.event_shape
  137. new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
  138. new.df = self.df.expand(batch_shape)
  139. new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
  140. if "covariance_matrix" in self.__dict__:
  141. new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
  142. if "scale_tril" in self.__dict__:
  143. new.scale_tril = self.scale_tril.expand(cov_shape)
  144. if "precision_matrix" in self.__dict__:
  145. new.precision_matrix = self.precision_matrix.expand(cov_shape)
  146. # Chi2 distribution is needed for Bartlett decomposition sampling
  147. new._dist_chi2 = torch.distributions.chi2.Chi2(
  148. df=(
  149. new.df.unsqueeze(-1)
  150. - torch.arange(
  151. self.event_shape[-1],
  152. dtype=new._unbroadcasted_scale_tril.dtype,
  153. device=new._unbroadcasted_scale_tril.device,
  154. ).expand(batch_shape + (-1,))
  155. )
  156. )
  157. super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
  158. new._validate_args = self._validate_args
  159. return new
  160. @lazy_property
  161. def scale_tril(self) -> Tensor:
  162. return self._unbroadcasted_scale_tril.expand(
  163. self._batch_shape + self._event_shape
  164. )
  165. @lazy_property
  166. def covariance_matrix(self) -> Tensor:
  167. return (
  168. self._unbroadcasted_scale_tril
  169. @ self._unbroadcasted_scale_tril.transpose(-2, -1)
  170. ).expand(self._batch_shape + self._event_shape)
  171. @lazy_property
  172. def precision_matrix(self) -> Tensor:
  173. identity = torch.eye(
  174. self._event_shape[-1],
  175. device=self._unbroadcasted_scale_tril.device,
  176. dtype=self._unbroadcasted_scale_tril.dtype,
  177. )
  178. return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
  179. self._batch_shape + self._event_shape
  180. )
  181. @property
  182. def mean(self) -> Tensor:
  183. return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
  184. @property
  185. def mode(self) -> Tensor:
  186. factor = self.df - self.covariance_matrix.shape[-1] - 1
  187. factor[factor <= 0] = nan
  188. return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
  189. @property
  190. def variance(self) -> Tensor:
  191. V = self.covariance_matrix # has shape (batch_shape x event_shape)
  192. diag_V = V.diagonal(dim1=-2, dim2=-1)
  193. return self.df.view(self._batch_shape + (1, 1)) * (
  194. V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)
  195. )
  196. def _bartlett_sampling(self, sample_shape=torch.Size()):
  197. p = self._event_shape[-1] # has singleton shape
  198. # Implemented Sampling using Bartlett decomposition
  199. noise = _clamp_above_eps(
  200. self._dist_chi2.rsample(sample_shape).sqrt()
  201. ).diag_embed(dim1=-2, dim2=-1)
  202. i, j = torch.tril_indices(p, p, offset=-1)
  203. noise[..., i, j] = torch.randn(
  204. torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
  205. dtype=noise.dtype,
  206. device=noise.device,
  207. )
  208. chol = self._unbroadcasted_scale_tril @ noise
  209. return chol @ chol.transpose(-2, -1)
  210. def rsample(
  211. self, sample_shape: _size = torch.Size(), max_try_correction=None
  212. ) -> Tensor:
  213. r"""
  214. .. warning::
  215. In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
  216. Several tries to correct singular samples are performed by default, but it may end up returning
  217. singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
  218. In those cases, the user should validate the samples and either fix the value of `df`
  219. or adjust `max_try_correction` value for argument in `.rsample` accordingly.
  220. """
  221. if max_try_correction is None:
  222. max_try_correction = 3 if torch._C._get_tracing_state() else 10
  223. sample_shape = torch.Size(sample_shape)
  224. sample = self._bartlett_sampling(sample_shape)
  225. # Below part is to improve numerical stability temporally and should be removed in the future
  226. is_singular = self.support.check(sample)
  227. if self._batch_shape:
  228. is_singular = is_singular.amax(self._batch_dims)
  229. if torch._C._get_tracing_state():
  230. # Less optimized version for JIT
  231. for _ in range(max_try_correction):
  232. sample_new = self._bartlett_sampling(sample_shape)
  233. sample = torch.where(is_singular, sample_new, sample)
  234. is_singular = ~self.support.check(sample)
  235. if self._batch_shape:
  236. is_singular = is_singular.amax(self._batch_dims)
  237. else:
  238. # More optimized version with data-dependent control flow.
  239. if is_singular.any():
  240. warnings.warn("Singular sample detected.", stacklevel=2)
  241. for _ in range(max_try_correction):
  242. sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
  243. sample[is_singular] = sample_new
  244. is_singular_new = ~self.support.check(sample_new)
  245. if self._batch_shape:
  246. is_singular_new = is_singular_new.amax(self._batch_dims)
  247. is_singular[is_singular.clone()] = is_singular_new
  248. if not is_singular.any():
  249. break
  250. return sample
  251. def log_prob(self, value):
  252. if self._validate_args:
  253. self._validate_sample(value)
  254. nu = self.df # has shape (batch_shape)
  255. p = self._event_shape[-1] # has singleton shape
  256. return (
  257. -nu
  258. * (
  259. p * _log_2 / 2
  260. + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
  261. .log()
  262. .sum(-1)
  263. )
  264. - torch.mvlgamma(nu / 2, p=p)
  265. + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
  266. - torch.cholesky_solve(value, self._unbroadcasted_scale_tril)
  267. .diagonal(dim1=-2, dim2=-1)
  268. .sum(dim=-1)
  269. / 2
  270. )
  271. def entropy(self):
  272. nu = self.df # has shape (batch_shape)
  273. p = self._event_shape[-1] # has singleton shape
  274. return (
  275. (p + 1)
  276. * (
  277. p * _log_2 / 2
  278. + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
  279. .log()
  280. .sum(-1)
  281. )
  282. + torch.mvlgamma(nu / 2, p=p)
  283. - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
  284. + nu * p / 2
  285. )
  286. @property
  287. def _natural_params(self) -> tuple[Tensor, Tensor]:
  288. nu = self.df # has shape (batch_shape)
  289. p = self._event_shape[-1] # has singleton shape
  290. return -self.precision_matrix / 2, (nu - p - 1) / 2
  291. # pyrefly: ignore [bad-override]
  292. def _log_normalizer(self, x, y):
  293. p = self._event_shape[-1]
  294. return (y + (p + 1) / 2) * (
  295. -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p
  296. ) + torch.mvlgamma(y + (p + 1) / 2, p=p)