distribution.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from typing_extensions import deprecated
  4. import torch
  5. from torch import Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.utils import lazy_property
  8. from torch.types import _size
  9. __all__ = ["Distribution"]
  10. class Distribution:
  11. r"""
  12. Distribution is the abstract base class for probability distributions.
  13. Args:
  14. batch_shape (torch.Size): The shape over which parameters are batched.
  15. event_shape (torch.Size): The shape of a single sample (without batching).
  16. validate_args (bool, optional): Whether to validate arguments. Default: None.
  17. """
  18. has_rsample = False
  19. has_enumerate_support = False
  20. _validate_args = __debug__
  21. @staticmethod
  22. def set_default_validate_args(value: bool) -> None:
  23. """
  24. Sets whether validation is enabled or disabled.
  25. The default behavior mimics Python's ``assert`` statement: validation
  26. is on by default, but is disabled if Python is run in optimized mode
  27. (via ``python -O``). Validation may be expensive, so you may want to
  28. disable it once a model is working.
  29. Args:
  30. value (bool): Whether to enable validation.
  31. """
  32. if value not in [True, False]:
  33. raise ValueError
  34. Distribution._validate_args = value
  35. def __init__(
  36. self,
  37. batch_shape: torch.Size = torch.Size(),
  38. event_shape: torch.Size = torch.Size(),
  39. validate_args: bool | None = None,
  40. ) -> None:
  41. self._batch_shape = batch_shape
  42. self._event_shape = event_shape
  43. if validate_args is not None:
  44. self._validate_args = validate_args
  45. if self._validate_args:
  46. try:
  47. arg_constraints = self.arg_constraints
  48. except NotImplementedError:
  49. arg_constraints = {}
  50. warnings.warn(
  51. f"{self.__class__} does not define `arg_constraints`. "
  52. + "Please set `arg_constraints = {}` or initialize the distribution "
  53. + "with `validate_args=False` to turn off validation.",
  54. stacklevel=2,
  55. )
  56. for param, constraint in arg_constraints.items():
  57. if constraints.is_dependent(constraint):
  58. continue # skip constraints that cannot be checked
  59. if param not in self.__dict__ and isinstance(
  60. getattr(type(self), param), lazy_property
  61. ):
  62. continue # skip checking lazily-constructed args
  63. value = getattr(self, param)
  64. valid = constraint.check(value)
  65. if not torch._is_all_true(valid):
  66. raise ValueError(
  67. f"Expected parameter {param} "
  68. f"({type(value).__name__} of shape {tuple(value.shape)}) "
  69. f"of distribution {repr(self)} "
  70. f"to satisfy the constraint {repr(constraint)}, "
  71. f"but found invalid values:\n{value}"
  72. )
  73. super().__init__()
  74. def expand(self, batch_shape: _size, _instance=None):
  75. """
  76. Returns a new distribution instance (or populates an existing instance
  77. provided by a derived class) with batch dimensions expanded to
  78. `batch_shape`. This method calls :class:`~torch.Tensor.expand` on
  79. the distribution's parameters. As such, this does not allocate new
  80. memory for the expanded distribution instance. Additionally,
  81. this does not repeat any args checking or parameter broadcasting in
  82. `__init__.py`, when an instance is first created.
  83. Args:
  84. batch_shape (torch.Size): the desired expanded size.
  85. _instance: new instance provided by subclasses that
  86. need to override `.expand`.
  87. Returns:
  88. New distribution instance with batch dimensions expanded to
  89. `batch_size`.
  90. """
  91. raise NotImplementedError
  92. @property
  93. def batch_shape(self) -> torch.Size:
  94. """
  95. Returns the shape over which parameters are batched.
  96. """
  97. return self._batch_shape
  98. @property
  99. def event_shape(self) -> torch.Size:
  100. """
  101. Returns the shape of a single sample (without batching).
  102. """
  103. return self._event_shape
  104. @property
  105. def arg_constraints(self) -> dict[str, constraints.Constraint]:
  106. """
  107. Returns a dictionary from argument names to
  108. :class:`~torch.distributions.constraints.Constraint` objects that
  109. should be satisfied by each argument of this distribution. Args that
  110. are not tensors need not appear in this dict.
  111. """
  112. raise NotImplementedError
  113. @property
  114. def support(self) -> constraints.Constraint | None:
  115. """
  116. Returns a :class:`~torch.distributions.constraints.Constraint` object
  117. representing this distribution's support.
  118. """
  119. raise NotImplementedError
  120. @property
  121. def mean(self) -> Tensor:
  122. """
  123. Returns the mean of the distribution.
  124. """
  125. raise NotImplementedError
  126. @property
  127. def mode(self) -> Tensor:
  128. """
  129. Returns the mode of the distribution.
  130. """
  131. raise NotImplementedError(f"{self.__class__} does not implement mode")
  132. @property
  133. def variance(self) -> Tensor:
  134. """
  135. Returns the variance of the distribution.
  136. """
  137. raise NotImplementedError
  138. @property
  139. def stddev(self) -> Tensor:
  140. """
  141. Returns the standard deviation of the distribution.
  142. """
  143. return self.variance.sqrt()
  144. def sample(self, sample_shape: _size = torch.Size()) -> Tensor:
  145. """
  146. Generates a sample_shape shaped sample or sample_shape shaped batch of
  147. samples if the distribution parameters are batched.
  148. """
  149. with torch.no_grad():
  150. return self.rsample(sample_shape)
  151. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  152. """
  153. Generates a sample_shape shaped reparameterized sample or sample_shape
  154. shaped batch of reparameterized samples if the distribution parameters
  155. are batched.
  156. """
  157. raise NotImplementedError
  158. @deprecated(
  159. "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.",
  160. category=FutureWarning,
  161. )
  162. def sample_n(self, n: int) -> Tensor:
  163. """
  164. Generates n samples or n batches of samples if the distribution
  165. parameters are batched.
  166. """
  167. return self.sample(torch.Size((n,)))
  168. def log_prob(self, value: Tensor) -> Tensor:
  169. """
  170. Returns the log of the probability density/mass function evaluated at
  171. `value`.
  172. Args:
  173. value (Tensor):
  174. """
  175. raise NotImplementedError
  176. def cdf(self, value: Tensor) -> Tensor:
  177. """
  178. Returns the cumulative density/mass function evaluated at
  179. `value`.
  180. Args:
  181. value (Tensor):
  182. """
  183. raise NotImplementedError
  184. def icdf(self, value: Tensor) -> Tensor:
  185. """
  186. Returns the inverse cumulative density/mass function evaluated at
  187. `value`.
  188. Args:
  189. value (Tensor):
  190. """
  191. raise NotImplementedError
  192. def enumerate_support(self, expand: bool = True) -> Tensor:
  193. """
  194. Returns tensor containing all values supported by a discrete
  195. distribution. The result will enumerate over dimension 0, so the shape
  196. of the result will be `(cardinality,) + batch_shape + event_shape`
  197. (where `event_shape = ()` for univariate distributions).
  198. Note that this enumerates over all batched tensors in lock-step
  199. `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
  200. along dim 0, but with the remaining batch dimensions being
  201. singleton dimensions, `[[0], [1], ..`.
  202. To iterate over the full Cartesian product use
  203. `itertools.product(m.enumerate_support())`.
  204. Args:
  205. expand (bool): whether to expand the support over the
  206. batch dims to match the distribution's `batch_shape`.
  207. Returns:
  208. Tensor iterating over dimension 0.
  209. """
  210. raise NotImplementedError
  211. def entropy(self) -> Tensor:
  212. """
  213. Returns entropy of distribution, batched over batch_shape.
  214. Returns:
  215. Tensor of shape batch_shape.
  216. """
  217. raise NotImplementedError
  218. def perplexity(self) -> Tensor:
  219. """
  220. Returns perplexity of distribution, batched over batch_shape.
  221. Returns:
  222. Tensor of shape batch_shape.
  223. """
  224. return torch.exp(self.entropy())
  225. def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size:
  226. """
  227. Returns the size of the sample returned by the distribution, given
  228. a `sample_shape`. Note, that the batch and event shapes of a distribution
  229. instance are fixed at the time of construction. If this is empty, the
  230. returned shape is upcast to (1,).
  231. Args:
  232. sample_shape (torch.Size): the size of the sample to be drawn.
  233. """
  234. if not isinstance(sample_shape, torch.Size):
  235. sample_shape = torch.Size(sample_shape)
  236. return torch.Size(sample_shape + self._batch_shape + self._event_shape)
  237. def _validate_sample(self, value: Tensor) -> None:
  238. """
  239. Argument validation for distribution methods such as `log_prob`,
  240. `cdf` and `icdf`. The rightmost dimensions of a value to be
  241. scored via these methods must agree with the distribution's batch
  242. and event shapes.
  243. Args:
  244. value (Tensor): the tensor whose log probability is to be
  245. computed by the `log_prob` method.
  246. Raises
  247. ValueError: when the rightmost dimensions of `value` do not match the
  248. distribution's batch and event shapes.
  249. """
  250. if not isinstance(value, torch.Tensor):
  251. raise ValueError("The value argument to log_prob must be a Tensor")
  252. event_dim_start = len(value.size()) - len(self._event_shape)
  253. if value.size()[event_dim_start:] != self._event_shape:
  254. raise ValueError(
  255. f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}."
  256. )
  257. actual_shape = value.size()
  258. expected_shape = self._batch_shape + self._event_shape
  259. for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
  260. if i != 1 and j != 1 and i != j:
  261. raise ValueError(
  262. f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}."
  263. )
  264. try:
  265. support = self.support
  266. except NotImplementedError:
  267. warnings.warn(
  268. f"{self.__class__} does not define `support` to enable "
  269. + "sample validation. Please initialize the distribution with "
  270. + "`validate_args=False` to turn off validation.",
  271. stacklevel=2,
  272. )
  273. return
  274. if support is None:
  275. raise AssertionError("support is unexpectedly None")
  276. valid = support.check(value)
  277. if not torch._is_all_true(valid):
  278. raise ValueError(
  279. "Expected value argument "
  280. f"({type(value).__name__} of shape {tuple(value.shape)}) "
  281. f"to be within the support ({repr(support)}) "
  282. f"of the distribution {repr(self)}, "
  283. f"but found invalid values:\n{value}"
  284. )
  285. def _get_checked_instance(self, cls, _instance=None):
  286. if _instance is None and type(self).__init__ != cls.__init__:
  287. raise NotImplementedError(
  288. f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method "
  289. "must also define a custom .expand() method."
  290. )
  291. return self.__new__(type(self)) if _instance is None else _instance
  292. def __repr__(self) -> str:
  293. param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
  294. args_string = ", ".join(
  295. [
  296. f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
  297. for p in param_names
  298. ]
  299. )
  300. return self.__class__.__name__ + "(" + args_string + ")"