constraint_registry.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # mypy: allow-untyped-defs
  2. r"""
  3. PyTorch provides two global :class:`ConstraintRegistry` objects that link
  4. :class:`~torch.distributions.constraints.Constraint` objects to
  5. :class:`~torch.distributions.transforms.Transform` objects. These objects both
  6. input constraints and return transforms, but they have different guarantees on
  7. bijectivity.
  8. 1. ``biject_to(constraint)`` looks up a bijective
  9. :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
  10. to the given ``constraint``. The returned transform is guaranteed to have
  11. ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
  12. 2. ``transform_to(constraint)`` looks up a not-necessarily bijective
  13. :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
  14. to the given ``constraint``. The returned transform is not guaranteed to
  15. implement ``.log_abs_det_jacobian()``.
  16. The ``transform_to()`` registry is useful for performing unconstrained
  17. optimization on constrained parameters of probability distributions, which are
  18. indicated by each distribution's ``.arg_constraints`` dict. These transforms often
  19. overparameterize a space in order to avoid rotation; they are thus more
  20. suitable for coordinate-wise optimization algorithms like Adam::
  21. loc = torch.zeros(100, requires_grad=True)
  22. unconstrained = torch.zeros(100, requires_grad=True)
  23. scale = transform_to(Normal.arg_constraints["scale"])(unconstrained)
  24. loss = -Normal(loc, scale).log_prob(data).sum()
  25. The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
  26. samples from a probability distribution with constrained ``.support`` are
  27. propagated in an unconstrained space, and algorithms are typically rotation
  28. invariant.::
  29. dist = Exponential(rate)
  30. unconstrained = torch.zeros(100, requires_grad=True)
  31. sample = biject_to(dist.support)(unconstrained)
  32. potential_energy = -dist.log_prob(sample).sum()
  33. .. note::
  34. An example where ``transform_to`` and ``biject_to`` differ is
  35. ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
  36. :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
  37. exponentiates and normalizes its inputs; this is a cheap and mostly
  38. coordinate-wise operation appropriate for algorithms like SVI. In
  39. contrast, ``biject_to(constraints.simplex)`` returns a
  40. :class:`~torch.distributions.transforms.StickBreakingTransform` that
  41. bijects its input down to a one-fewer-dimensional space; this a more
  42. expensive less numerically stable transform but is needed for algorithms
  43. like HMC.
  44. The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
  45. constraints and transforms using their ``.register()`` method either as a
  46. function on singleton constraints::
  47. transform_to.register(my_constraint, my_transform)
  48. or as a decorator on parameterized constraints::
  49. @transform_to.register(MyConstraintClass)
  50. def my_factory(constraint):
  51. assert isinstance(constraint, MyConstraintClass)
  52. return MyTransform(constraint.param1, constraint.param2)
  53. You can create your own registry by creating a new :class:`ConstraintRegistry`
  54. object.
  55. """
  56. from torch.distributions import constraints, transforms
  57. from torch.types import _Number
  58. __all__ = [
  59. "ConstraintRegistry",
  60. "biject_to",
  61. "transform_to",
  62. ]
  63. class ConstraintRegistry:
  64. """
  65. Registry to link constraints to transforms.
  66. """
  67. def __init__(self):
  68. self._registry = {}
  69. super().__init__()
  70. def register(self, constraint, factory=None):
  71. """
  72. Registers a :class:`~torch.distributions.constraints.Constraint`
  73. subclass in this registry. Usage::
  74. @my_registry.register(MyConstraintClass)
  75. def construct_transform(constraint):
  76. assert isinstance(constraint, MyConstraint)
  77. return MyTransform(constraint.arg_constraints)
  78. Args:
  79. constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
  80. A subclass of :class:`~torch.distributions.constraints.Constraint`, or
  81. a singleton object of the desired class.
  82. factory (Callable): A callable that inputs a constraint object and returns
  83. a :class:`~torch.distributions.transforms.Transform` object.
  84. """
  85. # Support use as decorator.
  86. if factory is None:
  87. return lambda factory: self.register(constraint, factory)
  88. # Support calling on singleton instances.
  89. if isinstance(constraint, constraints.Constraint):
  90. constraint = type(constraint)
  91. if not isinstance(constraint, type) or not issubclass(
  92. constraint, constraints.Constraint
  93. ):
  94. raise TypeError(
  95. f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
  96. )
  97. self._registry[constraint] = factory
  98. return factory
  99. def __call__(self, constraint):
  100. """
  101. Looks up a transform to constrained space, given a constraint object.
  102. Usage::
  103. constraint = Normal.arg_constraints["scale"]
  104. scale = transform_to(constraint)(torch.zeros(1)) # constrained
  105. u = transform_to(constraint).inv(scale) # unconstrained
  106. Args:
  107. constraint (:class:`~torch.distributions.constraints.Constraint`):
  108. A constraint object.
  109. Returns:
  110. A :class:`~torch.distributions.transforms.Transform` object.
  111. Raises:
  112. `NotImplementedError` if no transform has been registered.
  113. """
  114. # Look up by Constraint subclass.
  115. try:
  116. factory = self._registry[type(constraint)]
  117. except KeyError:
  118. raise NotImplementedError(
  119. f"Cannot transform {type(constraint).__name__} constraints"
  120. ) from None
  121. return factory(constraint)
  122. biject_to = ConstraintRegistry()
  123. transform_to = ConstraintRegistry()
  124. ################################################################################
  125. # Registration Table
  126. ################################################################################
  127. @biject_to.register(constraints.real)
  128. @transform_to.register(constraints.real)
  129. def _transform_to_real(constraint):
  130. return transforms.identity_transform
  131. @biject_to.register(constraints.independent)
  132. def _biject_to_independent(constraint):
  133. base_transform = biject_to(constraint.base_constraint)
  134. return transforms.IndependentTransform(
  135. base_transform, constraint.reinterpreted_batch_ndims
  136. )
  137. @transform_to.register(constraints.independent)
  138. def _transform_to_independent(constraint):
  139. base_transform = transform_to(constraint.base_constraint)
  140. return transforms.IndependentTransform(
  141. base_transform, constraint.reinterpreted_batch_ndims
  142. )
  143. @biject_to.register(constraints.positive)
  144. @biject_to.register(constraints.nonnegative)
  145. @transform_to.register(constraints.positive)
  146. @transform_to.register(constraints.nonnegative)
  147. def _transform_to_positive(constraint):
  148. return transforms.ExpTransform()
  149. @biject_to.register(constraints.greater_than)
  150. @biject_to.register(constraints.greater_than_eq)
  151. @transform_to.register(constraints.greater_than)
  152. @transform_to.register(constraints.greater_than_eq)
  153. def _transform_to_greater_than(constraint):
  154. return transforms.ComposeTransform(
  155. [
  156. transforms.ExpTransform(),
  157. transforms.AffineTransform(constraint.lower_bound, 1),
  158. ]
  159. )
  160. @biject_to.register(constraints.less_than)
  161. @transform_to.register(constraints.less_than)
  162. def _transform_to_less_than(constraint):
  163. return transforms.ComposeTransform(
  164. [
  165. transforms.ExpTransform(),
  166. transforms.AffineTransform(constraint.upper_bound, -1),
  167. ]
  168. )
  169. @biject_to.register(constraints.interval)
  170. @biject_to.register(constraints.half_open_interval)
  171. @transform_to.register(constraints.interval)
  172. @transform_to.register(constraints.half_open_interval)
  173. def _transform_to_interval(constraint):
  174. # Handle the special case of the unit interval.
  175. lower_is_0 = (
  176. isinstance(constraint.lower_bound, _Number) and constraint.lower_bound == 0
  177. )
  178. upper_is_1 = (
  179. isinstance(constraint.upper_bound, _Number) and constraint.upper_bound == 1
  180. )
  181. if lower_is_0 and upper_is_1:
  182. return transforms.SigmoidTransform()
  183. loc = constraint.lower_bound
  184. scale = constraint.upper_bound - constraint.lower_bound
  185. return transforms.ComposeTransform(
  186. [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]
  187. )
  188. @biject_to.register(constraints.simplex)
  189. def _biject_to_simplex(constraint):
  190. return transforms.StickBreakingTransform()
  191. @transform_to.register(constraints.simplex)
  192. def _transform_to_simplex(constraint):
  193. return transforms.SoftmaxTransform()
  194. # TODO define a bijection for LowerCholeskyTransform
  195. @transform_to.register(constraints.lower_cholesky)
  196. def _transform_to_lower_cholesky(constraint):
  197. return transforms.LowerCholeskyTransform()
  198. @transform_to.register(constraints.positive_definite)
  199. @transform_to.register(constraints.positive_semidefinite)
  200. def _transform_to_positive_definite(constraint):
  201. return transforms.PositiveDefiniteTransform()
  202. @biject_to.register(constraints.corr_cholesky)
  203. @transform_to.register(constraints.corr_cholesky)
  204. def _transform_to_corr_cholesky(constraint):
  205. return transforms.CorrCholeskyTransform()
  206. @biject_to.register(constraints.cat)
  207. def _biject_to_cat(constraint):
  208. return transforms.CatTransform(
  209. [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
  210. )
  211. @transform_to.register(constraints.cat)
  212. def _transform_to_cat(constraint):
  213. return transforms.CatTransform(
  214. [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
  215. )
  216. @biject_to.register(constraints.stack)
  217. def _biject_to_stack(constraint):
  218. return transforms.StackTransform(
  219. [biject_to(c) for c in constraint.cseq], constraint.dim
  220. )
  221. @transform_to.register(constraints.stack)
  222. def _transform_to_stack(constraint):
  223. return transforms.StackTransform(
  224. [transform_to(c) for c in constraint.cseq], constraint.dim
  225. )