studentT.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # mypy: allow-untyped-defs
  2. import math
  3. import torch
  4. from torch import inf, nan, Tensor
  5. from torch.distributions import Chi2, constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import _standard_normal, broadcast_all
  8. from torch.types import _size
  9. __all__ = ["StudentT"]
  10. class StudentT(Distribution):
  11. r"""
  12. Creates a Student's t-distribution parameterized by degree of
  13. freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`.
  14. Example::
  15. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  16. >>> m = StudentT(torch.tensor([2.0]))
  17. >>> m.sample() # Student's t-distributed with degrees of freedom=2
  18. tensor([ 0.1046])
  19. Args:
  20. df (float or Tensor): degrees of freedom
  21. loc (float or Tensor): mean of the distribution
  22. scale (float or Tensor): scale of the distribution
  23. """
  24. # pyrefly: ignore [bad-override]
  25. arg_constraints = {
  26. "df": constraints.positive,
  27. "loc": constraints.real,
  28. "scale": constraints.positive,
  29. }
  30. support = constraints.real
  31. has_rsample = True
  32. @property
  33. def mean(self) -> Tensor:
  34. m = self.loc.clone(memory_format=torch.contiguous_format)
  35. m[self.df <= 1] = nan
  36. return m
  37. @property
  38. def mode(self) -> Tensor:
  39. return self.loc
  40. @property
  41. def variance(self) -> Tensor:
  42. m = self.df.clone(memory_format=torch.contiguous_format)
  43. m[self.df > 2] = (
  44. self.scale[self.df > 2].pow(2)
  45. * self.df[self.df > 2]
  46. / (self.df[self.df > 2] - 2)
  47. )
  48. m[(self.df <= 2) & (self.df > 1)] = inf
  49. m[self.df <= 1] = nan
  50. return m
  51. def __init__(
  52. self,
  53. df: Tensor | float,
  54. loc: Tensor | float = 0.0,
  55. scale: Tensor | float = 1.0,
  56. validate_args: bool | None = None,
  57. ) -> None:
  58. self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
  59. self._chi2 = Chi2(self.df)
  60. batch_shape = self.df.size()
  61. super().__init__(batch_shape, validate_args=validate_args)
  62. def expand(self, batch_shape, _instance=None):
  63. new = self._get_checked_instance(StudentT, _instance)
  64. batch_shape = torch.Size(batch_shape)
  65. new.df = self.df.expand(batch_shape)
  66. new.loc = self.loc.expand(batch_shape)
  67. new.scale = self.scale.expand(batch_shape)
  68. new._chi2 = self._chi2.expand(batch_shape)
  69. super(StudentT, new).__init__(batch_shape, validate_args=False)
  70. new._validate_args = self._validate_args
  71. return new
  72. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  73. # NOTE: This does not agree with scipy implementation as much as other distributions.
  74. # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
  75. # parameters seems to help.
  76. # X ~ Normal(0, 1)
  77. # Z ~ Chi2(df)
  78. # Y = X / sqrt(Z / df) ~ StudentT(df)
  79. shape = self._extended_shape(sample_shape)
  80. X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
  81. Z = self._chi2.rsample(sample_shape)
  82. Y = X * torch.rsqrt(Z / self.df)
  83. return self.loc + self.scale * Y
  84. def log_prob(self, value):
  85. if self._validate_args:
  86. self._validate_sample(value)
  87. y = (value - self.loc) / self.scale
  88. Z = (
  89. self.scale.log()
  90. + 0.5 * self.df.log()
  91. + 0.5 * math.log(math.pi)
  92. + torch.lgamma(0.5 * self.df)
  93. - torch.lgamma(0.5 * (self.df + 1.0))
  94. )
  95. return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
  96. def entropy(self):
  97. lbeta = (
  98. torch.lgamma(0.5 * self.df)
  99. + math.lgamma(0.5)
  100. - torch.lgamma(0.5 * (self.df + 1))
  101. )
  102. return (
  103. self.scale.log()
  104. + 0.5
  105. * (self.df + 1)
  106. * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
  107. + 0.5 * self.df.log()
  108. + lbeta
  109. )