transformed_distribution.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.distributions import constraints
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.independent import Independent
  7. from torch.distributions.transforms import ComposeTransform, Transform
  8. from torch.distributions.utils import _sum_rightmost
  9. from torch.types import _size
  10. __all__ = ["TransformedDistribution"]
  11. class TransformedDistribution(Distribution):
  12. r"""
  13. Extension of the Distribution class, which applies a sequence of Transforms
  14. to a base distribution. Let f be the composition of transforms applied::
  15. X ~ BaseDistribution
  16. Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
  17. log p(Y) = log p(X) + log |det (dX/dY)|
  18. Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
  19. maximum shape of its base distribution and its transforms, since transforms
  20. can introduce correlations among events.
  21. An example for the usage of :class:`TransformedDistribution` would be::
  22. # Building a Logistic Distribution
  23. # X ~ Uniform(0, 1)
  24. # f = a + b * logit(X)
  25. # Y ~ f(X) ~ Logistic(a, b)
  26. base_distribution = Uniform(0, 1)
  27. transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
  28. logistic = TransformedDistribution(base_distribution, transforms)
  29. For more examples, please look at the implementations of
  30. :class:`~torch.distributions.gumbel.Gumbel`,
  31. :class:`~torch.distributions.half_cauchy.HalfCauchy`,
  32. :class:`~torch.distributions.half_normal.HalfNormal`,
  33. :class:`~torch.distributions.log_normal.LogNormal`,
  34. :class:`~torch.distributions.pareto.Pareto`,
  35. :class:`~torch.distributions.weibull.Weibull`,
  36. :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
  37. :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
  38. """
  39. arg_constraints: dict[str, constraints.Constraint] = {}
  40. def __init__(
  41. self,
  42. base_distribution: Distribution,
  43. transforms: Transform | list[Transform],
  44. validate_args: bool | None = None,
  45. ) -> None:
  46. if isinstance(transforms, Transform):
  47. self.transforms = [
  48. transforms,
  49. ]
  50. elif isinstance(transforms, list):
  51. if not all(isinstance(t, Transform) for t in transforms):
  52. raise ValueError(
  53. "transforms must be a Transform or a list of Transforms"
  54. )
  55. self.transforms = transforms
  56. else:
  57. raise ValueError(
  58. f"transforms must be a Transform or list, but was {transforms}"
  59. )
  60. # Reshape base_distribution according to transforms.
  61. base_shape = base_distribution.batch_shape + base_distribution.event_shape
  62. base_event_dim = len(base_distribution.event_shape)
  63. transform = ComposeTransform(self.transforms)
  64. if len(base_shape) < transform.domain.event_dim:
  65. raise ValueError(
  66. f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}."
  67. )
  68. forward_shape = transform.forward_shape(base_shape)
  69. expanded_base_shape = transform.inverse_shape(forward_shape)
  70. if base_shape != expanded_base_shape:
  71. base_batch_shape = expanded_base_shape[
  72. : len(expanded_base_shape) - base_event_dim
  73. ]
  74. base_distribution = base_distribution.expand(base_batch_shape)
  75. reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
  76. if reinterpreted_batch_ndims > 0:
  77. base_distribution = Independent(
  78. base_distribution, reinterpreted_batch_ndims
  79. )
  80. self.base_dist = base_distribution
  81. # Compute shapes.
  82. transform_change_in_event_dim = (
  83. transform.codomain.event_dim - transform.domain.event_dim
  84. )
  85. event_dim = max(
  86. transform.codomain.event_dim, # the transform is coupled
  87. base_event_dim + transform_change_in_event_dim, # the base dist is coupled
  88. )
  89. if len(forward_shape) < event_dim:
  90. raise AssertionError(
  91. f"forward_shape length {len(forward_shape)} must be >= event_dim {event_dim}"
  92. )
  93. cut = len(forward_shape) - event_dim
  94. batch_shape = forward_shape[:cut]
  95. event_shape = forward_shape[cut:]
  96. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  97. def expand(self, batch_shape, _instance=None):
  98. new = self._get_checked_instance(TransformedDistribution, _instance)
  99. batch_shape = torch.Size(batch_shape)
  100. shape = batch_shape + self.event_shape
  101. for t in reversed(self.transforms):
  102. shape = t.inverse_shape(shape)
  103. base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)]
  104. new.base_dist = self.base_dist.expand(base_batch_shape)
  105. new.transforms = self.transforms
  106. super(TransformedDistribution, new).__init__(
  107. batch_shape, self.event_shape, validate_args=False
  108. )
  109. new._validate_args = self._validate_args
  110. return new
  111. @constraints.dependent_property(is_discrete=False)
  112. # pyrefly: ignore [bad-override]
  113. def support(self):
  114. if not self.transforms:
  115. return self.base_dist.support
  116. support = self.transforms[-1].codomain
  117. if len(self.event_shape) > support.event_dim:
  118. support = constraints.independent(
  119. support, len(self.event_shape) - support.event_dim
  120. )
  121. return support
  122. @property
  123. def has_rsample(self) -> bool: # type: ignore[override]
  124. return self.base_dist.has_rsample
  125. def sample(self, sample_shape=torch.Size()):
  126. """
  127. Generates a sample_shape shaped sample or sample_shape shaped batch of
  128. samples if the distribution parameters are batched. Samples first from
  129. base distribution and applies `transform()` for every transform in the
  130. list.
  131. """
  132. with torch.no_grad():
  133. x = self.base_dist.sample(sample_shape)
  134. for transform in self.transforms:
  135. x = transform(x)
  136. return x
  137. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  138. """
  139. Generates a sample_shape shaped reparameterized sample or sample_shape
  140. shaped batch of reparameterized samples if the distribution parameters
  141. are batched. Samples first from base distribution and applies
  142. `transform()` for every transform in the list.
  143. """
  144. x = self.base_dist.rsample(sample_shape)
  145. for transform in self.transforms:
  146. x = transform(x)
  147. return x
  148. def log_prob(self, value):
  149. """
  150. Scores the sample by inverting the transform(s) and computing the score
  151. using the score of the base distribution and the log abs det jacobian.
  152. """
  153. if self._validate_args:
  154. self._validate_sample(value)
  155. event_dim = len(self.event_shape)
  156. log_prob: Tensor | float = 0.0
  157. y = value
  158. for transform in reversed(self.transforms):
  159. x = transform.inv(y)
  160. event_dim += transform.domain.event_dim - transform.codomain.event_dim
  161. log_prob = log_prob - _sum_rightmost(
  162. transform.log_abs_det_jacobian(x, y),
  163. event_dim - transform.domain.event_dim,
  164. )
  165. y = x
  166. log_prob = log_prob + _sum_rightmost(
  167. self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
  168. )
  169. return log_prob
  170. def _monotonize_cdf(self, value):
  171. """
  172. This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
  173. monotone increasing.
  174. """
  175. sign = 1
  176. for transform in self.transforms:
  177. sign = sign * transform.sign
  178. if isinstance(sign, int) and sign == 1:
  179. return value
  180. return sign * (value - 0.5) + 0.5
  181. def cdf(self, value):
  182. """
  183. Computes the cumulative distribution function by inverting the
  184. transform(s) and computing the score of the base distribution.
  185. """
  186. for transform in self.transforms[::-1]:
  187. value = transform.inv(value)
  188. if self._validate_args:
  189. self.base_dist._validate_sample(value)
  190. value = self.base_dist.cdf(value)
  191. value = self._monotonize_cdf(value)
  192. return value
  193. def icdf(self, value):
  194. """
  195. Computes the inverse cumulative distribution function using
  196. transform(s) and computing the score of the base distribution.
  197. """
  198. value = self._monotonize_cdf(value)
  199. value = self.base_dist.icdf(value)
  200. for transform in self.transforms:
  201. value = transform(value)
  202. return value