mixture_same_family.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.distributions import Categorical, constraints
  5. from torch.distributions.constraints import MixtureSameFamilyConstraint
  6. from torch.distributions.distribution import Distribution
  7. __all__ = ["MixtureSameFamily"]
  8. class MixtureSameFamily(Distribution):
  9. r"""
  10. The `MixtureSameFamily` distribution implements a (batch of) mixture
  11. distribution where all component are from different parameterizations of
  12. the same distribution type. It is parameterized by a `Categorical`
  13. "selecting distribution" (over `k` component) and a component
  14. distribution, i.e., a `Distribution` with a rightmost batch shape
  15. (equal to `[k]`) which indexes each (batch of) component.
  16. Examples::
  17. >>> # xdoctest: +SKIP("undefined vars")
  18. >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
  19. >>> # weighted normal distributions
  20. >>> mix = D.Categorical(torch.ones(5,))
  21. >>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
  22. >>> gmm = MixtureSameFamily(mix, comp)
  23. >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
  24. >>> # weighted bivariate normal distributions
  25. >>> mix = D.Categorical(torch.ones(5,))
  26. >>> comp = D.Independent(D.Normal(
  27. ... torch.randn(5,2), torch.rand(5,2)), 1)
  28. >>> gmm = MixtureSameFamily(mix, comp)
  29. >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
  30. >>> # consisting of 5 random weighted bivariate normal distributions
  31. >>> mix = D.Categorical(torch.rand(3,5))
  32. >>> comp = D.Independent(D.Normal(
  33. ... torch.randn(3,5,2), torch.rand(3,5,2)), 1)
  34. >>> gmm = MixtureSameFamily(mix, comp)
  35. Args:
  36. mixture_distribution: `torch.distributions.Categorical`-like
  37. instance. Manages the probability of selecting component.
  38. The number of categories must match the rightmost batch
  39. dimension of the `component_distribution`. Must have either
  40. scalar `batch_shape` or `batch_shape` matching
  41. `component_distribution.batch_shape[:-1]`
  42. component_distribution: `torch.distributions.Distribution`-like
  43. instance. Right-most batch dimension indexes component.
  44. """
  45. arg_constraints: dict[str, constraints.Constraint] = {}
  46. has_rsample = False
  47. def __init__(
  48. self,
  49. mixture_distribution: Categorical,
  50. component_distribution: Distribution,
  51. validate_args: bool | None = None,
  52. ) -> None:
  53. self._mixture_distribution = mixture_distribution
  54. self._component_distribution = component_distribution
  55. if not isinstance(self._mixture_distribution, Categorical):
  56. raise ValueError(
  57. " The Mixture distribution needs to be an "
  58. " instance of torch.distributions.Categorical"
  59. )
  60. if not isinstance(self._component_distribution, Distribution):
  61. raise ValueError(
  62. "The Component distribution need to be an "
  63. "instance of torch.distributions.Distribution"
  64. )
  65. # Check that batch size matches
  66. mdbs = self._mixture_distribution.batch_shape
  67. cdbs = self._component_distribution.batch_shape[:-1]
  68. for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
  69. if size1 != 1 and size2 != 1 and size1 != size2:
  70. raise ValueError(
  71. f"`mixture_distribution.batch_shape` ({mdbs}) is not "
  72. "compatible with `component_distribution."
  73. f"batch_shape`({cdbs})"
  74. )
  75. # Check that the number of mixture component matches
  76. km = self._mixture_distribution.logits.shape[-1]
  77. kc = self._component_distribution.batch_shape[-1]
  78. if km is not None and kc is not None and km != kc:
  79. raise ValueError(
  80. f"`mixture_distribution component` ({km}) does not"
  81. " equal `component_distribution.batch_shape[-1]`"
  82. f" ({kc})"
  83. )
  84. self._num_component = km
  85. event_shape = self._component_distribution.event_shape
  86. self._event_ndims = len(event_shape)
  87. super().__init__(
  88. # pyrefly: ignore [bad-argument-type]
  89. batch_shape=cdbs,
  90. event_shape=event_shape,
  91. validate_args=validate_args,
  92. )
  93. def expand(self, batch_shape, _instance=None):
  94. batch_shape = torch.Size(batch_shape)
  95. batch_shape_comp = batch_shape + (self._num_component,)
  96. new = self._get_checked_instance(MixtureSameFamily, _instance)
  97. new._component_distribution = self._component_distribution.expand(
  98. batch_shape_comp
  99. )
  100. new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
  101. new._num_component = self._num_component
  102. new._event_ndims = self._event_ndims
  103. event_shape = new._component_distribution.event_shape
  104. super(MixtureSameFamily, new).__init__(
  105. batch_shape=batch_shape, event_shape=event_shape, validate_args=False
  106. )
  107. new._validate_args = self._validate_args
  108. return new
  109. @constraints.dependent_property
  110. # pyrefly: ignore [bad-override]
  111. def support(self):
  112. return MixtureSameFamilyConstraint(self._component_distribution.support)
  113. @property
  114. def mixture_distribution(self) -> Categorical:
  115. return self._mixture_distribution
  116. @property
  117. def component_distribution(self) -> Distribution:
  118. return self._component_distribution
  119. @property
  120. def mean(self) -> Tensor:
  121. probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
  122. return torch.sum(
  123. probs * self.component_distribution.mean, dim=-1 - self._event_ndims
  124. ) # [B, E]
  125. @property
  126. def variance(self) -> Tensor:
  127. # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
  128. probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
  129. mean_cond_var = torch.sum(
  130. probs * self.component_distribution.variance, dim=-1 - self._event_ndims
  131. )
  132. var_cond_mean = torch.sum(
  133. probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),
  134. dim=-1 - self._event_ndims,
  135. )
  136. return mean_cond_var + var_cond_mean
  137. def cdf(self, x):
  138. x = self._pad(x)
  139. cdf_x = self.component_distribution.cdf(x)
  140. mix_prob = self.mixture_distribution.probs
  141. return torch.sum(cdf_x * mix_prob, dim=-1)
  142. def log_prob(self, x):
  143. if self._validate_args:
  144. self._validate_sample(x)
  145. x = self._pad(x)
  146. log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
  147. log_mix_prob = torch.log_softmax(
  148. self.mixture_distribution.logits, dim=-1
  149. ) # [B, k]
  150. return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]
  151. def sample(self, sample_shape=torch.Size()):
  152. with torch.no_grad():
  153. sample_len = len(sample_shape)
  154. batch_len = len(self.batch_shape)
  155. gather_dim = sample_len + batch_len
  156. es = self.event_shape
  157. # mixture samples [n, B]
  158. mix_sample = self.mixture_distribution.sample(sample_shape)
  159. mix_shape = mix_sample.shape
  160. # component samples [n, B, k, E]
  161. comp_samples = self.component_distribution.sample(sample_shape)
  162. # Gather along the k dimension
  163. mix_sample_r = mix_sample.reshape(
  164. mix_shape + torch.Size([1] * (len(es) + 1))
  165. )
  166. mix_sample_r = mix_sample_r.repeat(
  167. torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
  168. )
  169. samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
  170. return samples.squeeze(gather_dim)
  171. def _pad(self, x):
  172. return x.unsqueeze(-1 - self._event_ndims)
  173. def _pad_mixture_dimensions(self, x):
  174. dist_batch_ndims = len(self.batch_shape)
  175. cat_batch_ndims = len(self.mixture_distribution.batch_shape)
  176. pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
  177. xs = x.shape
  178. x = x.reshape(
  179. xs[:-1]
  180. + torch.Size(pad_ndims * [1])
  181. + xs[-1:]
  182. + torch.Size(self._event_ndims * [1])
  183. )
  184. return x
  185. def __repr__(self):
  186. args_string = (
  187. f"\n {self.mixture_distribution},\n {self.component_distribution}"
  188. )
  189. return "MixtureSameFamily" + "(" + args_string + ")"