von_mises.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # mypy: allow-untyped-defs
  2. import math
  3. import torch
  4. import torch.jit
  5. from torch import Tensor
  6. from torch.distributions import constraints
  7. from torch.distributions.distribution import Distribution
  8. from torch.distributions.utils import broadcast_all, lazy_property
  9. __all__ = ["VonMises"]
  10. def _eval_poly(y, coef):
  11. coef = list(coef)
  12. result = coef.pop()
  13. while coef:
  14. result = coef.pop() + y * result
  15. return result
  16. _I0_COEF_SMALL = [
  17. 1.0,
  18. 3.5156229,
  19. 3.0899424,
  20. 1.2067492,
  21. 0.2659732,
  22. 0.360768e-1,
  23. 0.45813e-2,
  24. ]
  25. _I0_COEF_LARGE = [
  26. 0.39894228,
  27. 0.1328592e-1,
  28. 0.225319e-2,
  29. -0.157565e-2,
  30. 0.916281e-2,
  31. -0.2057706e-1,
  32. 0.2635537e-1,
  33. -0.1647633e-1,
  34. 0.392377e-2,
  35. ]
  36. _I1_COEF_SMALL = [
  37. 0.5,
  38. 0.87890594,
  39. 0.51498869,
  40. 0.15084934,
  41. 0.2658733e-1,
  42. 0.301532e-2,
  43. 0.32411e-3,
  44. ]
  45. _I1_COEF_LARGE = [
  46. 0.39894228,
  47. -0.3988024e-1,
  48. -0.362018e-2,
  49. 0.163801e-2,
  50. -0.1031555e-1,
  51. 0.2282967e-1,
  52. -0.2895312e-1,
  53. 0.1787654e-1,
  54. -0.420059e-2,
  55. ]
  56. _COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
  57. _COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
  58. def _log_modified_bessel_fn(x, order=0):
  59. """
  60. Returns ``log(I_order(x))`` for ``x > 0``,
  61. where `order` is either 0 or 1.
  62. """
  63. if order != 0 and order != 1:
  64. raise AssertionError(f"order must be 0 or 1, got {order}")
  65. # compute small solution
  66. y = x / 3.75
  67. y = y * y
  68. small = _eval_poly(y, _COEF_SMALL[order])
  69. if order == 1:
  70. small = x.abs() * small
  71. small = small.log()
  72. # compute large solution
  73. y = 3.75 / x
  74. large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
  75. result = torch.where(x < 3.75, small, large)
  76. return result
  77. @torch.jit.script_if_tracing
  78. def _rejection_sample(loc, concentration, proposal_r, x):
  79. done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
  80. # pyrefly: ignore [bad-assignment, missing-attribute]
  81. while not done.all():
  82. u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
  83. u1, u2, u3 = u.unbind()
  84. z = torch.cos(math.pi * u1)
  85. f = (1 + proposal_r * z) / (proposal_r + z)
  86. c = concentration * (proposal_r - f)
  87. accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
  88. if accept.any():
  89. # pyrefly: ignore [no-matching-overload]
  90. x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
  91. done = done | accept
  92. return (x + math.pi + loc) % (2 * math.pi) - math.pi
  93. class VonMises(Distribution):
  94. """
  95. A circular von Mises distribution.
  96. This implementation uses polar coordinates. The ``loc`` and ``value`` args
  97. can be any real number (to facilitate unconstrained optimization), but are
  98. interpreted as angles modulo 2 pi.
  99. Example::
  100. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  101. >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
  102. >>> m.sample() # von Mises distributed with loc=1 and concentration=1
  103. tensor([1.9777])
  104. :param torch.Tensor loc: an angle in radians.
  105. :param torch.Tensor concentration: concentration parameter
  106. """
  107. # pyrefly: ignore [bad-override]
  108. arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
  109. support = constraints.real
  110. has_rsample = False
  111. def __init__(
  112. self,
  113. loc: Tensor,
  114. concentration: Tensor,
  115. validate_args: bool | None = None,
  116. ) -> None:
  117. self.loc, self.concentration = broadcast_all(loc, concentration)
  118. batch_shape = self.loc.shape
  119. event_shape = torch.Size()
  120. super().__init__(batch_shape, event_shape, validate_args)
  121. def log_prob(self, value):
  122. if self._validate_args:
  123. self._validate_sample(value)
  124. log_prob = self.concentration * torch.cos(value - self.loc)
  125. log_prob = (
  126. log_prob
  127. - math.log(2 * math.pi)
  128. - _log_modified_bessel_fn(self.concentration, order=0)
  129. )
  130. return log_prob
  131. @lazy_property
  132. def _loc(self) -> Tensor:
  133. return self.loc.to(torch.double)
  134. @lazy_property
  135. def _concentration(self) -> Tensor:
  136. return self.concentration.to(torch.double)
  137. @lazy_property
  138. def _proposal_r(self) -> Tensor:
  139. kappa = self._concentration
  140. # pyrefly: ignore [unsupported-operation]
  141. tau = 1 + (1 + 4 * kappa**2).sqrt()
  142. rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
  143. # pyrefly: ignore [unsupported-operation]
  144. _proposal_r = (1 + rho**2) / (2 * rho)
  145. # second order Taylor expansion around 0 for small kappa
  146. _proposal_r_taylor = 1 / kappa + kappa
  147. return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
  148. @torch.no_grad()
  149. def sample(self, sample_shape=torch.Size()):
  150. """
  151. The sampling algorithm for the von Mises distribution is based on the
  152. following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
  153. von Mises distribution." Applied Statistics (1979): 152-157.
  154. Sampling is always done in double precision internally to avoid a hang
  155. in _rejection_sample() for small values of the concentration, which
  156. starts to happen for single precision around 1e-4 (see issue #88443).
  157. """
  158. shape = self._extended_shape(sample_shape)
  159. x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
  160. return _rejection_sample(
  161. self._loc, self._concentration, self._proposal_r, x
  162. ).to(self.loc.dtype)
  163. def expand(self, batch_shape, _instance=None):
  164. try:
  165. return super().expand(batch_shape)
  166. except NotImplementedError:
  167. validate_args = self.__dict__.get("_validate_args")
  168. loc = self.loc.expand(batch_shape)
  169. concentration = self.concentration.expand(batch_shape)
  170. return type(self)(loc, concentration, validate_args=validate_args)
  171. @property
  172. def mean(self) -> Tensor:
  173. """
  174. The provided mean is the circular one.
  175. """
  176. return self.loc
  177. @property
  178. def mode(self) -> Tensor:
  179. return self.loc
  180. @lazy_property
  181. def variance(self) -> Tensor: # type: ignore[override]
  182. """
  183. The provided variance is the circular one.
  184. """
  185. return (
  186. 1
  187. - (
  188. _log_modified_bessel_fn(self.concentration, order=1)
  189. - _log_modified_bessel_fn(self.concentration, order=0)
  190. ).exp()
  191. )