exp_family.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.distributions.distribution import Distribution
  5. __all__ = ["ExponentialFamily"]
  6. class ExponentialFamily(Distribution):
  7. r"""
  8. ExponentialFamily is the abstract base class for probability distributions belonging to an
  9. exponential family, whose probability mass/density function has the form is defined below
  10. .. math::
  11. p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
  12. where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
  13. :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
  14. measure.
  15. Note:
  16. This class is an intermediary between the `Distribution` class and distributions which belong
  17. to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
  18. divergence methods. We use this class to compute the entropy and KL divergence using the AD
  19. framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
  20. Cross-entropies of Exponential Families).
  21. """
  22. @property
  23. def _natural_params(self) -> tuple[Tensor, ...]:
  24. """
  25. Abstract method for natural parameters. Returns a tuple of Tensors based
  26. on the distribution
  27. """
  28. raise NotImplementedError
  29. def _log_normalizer(self, *natural_params):
  30. """
  31. Abstract method for log normalizer function. Returns a log normalizer based on
  32. the distribution and input
  33. """
  34. raise NotImplementedError
  35. @property
  36. def _mean_carrier_measure(self) -> float:
  37. """
  38. Abstract method for expected carrier measure, which is required for computing
  39. entropy.
  40. """
  41. raise NotImplementedError
  42. def entropy(self):
  43. """
  44. Method to compute the entropy using Bregman divergence of the log normalizer.
  45. """
  46. result: Tensor | float = -self._mean_carrier_measure
  47. nparams = [p.detach().requires_grad_() for p in self._natural_params]
  48. lg_normal = self._log_normalizer(*nparams)
  49. gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
  50. result += lg_normal
  51. for np, g in zip(nparams, gradients):
  52. result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
  53. return result