constraints.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Callable
  3. from typing import Any
  4. r"""
  5. The following constraints are implemented:
  6. - ``constraints.boolean``
  7. - ``constraints.cat``
  8. - ``constraints.corr_cholesky``
  9. - ``constraints.dependent``
  10. - ``constraints.greater_than(lower_bound)``
  11. - ``constraints.greater_than_eq(lower_bound)``
  12. - ``constraints.independent(constraint, reinterpreted_batch_ndims)``
  13. - ``constraints.integer_interval(lower_bound, upper_bound)``
  14. - ``constraints.interval(lower_bound, upper_bound)``
  15. - ``constraints.less_than(upper_bound)``
  16. - ``constraints.lower_cholesky``
  17. - ``constraints.lower_triangular``
  18. - ``constraints.MixtureSameFamilyConstraint(base_constraint)``
  19. - ``constraints.multinomial``
  20. - ``constraints.nonnegative``
  21. - ``constraints.nonnegative_integer``
  22. - ``constraints.one_hot``
  23. - ``constraints.positive_integer``
  24. - ``constraints.positive``
  25. - ``constraints.positive_semidefinite``
  26. - ``constraints.positive_definite``
  27. - ``constraints.real_vector``
  28. - ``constraints.real``
  29. - ``constraints.simplex``
  30. - ``constraints.symmetric``
  31. - ``constraints.stack``
  32. - ``constraints.square``
  33. - ``constraints.symmetric``
  34. - ``constraints.unit_interval``
  35. """
  36. import torch
  37. __all__ = [
  38. "Constraint",
  39. "boolean",
  40. "cat",
  41. "corr_cholesky",
  42. "dependent",
  43. "dependent_property",
  44. "greater_than",
  45. "greater_than_eq",
  46. "independent",
  47. "integer_interval",
  48. "interval",
  49. "half_open_interval",
  50. "is_dependent",
  51. "less_than",
  52. "lower_cholesky",
  53. "lower_triangular",
  54. "MixtureSameFamilyConstraint",
  55. "multinomial",
  56. "nonnegative",
  57. "nonnegative_integer",
  58. "one_hot",
  59. "positive",
  60. "positive_semidefinite",
  61. "positive_definite",
  62. "positive_integer",
  63. "real",
  64. "real_vector",
  65. "simplex",
  66. "square",
  67. "stack",
  68. "symmetric",
  69. "unit_interval",
  70. ]
  71. class Constraint:
  72. """
  73. Abstract base class for constraints.
  74. A constraint object represents a region over which a variable is valid,
  75. e.g. within which a variable can be optimized.
  76. Attributes:
  77. is_discrete (bool): Whether constrained space is discrete.
  78. Defaults to False.
  79. event_dim (int): Number of rightmost dimensions that together define
  80. an event. The :meth:`check` method will remove this many dimensions
  81. when computing validity.
  82. """
  83. is_discrete = False # Default to continuous.
  84. event_dim = 0 # Default to univariate.
  85. def check(self, value):
  86. """
  87. Returns a byte tensor of ``sample_shape + batch_shape`` indicating
  88. whether each event in value satisfies this constraint.
  89. """
  90. raise NotImplementedError
  91. def __repr__(self):
  92. return self.__class__.__name__[1:] + "()"
  93. class _Dependent(Constraint):
  94. """
  95. Placeholder for variables whose support depends on other variables.
  96. These variables obey no simple coordinate-wise constraints.
  97. Args:
  98. is_discrete (bool): Optional value of ``.is_discrete`` in case this
  99. can be computed statically. If not provided, access to the
  100. ``.is_discrete`` attribute will raise a NotImplementedError.
  101. event_dim (int): Optional value of ``.event_dim`` in case this
  102. can be computed statically. If not provided, access to the
  103. ``.event_dim`` attribute will raise a NotImplementedError.
  104. """
  105. def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
  106. self._is_discrete = is_discrete
  107. self._event_dim = event_dim
  108. super().__init__()
  109. @property
  110. def is_discrete(self) -> bool: # type: ignore[override]
  111. if self._is_discrete is NotImplemented:
  112. raise NotImplementedError(".is_discrete cannot be determined statically")
  113. return self._is_discrete
  114. @property
  115. def event_dim(self) -> int: # type: ignore[override]
  116. if self._event_dim is NotImplemented:
  117. raise NotImplementedError(".event_dim cannot be determined statically")
  118. return self._event_dim
  119. def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
  120. """
  121. Support for syntax to customize static attributes::
  122. constraints.dependent(is_discrete=True, event_dim=1)
  123. """
  124. if is_discrete is NotImplemented:
  125. is_discrete = self._is_discrete
  126. if event_dim is NotImplemented:
  127. event_dim = self._event_dim
  128. return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
  129. def check(self, x):
  130. raise ValueError("Cannot determine validity of dependent constraint")
  131. def is_dependent(constraint):
  132. """
  133. Checks if ``constraint`` is a ``_Dependent`` object.
  134. Args:
  135. constraint : A ``Constraint`` object.
  136. Returns:
  137. ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise.
  138. Examples:
  139. >>> import torch
  140. >>> from torch.distributions import Bernoulli
  141. >>> from torch.distributions.constraints import is_dependent
  142. >>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True))
  143. >>> constraint1 = dist.arg_constraints["probs"]
  144. >>> constraint2 = dist.arg_constraints["logits"]
  145. >>> for constraint in [constraint1, constraint2]:
  146. >>> if is_dependent(constraint):
  147. >>> continue
  148. """
  149. return isinstance(constraint, _Dependent)
  150. class _DependentProperty(property, _Dependent):
  151. """
  152. Decorator that extends @property to act like a `Dependent` constraint when
  153. called on a class and act like a property when called on an object.
  154. Example::
  155. class Uniform(Distribution):
  156. def __init__(self, low, high):
  157. self.low = low
  158. self.high = high
  159. @constraints.dependent_property(is_discrete=False, event_dim=0)
  160. def support(self):
  161. return constraints.interval(self.low, self.high)
  162. Args:
  163. fn (Callable): The function to be decorated.
  164. is_discrete (bool): Optional value of ``.is_discrete`` in case this
  165. can be computed statically. If not provided, access to the
  166. ``.is_discrete`` attribute will raise a NotImplementedError.
  167. event_dim (int): Optional value of ``.event_dim`` in case this
  168. can be computed statically. If not provided, access to the
  169. ``.event_dim`` attribute will raise a NotImplementedError.
  170. """
  171. def __init__(
  172. self,
  173. fn: Callable[..., Any] | None = None,
  174. *,
  175. is_discrete: bool | None = NotImplemented,
  176. event_dim: int | None = NotImplemented,
  177. ) -> None:
  178. super().__init__(fn)
  179. self._is_discrete = is_discrete
  180. self._event_dim = event_dim
  181. def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": # type: ignore[override]
  182. """
  183. Support for syntax to customize static attributes::
  184. @constraints.dependent_property(is_discrete=True, event_dim=1)
  185. def support(self): ...
  186. """
  187. return _DependentProperty(
  188. fn, is_discrete=self._is_discrete, event_dim=self._event_dim
  189. )
  190. class _IndependentConstraint(Constraint):
  191. """
  192. Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
  193. dims in :meth:`check`, so that an event is valid only if all its
  194. independent entries are valid.
  195. """
  196. def __init__(self, base_constraint, reinterpreted_batch_ndims):
  197. if not isinstance(base_constraint, Constraint):
  198. raise AssertionError(
  199. f"base_constraint must be a Constraint, got {type(base_constraint).__name__}"
  200. )
  201. if not isinstance(reinterpreted_batch_ndims, int):
  202. raise AssertionError(
  203. f"reinterpreted_batch_ndims must be an int, got {type(reinterpreted_batch_ndims).__name__}"
  204. )
  205. if reinterpreted_batch_ndims < 0:
  206. raise AssertionError(
  207. f"reinterpreted_batch_ndims must be >= 0, got {reinterpreted_batch_ndims}"
  208. )
  209. self.base_constraint = base_constraint
  210. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  211. super().__init__()
  212. @property
  213. def is_discrete(self) -> bool: # type: ignore[override]
  214. return self.base_constraint.is_discrete
  215. @property
  216. def event_dim(self) -> int: # type: ignore[override]
  217. return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
  218. def check(self, value):
  219. result = self.base_constraint.check(value)
  220. if result.dim() < self.reinterpreted_batch_ndims:
  221. expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
  222. raise ValueError(
  223. f"Expected value.dim() >= {expected} but got {value.dim()}"
  224. )
  225. result = result.reshape(
  226. result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
  227. )
  228. result = result.all(-1)
  229. return result
  230. def __repr__(self):
  231. return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
  232. class MixtureSameFamilyConstraint(Constraint):
  233. """
  234. Constraint for the :class:`~torch.distribution.MixtureSameFamily`
  235. distribution that adds back the rightmost batch dimension before
  236. performing the validity check with the component distribution
  237. constraint.
  238. Args:
  239. base_constraint: The ``Constraint`` object of
  240. the component distribution of
  241. the :class:`~torch.distribution.MixtureSameFamily` distribution.
  242. """
  243. def __init__(self, base_constraint):
  244. if not isinstance(base_constraint, Constraint):
  245. raise AssertionError(
  246. f"base_constraint must be a Constraint, got {type(base_constraint).__name__}"
  247. )
  248. self.base_constraint = base_constraint
  249. super().__init__()
  250. @property
  251. def is_discrete(self) -> bool: # type: ignore[override]
  252. return self.base_constraint.is_discrete
  253. @property
  254. def event_dim(self) -> int: # type: ignore[override]
  255. return self.base_constraint.event_dim
  256. def check(self, value):
  257. """
  258. Check validity of ``value`` as a possible outcome of sampling
  259. the :class:`~torch.distribution.MixtureSameFamily` distribution.
  260. """
  261. unsqueezed_value = value.unsqueeze(-1 - self.event_dim)
  262. result = self.base_constraint.check(unsqueezed_value)
  263. if value.dim() < self.event_dim:
  264. raise ValueError(
  265. f"Expected value.dim() >= {self.event_dim} but got {value.dim()}"
  266. )
  267. num_dim_to_keep = value.dim() - self.event_dim
  268. result = result.reshape(result.shape[:num_dim_to_keep] + (-1,))
  269. result = result.all(-1)
  270. return result
  271. def __repr__(self):
  272. return f"{self.__class__.__name__}({repr(self.base_constraint)})"
  273. class _Boolean(Constraint):
  274. """
  275. Constrain to the two values `{0, 1}`.
  276. """
  277. is_discrete = True
  278. def check(self, value):
  279. return (value == 0) | (value == 1)
  280. class _OneHot(Constraint):
  281. """
  282. Constrain to one-hot vectors.
  283. """
  284. is_discrete = True
  285. event_dim = 1
  286. def check(self, value):
  287. is_boolean = (value == 0) | (value == 1)
  288. is_normalized = value.sum(-1).eq(1)
  289. return is_boolean.all(-1) & is_normalized
  290. class _IntegerInterval(Constraint):
  291. """
  292. Constrain to an integer interval `[lower_bound, upper_bound]`.
  293. """
  294. is_discrete = True
  295. def __init__(self, lower_bound, upper_bound):
  296. self.lower_bound = lower_bound
  297. self.upper_bound = upper_bound
  298. super().__init__()
  299. def check(self, value):
  300. return (
  301. (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
  302. )
  303. def __repr__(self):
  304. fmt_string = self.__class__.__name__[1:]
  305. fmt_string += (
  306. f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
  307. )
  308. return fmt_string
  309. class _IntegerLessThan(Constraint):
  310. """
  311. Constrain to an integer interval `(-inf, upper_bound]`.
  312. """
  313. is_discrete = True
  314. def __init__(self, upper_bound):
  315. self.upper_bound = upper_bound
  316. super().__init__()
  317. def check(self, value):
  318. return (value % 1 == 0) & (value <= self.upper_bound)
  319. def __repr__(self):
  320. fmt_string = self.__class__.__name__[1:]
  321. fmt_string += f"(upper_bound={self.upper_bound})"
  322. return fmt_string
  323. class _IntegerGreaterThan(Constraint):
  324. """
  325. Constrain to an integer interval `[lower_bound, inf)`.
  326. """
  327. is_discrete = True
  328. def __init__(self, lower_bound):
  329. self.lower_bound = lower_bound
  330. super().__init__()
  331. def check(self, value):
  332. return (value % 1 == 0) & (value >= self.lower_bound)
  333. def __repr__(self):
  334. fmt_string = self.__class__.__name__[1:]
  335. fmt_string += f"(lower_bound={self.lower_bound})"
  336. return fmt_string
  337. class _Real(Constraint):
  338. """
  339. Trivially constrain to the extended real line `[-inf, inf]`.
  340. """
  341. def check(self, value):
  342. return value == value # False for NANs.
  343. class _GreaterThan(Constraint):
  344. """
  345. Constrain to a real half line `(lower_bound, inf]`.
  346. """
  347. def __init__(self, lower_bound):
  348. self.lower_bound = lower_bound
  349. super().__init__()
  350. def check(self, value):
  351. return self.lower_bound < value
  352. def __repr__(self):
  353. fmt_string = self.__class__.__name__[1:]
  354. fmt_string += f"(lower_bound={self.lower_bound})"
  355. return fmt_string
  356. class _GreaterThanEq(Constraint):
  357. """
  358. Constrain to a real half line `[lower_bound, inf)`.
  359. """
  360. def __init__(self, lower_bound):
  361. self.lower_bound = lower_bound
  362. super().__init__()
  363. def check(self, value):
  364. return self.lower_bound <= value
  365. def __repr__(self):
  366. fmt_string = self.__class__.__name__[1:]
  367. fmt_string += f"(lower_bound={self.lower_bound})"
  368. return fmt_string
  369. class _LessThan(Constraint):
  370. """
  371. Constrain to a real half line `[-inf, upper_bound)`.
  372. """
  373. def __init__(self, upper_bound):
  374. self.upper_bound = upper_bound
  375. super().__init__()
  376. def check(self, value):
  377. return value < self.upper_bound
  378. def __repr__(self):
  379. fmt_string = self.__class__.__name__[1:]
  380. fmt_string += f"(upper_bound={self.upper_bound})"
  381. return fmt_string
  382. class _Interval(Constraint):
  383. """
  384. Constrain to a real interval `[lower_bound, upper_bound]`.
  385. """
  386. def __init__(self, lower_bound, upper_bound):
  387. self.lower_bound = lower_bound
  388. self.upper_bound = upper_bound
  389. super().__init__()
  390. def check(self, value):
  391. return (self.lower_bound <= value) & (value <= self.upper_bound)
  392. def __repr__(self):
  393. fmt_string = self.__class__.__name__[1:]
  394. fmt_string += (
  395. f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
  396. )
  397. return fmt_string
  398. class _HalfOpenInterval(Constraint):
  399. """
  400. Constrain to a real interval `[lower_bound, upper_bound)`.
  401. """
  402. def __init__(self, lower_bound, upper_bound):
  403. self.lower_bound = lower_bound
  404. self.upper_bound = upper_bound
  405. super().__init__()
  406. def check(self, value):
  407. return (self.lower_bound <= value) & (value < self.upper_bound)
  408. def __repr__(self):
  409. fmt_string = self.__class__.__name__[1:]
  410. fmt_string += (
  411. f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
  412. )
  413. return fmt_string
  414. class _Simplex(Constraint):
  415. """
  416. Constrain to the unit simplex in the innermost (rightmost) dimension.
  417. Specifically: `x >= 0` and `x.sum(-1) == 1`.
  418. """
  419. event_dim = 1
  420. def check(self, value):
  421. return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
  422. class _Multinomial(Constraint):
  423. """
  424. Constrain to nonnegative integer values summing to at most an upper bound.
  425. Note due to limitations of the Multinomial distribution, this currently
  426. checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future
  427. this may be strengthened to ``value.sum(-1) == upper_bound``.
  428. """
  429. is_discrete = True
  430. event_dim = 1
  431. def __init__(self, upper_bound):
  432. self.upper_bound = upper_bound
  433. def check(self, x):
  434. return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)
  435. class _LowerTriangular(Constraint):
  436. """
  437. Constrain to lower-triangular square matrices.
  438. """
  439. event_dim = 2
  440. def check(self, value):
  441. value_tril = value.tril()
  442. return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
  443. class _LowerCholesky(Constraint):
  444. """
  445. Constrain to lower-triangular square matrices with positive diagonals.
  446. """
  447. event_dim = 2
  448. def check(self, value):
  449. value_tril = value.tril()
  450. lower_triangular = (
  451. (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
  452. )
  453. positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
  454. return lower_triangular & positive_diagonal
  455. class _CorrCholesky(Constraint):
  456. """
  457. Constrain to lower-triangular square matrices with positive diagonals and each
  458. row vector being of unit length.
  459. """
  460. event_dim = 2
  461. def check(self, value):
  462. tol = (
  463. torch.finfo(value.dtype).eps * value.size(-1) * 10
  464. ) # 10 is an adjustable fudge factor
  465. row_norm = torch.linalg.norm(value.detach(), dim=-1)
  466. unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1)
  467. return _LowerCholesky().check(value) & unit_row_norm
  468. class _Square(Constraint):
  469. """
  470. Constrain to square matrices.
  471. """
  472. event_dim = 2
  473. def check(self, value):
  474. return torch.full(
  475. size=value.shape[:-2],
  476. fill_value=(value.shape[-2] == value.shape[-1]),
  477. dtype=torch.bool,
  478. device=value.device,
  479. )
  480. class _Symmetric(_Square):
  481. """
  482. Constrain to Symmetric square matrices.
  483. """
  484. def check(self, value):
  485. square_check = super().check(value)
  486. if not square_check.all():
  487. return square_check
  488. return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
  489. class _PositiveSemidefinite(_Symmetric):
  490. """
  491. Constrain to positive-semidefinite matrices.
  492. """
  493. def check(self, value):
  494. sym_check = super().check(value)
  495. if not sym_check.all():
  496. return sym_check
  497. return torch.linalg.eigvalsh(value).ge(0).all(-1)
  498. class _PositiveDefinite(_Symmetric):
  499. """
  500. Constrain to positive-definite matrices.
  501. """
  502. def check(self, value):
  503. sym_check = super().check(value)
  504. if not sym_check.all():
  505. return sym_check
  506. return torch.linalg.cholesky_ex(value).info.eq(0)
  507. class _Cat(Constraint):
  508. """
  509. Constraint functor that applies a sequence of constraints
  510. `cseq` at the submatrices at dimension `dim`,
  511. each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
  512. """
  513. def __init__(self, cseq, dim=0, lengths=None):
  514. if not all(isinstance(c, Constraint) for c in cseq):
  515. raise AssertionError("All elements of cseq must be Constraint instances")
  516. self.cseq = list(cseq)
  517. if lengths is None:
  518. lengths = [1] * len(self.cseq)
  519. self.lengths = list(lengths)
  520. if len(self.lengths) != len(self.cseq):
  521. raise AssertionError(
  522. f"lengths ({len(self.lengths)}) must match cseq ({len(self.cseq)})"
  523. )
  524. self.dim = dim
  525. super().__init__()
  526. @property
  527. def is_discrete(self) -> bool: # type: ignore[override]
  528. return any(c.is_discrete for c in self.cseq)
  529. @property
  530. def event_dim(self) -> int: # type: ignore[override]
  531. return max(c.event_dim for c in self.cseq)
  532. def check(self, value):
  533. if not (-value.dim() <= self.dim < value.dim()):
  534. raise AssertionError(
  535. f"dim {self.dim} out of range for value with {value.dim()} dimensions"
  536. )
  537. checks = []
  538. start = 0
  539. for constr, length in zip(self.cseq, self.lengths):
  540. v = value.narrow(self.dim, start, length)
  541. checks.append(constr.check(v))
  542. start = start + length # avoid += for jit compat
  543. return torch.cat(checks, self.dim)
  544. class _Stack(Constraint):
  545. """
  546. Constraint functor that applies a sequence of constraints
  547. `cseq` at the submatrices at dimension `dim`,
  548. in a way compatible with :func:`torch.stack`.
  549. """
  550. def __init__(self, cseq, dim=0):
  551. if not all(isinstance(c, Constraint) for c in cseq):
  552. raise AssertionError("All elements of cseq must be Constraint instances")
  553. self.cseq = list(cseq)
  554. self.dim = dim
  555. super().__init__()
  556. @property
  557. def is_discrete(self) -> bool: # type: ignore[override]
  558. return any(c.is_discrete for c in self.cseq)
  559. @property
  560. def event_dim(self) -> int: # type: ignore[override]
  561. dim = max(c.event_dim for c in self.cseq)
  562. if self.dim + dim < 0:
  563. dim += 1
  564. return dim
  565. def check(self, value):
  566. if not (-value.dim() <= self.dim < value.dim()):
  567. raise AssertionError(
  568. f"dim {self.dim} out of range for value with {value.dim()} dimensions"
  569. )
  570. vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
  571. return torch.stack(
  572. [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim
  573. )
  574. # Public interface.
  575. dependent = _Dependent()
  576. dependent_property = _DependentProperty
  577. independent = _IndependentConstraint
  578. boolean = _Boolean()
  579. one_hot = _OneHot()
  580. nonnegative_integer = _IntegerGreaterThan(0)
  581. positive_integer = _IntegerGreaterThan(1)
  582. integer_interval = _IntegerInterval
  583. real = _Real()
  584. real_vector = independent(real, 1)
  585. positive = _GreaterThan(0.0)
  586. nonnegative = _GreaterThanEq(0.0)
  587. greater_than = _GreaterThan
  588. greater_than_eq = _GreaterThanEq
  589. less_than = _LessThan
  590. multinomial = _Multinomial
  591. unit_interval = _Interval(0.0, 1.0)
  592. interval = _Interval
  593. half_open_interval = _HalfOpenInterval
  594. simplex = _Simplex()
  595. lower_triangular = _LowerTriangular()
  596. lower_cholesky = _LowerCholesky()
  597. corr_cholesky = _CorrCholesky()
  598. square = _Square()
  599. symmetric = _Symmetric()
  600. positive_semidefinite = _PositiveSemidefinite()
  601. positive_definite = _PositiveDefinite()
  602. cat = _Cat
  603. stack = _Stack