multivariate_normal.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # mypy: allow-untyped-defs
  2. import math
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import _standard_normal, lazy_property
  8. from torch.types import _size
  9. __all__ = ["MultivariateNormal"]
  10. def _batch_mv(bmat, bvec):
  11. r"""
  12. Performs a batched matrix-vector product, with compatible but different batch shapes.
  13. This function takes as input `bmat`, containing :math:`n \times n` matrices, and
  14. `bvec`, containing length :math:`n` vectors.
  15. Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
  16. to a batch shape. They are not necessarily assumed to have the same batch shape,
  17. just ones which can be broadcasted.
  18. """
  19. return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
  20. def _batch_mahalanobis(bL, bx):
  21. r"""
  22. Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
  23. for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
  24. Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
  25. shape, but `bL` one should be able to broadcasted to `bx` one.
  26. """
  27. n = bx.size(-1)
  28. bx_batch_shape = bx.shape[:-1]
  29. # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
  30. # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
  31. bx_batch_dims = len(bx_batch_shape)
  32. bL_batch_dims = bL.dim() - 2
  33. outer_batch_dims = bx_batch_dims - bL_batch_dims
  34. old_batch_dims = outer_batch_dims + bL_batch_dims
  35. new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
  36. # Reshape bx with the shape (..., 1, i, j, 1, n)
  37. bx_new_shape = bx.shape[:outer_batch_dims]
  38. for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
  39. bx_new_shape += (sx // sL, sL)
  40. bx_new_shape += (n,)
  41. bx = bx.reshape(bx_new_shape)
  42. # Permute bx to make it have shape (..., 1, j, i, 1, n)
  43. permute_dims = (
  44. list(range(outer_batch_dims))
  45. + list(range(outer_batch_dims, new_batch_dims, 2))
  46. + list(range(outer_batch_dims + 1, new_batch_dims, 2))
  47. + [new_batch_dims]
  48. )
  49. bx = bx.permute(permute_dims)
  50. flat_L = bL.reshape(-1, n, n) # shape = b x n x n
  51. flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
  52. flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
  53. M_swap = (
  54. torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
  55. ) # shape = b x c
  56. M = M_swap.t() # shape = c x b
  57. # Now we revert the above reshape and permute operators.
  58. permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
  59. permute_inv_dims = list(range(outer_batch_dims))
  60. for i in range(bL_batch_dims):
  61. permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
  62. reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
  63. return reshaped_M.reshape(bx_batch_shape)
  64. def _precision_to_scale_tril(P):
  65. # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
  66. Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
  67. L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
  68. Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
  69. L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
  70. return L
  71. class MultivariateNormal(Distribution):
  72. r"""
  73. Creates a multivariate normal (also called Gaussian) distribution
  74. parameterized by a mean vector and a covariance matrix.
  75. The multivariate normal distribution can be parameterized either
  76. in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
  77. or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
  78. or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
  79. diagonal entries, such that
  80. :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
  81. can be obtained via e.g. Cholesky decomposition of the covariance.
  82. Example:
  83. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  84. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  85. >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
  86. >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
  87. tensor([-0.2102, -0.5429])
  88. Args:
  89. loc (Tensor): mean of the distribution
  90. covariance_matrix (Tensor): positive-definite covariance matrix
  91. precision_matrix (Tensor): positive-definite precision matrix
  92. scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
  93. Note:
  94. Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
  95. :attr:`scale_tril` can be specified.
  96. Using :attr:`scale_tril` will be more efficient: all computations internally
  97. are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
  98. :attr:`precision_matrix` is passed instead, it is only used to compute
  99. the corresponding lower triangular matrices using a Cholesky decomposition.
  100. """
  101. # pyrefly: ignore [bad-override]
  102. arg_constraints = {
  103. "loc": constraints.real_vector,
  104. "covariance_matrix": constraints.positive_definite,
  105. "precision_matrix": constraints.positive_definite,
  106. "scale_tril": constraints.lower_cholesky,
  107. }
  108. support = constraints.real_vector
  109. has_rsample = True
  110. def __init__(
  111. self,
  112. loc: Tensor,
  113. covariance_matrix: Tensor | None = None,
  114. precision_matrix: Tensor | None = None,
  115. scale_tril: Tensor | None = None,
  116. validate_args: bool | None = None,
  117. ) -> None:
  118. if loc.dim() < 1:
  119. raise ValueError("loc must be at least one-dimensional.")
  120. if (covariance_matrix is not None) + (scale_tril is not None) + (
  121. precision_matrix is not None
  122. ) != 1:
  123. raise ValueError(
  124. "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
  125. )
  126. if scale_tril is not None:
  127. if scale_tril.dim() < 2:
  128. raise ValueError(
  129. "scale_tril matrix must be at least two-dimensional, "
  130. "with optional leading batch dimensions"
  131. )
  132. batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
  133. # pyrefly: ignore [read-only]
  134. self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
  135. elif covariance_matrix is not None:
  136. if covariance_matrix.dim() < 2:
  137. raise ValueError(
  138. "covariance_matrix must be at least two-dimensional, "
  139. "with optional leading batch dimensions"
  140. )
  141. batch_shape = torch.broadcast_shapes(
  142. covariance_matrix.shape[:-2], loc.shape[:-1]
  143. )
  144. # pyrefly: ignore [read-only]
  145. self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
  146. else:
  147. if precision_matrix is None:
  148. raise AssertionError("precision_matrix is unexpectedly None")
  149. if precision_matrix.dim() < 2:
  150. raise ValueError(
  151. "precision_matrix must be at least two-dimensional, "
  152. "with optional leading batch dimensions"
  153. )
  154. batch_shape = torch.broadcast_shapes(
  155. precision_matrix.shape[:-2], loc.shape[:-1]
  156. )
  157. # pyrefly: ignore [read-only]
  158. self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
  159. self.loc = loc.expand(batch_shape + (-1,))
  160. event_shape = self.loc.shape[-1:]
  161. # pyrefly: ignore [bad-argument-type]
  162. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  163. if scale_tril is not None:
  164. self._unbroadcasted_scale_tril = scale_tril
  165. elif covariance_matrix is not None:
  166. self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
  167. else: # precision_matrix is not None
  168. self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
  169. def expand(self, batch_shape, _instance=None):
  170. new = self._get_checked_instance(MultivariateNormal, _instance)
  171. batch_shape = torch.Size(batch_shape)
  172. loc_shape = batch_shape + self.event_shape
  173. cov_shape = batch_shape + self.event_shape + self.event_shape
  174. new.loc = self.loc.expand(loc_shape)
  175. new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
  176. if "covariance_matrix" in self.__dict__:
  177. new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
  178. if "scale_tril" in self.__dict__:
  179. new.scale_tril = self.scale_tril.expand(cov_shape)
  180. if "precision_matrix" in self.__dict__:
  181. new.precision_matrix = self.precision_matrix.expand(cov_shape)
  182. super(MultivariateNormal, new).__init__(
  183. batch_shape, self.event_shape, validate_args=False
  184. )
  185. new._validate_args = self._validate_args
  186. return new
  187. @lazy_property
  188. def scale_tril(self) -> Tensor:
  189. return self._unbroadcasted_scale_tril.expand(
  190. self._batch_shape + self._event_shape + self._event_shape
  191. )
  192. @lazy_property
  193. def covariance_matrix(self) -> Tensor:
  194. return torch.matmul(
  195. self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
  196. ).expand(self._batch_shape + self._event_shape + self._event_shape)
  197. @lazy_property
  198. def precision_matrix(self) -> Tensor:
  199. return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
  200. self._batch_shape + self._event_shape + self._event_shape
  201. )
  202. @property
  203. def mean(self) -> Tensor:
  204. return self.loc
  205. @property
  206. def mode(self) -> Tensor:
  207. return self.loc
  208. @property
  209. def variance(self) -> Tensor:
  210. return (
  211. self._unbroadcasted_scale_tril.pow(2)
  212. .sum(-1)
  213. .expand(self._batch_shape + self._event_shape)
  214. )
  215. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  216. shape = self._extended_shape(sample_shape)
  217. eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
  218. return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
  219. def log_prob(self, value):
  220. if self._validate_args:
  221. self._validate_sample(value)
  222. diff = value - self.loc
  223. M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
  224. half_log_det = (
  225. self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  226. )
  227. return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
  228. def entropy(self):
  229. half_log_det = (
  230. self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  231. )
  232. H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
  233. if len(self._batch_shape) == 0:
  234. return H
  235. else:
  236. return H.expand(self._batch_shape)