independent.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # mypy: allow-untyped-defs
  2. from typing import Generic, TypeVar
  3. import torch
  4. from torch import Size, Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import _sum_rightmost
  8. from torch.types import _size
  9. __all__ = ["Independent"]
  10. D = TypeVar("D", bound=Distribution)
  11. class Independent(Distribution, Generic[D]):
  12. r"""
  13. Reinterprets some of the batch dims of a distribution as event dims.
  14. This is mainly useful for changing the shape of the result of
  15. :meth:`log_prob`. For example to create a diagonal Normal distribution with
  16. the same shape as a Multivariate Normal distribution (so they are
  17. interchangeable), you can::
  18. >>> from torch.distributions.multivariate_normal import MultivariateNormal
  19. >>> from torch.distributions.normal import Normal
  20. >>> loc = torch.zeros(3)
  21. >>> scale = torch.ones(3)
  22. >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
  23. >>> [mvn.batch_shape, mvn.event_shape]
  24. [torch.Size([]), torch.Size([3])]
  25. >>> normal = Normal(loc, scale)
  26. >>> [normal.batch_shape, normal.event_shape]
  27. [torch.Size([3]), torch.Size([])]
  28. >>> diagn = Independent(normal, 1)
  29. >>> [diagn.batch_shape, diagn.event_shape]
  30. [torch.Size([]), torch.Size([3])]
  31. Args:
  32. base_distribution (torch.distributions.distribution.Distribution): a
  33. base distribution
  34. reinterpreted_batch_ndims (int): the number of batch dims to
  35. reinterpret as event dims
  36. """
  37. arg_constraints: dict[str, constraints.Constraint] = {}
  38. base_dist: D
  39. def __init__(
  40. self,
  41. base_distribution: D,
  42. reinterpreted_batch_ndims: int,
  43. validate_args: bool | None = None,
  44. ) -> None:
  45. if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
  46. raise ValueError(
  47. "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
  48. f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}"
  49. )
  50. shape: Size = base_distribution.batch_shape + base_distribution.event_shape
  51. event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape)
  52. batch_shape = shape[: len(shape) - event_dim]
  53. event_shape = shape[len(shape) - event_dim :]
  54. self.base_dist = base_distribution
  55. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  56. # pyrefly: ignore [bad-argument-type]
  57. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  58. def expand(self, batch_shape, _instance=None):
  59. new = self._get_checked_instance(Independent, _instance)
  60. batch_shape = torch.Size(batch_shape)
  61. new.base_dist = self.base_dist.expand(
  62. batch_shape + self.event_shape[: self.reinterpreted_batch_ndims]
  63. )
  64. new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
  65. super(Independent, new).__init__(
  66. batch_shape, self.event_shape, validate_args=False
  67. )
  68. new._validate_args = self._validate_args
  69. return new
  70. @property
  71. def has_rsample(self) -> bool: # type: ignore[override]
  72. return self.base_dist.has_rsample
  73. @property
  74. def has_enumerate_support(self) -> bool: # type: ignore[override]
  75. if self.reinterpreted_batch_ndims > 0:
  76. return False
  77. return self.base_dist.has_enumerate_support
  78. @constraints.dependent_property
  79. # pyrefly: ignore [bad-override]
  80. def support(self):
  81. result = self.base_dist.support
  82. if self.reinterpreted_batch_ndims:
  83. result = constraints.independent(result, self.reinterpreted_batch_ndims)
  84. return result
  85. @property
  86. def mean(self) -> Tensor:
  87. return self.base_dist.mean
  88. @property
  89. def mode(self) -> Tensor:
  90. return self.base_dist.mode
  91. @property
  92. def variance(self) -> Tensor:
  93. return self.base_dist.variance
  94. def sample(self, sample_shape=torch.Size()) -> Tensor:
  95. return self.base_dist.sample(sample_shape)
  96. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  97. return self.base_dist.rsample(sample_shape)
  98. def log_prob(self, value):
  99. log_prob = self.base_dist.log_prob(value)
  100. return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
  101. def entropy(self):
  102. entropy = self.base_dist.entropy()
  103. return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
  104. def enumerate_support(self, expand=True):
  105. if self.reinterpreted_batch_ndims > 0:
  106. raise NotImplementedError(
  107. "Enumeration over cartesian product is not implemented"
  108. )
  109. return self.base_dist.enumerate_support(expand=expand)
  110. def __repr__(self):
  111. return (
  112. self.__class__.__name__
  113. + f"({self.base_dist}, {self.reinterpreted_batch_ndims})"
  114. )