gumbel.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # mypy: allow-untyped-defs
  2. import math
  3. import torch
  4. from torch import Tensor
  5. from torch.distributions import constraints
  6. from torch.distributions.transformed_distribution import TransformedDistribution
  7. from torch.distributions.transforms import AffineTransform, ExpTransform
  8. from torch.distributions.uniform import Uniform
  9. from torch.distributions.utils import broadcast_all, euler_constant
  10. from torch.types import _Number
  11. __all__ = ["Gumbel"]
  12. class Gumbel(TransformedDistribution):
  13. r"""
  14. Samples from a Gumbel Distribution.
  15. Examples::
  16. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  17. >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
  18. >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
  19. tensor([ 1.0124])
  20. Args:
  21. loc (float or Tensor): Location parameter of the distribution
  22. scale (float or Tensor): Scale parameter of the distribution
  23. """
  24. arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
  25. # pyrefly: ignore [bad-override]
  26. support = constraints.real
  27. def __init__(
  28. self,
  29. loc: Tensor | float,
  30. scale: Tensor | float,
  31. validate_args: bool | None = None,
  32. ) -> None:
  33. self.loc, self.scale = broadcast_all(loc, scale)
  34. finfo = torch.finfo(self.loc.dtype)
  35. if isinstance(loc, _Number) and isinstance(scale, _Number):
  36. base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
  37. else:
  38. base_dist = Uniform(
  39. torch.full_like(self.loc, finfo.tiny),
  40. torch.full_like(self.loc, 1 - finfo.eps),
  41. validate_args=validate_args,
  42. )
  43. transforms = [
  44. ExpTransform().inv,
  45. AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
  46. ExpTransform().inv,
  47. AffineTransform(loc=loc, scale=-self.scale),
  48. ]
  49. super().__init__(base_dist, transforms, validate_args=validate_args)
  50. def expand(self, batch_shape, _instance=None):
  51. new = self._get_checked_instance(Gumbel, _instance)
  52. new.loc = self.loc.expand(batch_shape)
  53. new.scale = self.scale.expand(batch_shape)
  54. return super().expand(batch_shape, _instance=new)
  55. # Explicitly defining the log probability function for Gumbel due to precision issues
  56. def log_prob(self, value):
  57. if self._validate_args:
  58. self._validate_sample(value)
  59. y = (self.loc - value) / self.scale
  60. return (y - y.exp()) - self.scale.log()
  61. @property
  62. def mean(self) -> Tensor:
  63. return self.loc + self.scale * euler_constant
  64. @property
  65. def mode(self) -> Tensor:
  66. return self.loc
  67. @property
  68. def stddev(self) -> Tensor:
  69. return (math.pi / math.sqrt(6)) * self.scale
  70. @property
  71. def variance(self) -> Tensor:
  72. return self.stddev.pow(2)
  73. def entropy(self):
  74. return self.scale.log() + (1 + euler_constant)