__init__.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. r"""
  2. The ``distributions`` package contains parameterizable probability distributions
  3. and sampling functions. This allows the construction of stochastic computation
  4. graphs and stochastic gradient estimators for optimization. This package
  5. generally follows the design of the `TensorFlow Distributions`_ package.
  6. .. _`TensorFlow Distributions`:
  7. https://arxiv.org/abs/1711.10604
  8. It is not possible to directly backpropagate through random samples. However,
  9. there are two main methods for creating surrogate functions that can be
  10. backpropagated through. These are the score function estimator/likelihood ratio
  11. estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
  12. seen as the basis for policy gradient methods in reinforcement learning, and the
  13. pathwise derivative estimator is commonly seen in the reparameterization trick
  14. in variational autoencoders. Whilst the score function only requires the value
  15. of samples :math:`f(x)`, the pathwise derivative requires the derivative
  16. :math:`f'(x)`. The next sections discuss these two in a reinforcement learning
  17. example. For more details see
  18. `Gradient Estimation Using Stochastic Computation Graphs`_ .
  19. .. _`Gradient Estimation Using Stochastic Computation Graphs`:
  20. https://arxiv.org/abs/1506.05254
  21. Score function
  22. ^^^^^^^^^^^^^^
  23. When the probability density function is differentiable with respect to its
  24. parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
  25. :meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
  26. .. math::
  27. \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
  28. where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
  29. :math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
  30. taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
  31. In practice we would sample an action from the output of a network, apply this
  32. action in an environment, and then use ``log_prob`` to construct an equivalent
  33. loss function. Note that we use a negative because optimizers use gradient
  34. descent, whilst the rule above assumes gradient ascent. With a categorical
  35. policy, the code for implementing REINFORCE would be as follows::
  36. probs = policy_network(state)
  37. # Note that this is equivalent to what used to be called multinomial
  38. m = Categorical(probs)
  39. action = m.sample()
  40. next_state, reward = env.step(action)
  41. loss = -m.log_prob(action) * reward
  42. loss.backward()
  43. Pathwise derivative
  44. ^^^^^^^^^^^^^^^^^^^
  45. The other way to implement these stochastic/policy gradients would be to use the
  46. reparameterization trick from the
  47. :meth:`~torch.distributions.Distribution.rsample` method, where the
  48. parameterized random variable can be constructed via a parameterized
  49. deterministic function of a parameter-free random variable. The reparameterized
  50. sample therefore becomes differentiable. The code for implementing the pathwise
  51. derivative would be as follows::
  52. params = policy_network(state)
  53. m = Normal(*params)
  54. # Any distribution with .has_rsample == True could work based on the application
  55. action = m.rsample()
  56. next_state, reward = env.step(action) # Assuming that reward is differentiable
  57. loss = -reward
  58. loss.backward()
  59. """
  60. from . import transforms
  61. from .bernoulli import Bernoulli
  62. from .beta import Beta
  63. from .binomial import Binomial
  64. from .categorical import Categorical
  65. from .cauchy import Cauchy
  66. from .chi2 import Chi2
  67. from .constraint_registry import biject_to, transform_to
  68. from .continuous_bernoulli import ContinuousBernoulli
  69. from .dirichlet import Dirichlet
  70. from .distribution import Distribution
  71. from .exp_family import ExponentialFamily
  72. from .exponential import Exponential
  73. from .fishersnedecor import FisherSnedecor
  74. from .gamma import Gamma
  75. from .generalized_pareto import GeneralizedPareto
  76. from .geometric import Geometric
  77. from .gumbel import Gumbel
  78. from .half_cauchy import HalfCauchy
  79. from .half_normal import HalfNormal
  80. from .independent import Independent
  81. from .inverse_gamma import InverseGamma
  82. from .kl import _add_kl_info, kl_divergence, register_kl
  83. from .kumaraswamy import Kumaraswamy
  84. from .laplace import Laplace
  85. from .lkj_cholesky import LKJCholesky
  86. from .log_normal import LogNormal
  87. from .logistic_normal import LogisticNormal
  88. from .lowrank_multivariate_normal import LowRankMultivariateNormal
  89. from .mixture_same_family import MixtureSameFamily
  90. from .multinomial import Multinomial
  91. from .multivariate_normal import MultivariateNormal
  92. from .negative_binomial import NegativeBinomial
  93. from .normal import Normal
  94. from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
  95. from .pareto import Pareto
  96. from .poisson import Poisson
  97. from .relaxed_bernoulli import RelaxedBernoulli
  98. from .relaxed_categorical import RelaxedOneHotCategorical
  99. from .studentT import StudentT
  100. from .transformed_distribution import TransformedDistribution
  101. from .transforms import * # noqa: F403
  102. from .uniform import Uniform
  103. from .von_mises import VonMises
  104. from .weibull import Weibull
  105. from .wishart import Wishart
  106. _add_kl_info()
  107. del _add_kl_info
  108. __all__ = [
  109. "Bernoulli",
  110. "Beta",
  111. "Binomial",
  112. "Categorical",
  113. "Cauchy",
  114. "Chi2",
  115. "ContinuousBernoulli",
  116. "Dirichlet",
  117. "Distribution",
  118. "Exponential",
  119. "ExponentialFamily",
  120. "FisherSnedecor",
  121. "Gamma",
  122. "GeneralizedPareto",
  123. "Geometric",
  124. "Gumbel",
  125. "HalfCauchy",
  126. "HalfNormal",
  127. "Independent",
  128. "InverseGamma",
  129. "Kumaraswamy",
  130. "LKJCholesky",
  131. "Laplace",
  132. "LogNormal",
  133. "LogisticNormal",
  134. "LowRankMultivariateNormal",
  135. "MixtureSameFamily",
  136. "Multinomial",
  137. "MultivariateNormal",
  138. "NegativeBinomial",
  139. "Normal",
  140. "OneHotCategorical",
  141. "OneHotCategoricalStraightThrough",
  142. "Pareto",
  143. "RelaxedBernoulli",
  144. "RelaxedOneHotCategorical",
  145. "StudentT",
  146. "Poisson",
  147. "Uniform",
  148. "VonMises",
  149. "Weibull",
  150. "Wishart",
  151. "TransformedDistribution",
  152. "biject_to",
  153. "kl_divergence",
  154. "register_kl",
  155. "transform_to",
  156. ]
  157. __all__.extend(transforms.__all__)