uniform.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import nan, Tensor
  4. from torch.distributions import constraints
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.utils import broadcast_all
  7. from torch.types import _Number, _size
  8. __all__ = ["Uniform"]
  9. class Uniform(Distribution):
  10. r"""
  11. Generates uniformly distributed random samples from the half-open interval
  12. ``[low, high)``.
  13. Example::
  14. >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
  15. >>> m.sample() # uniformly distributed in the range [0.0, 5.0)
  16. >>> # xdoctest: +SKIP
  17. tensor([ 2.3418])
  18. Args:
  19. low (float or Tensor): lower range (inclusive).
  20. high (float or Tensor): upper range (exclusive).
  21. """
  22. has_rsample = True
  23. @property
  24. def arg_constraints(self):
  25. # TODO allow (loc,scale) parameterization to allow independent constraints.
  26. return {
  27. "low": constraints.less_than(self.high),
  28. "high": constraints.greater_than(self.low),
  29. }
  30. @property
  31. def mean(self) -> Tensor:
  32. return (self.high + self.low) / 2
  33. @property
  34. def mode(self) -> Tensor:
  35. return nan * self.high
  36. @property
  37. def stddev(self) -> Tensor:
  38. return (self.high - self.low) / 12**0.5
  39. @property
  40. def variance(self) -> Tensor:
  41. return (self.high - self.low).pow(2) / 12
  42. def __init__(
  43. self,
  44. low: Tensor | float,
  45. high: Tensor | float,
  46. validate_args: bool | None = None,
  47. ) -> None:
  48. self.low, self.high = broadcast_all(low, high)
  49. if isinstance(low, _Number) and isinstance(high, _Number):
  50. batch_shape = torch.Size()
  51. else:
  52. batch_shape = self.low.size()
  53. super().__init__(batch_shape, validate_args=validate_args)
  54. def expand(self, batch_shape, _instance=None):
  55. new = self._get_checked_instance(Uniform, _instance)
  56. batch_shape = torch.Size(batch_shape)
  57. new.low = self.low.expand(batch_shape)
  58. new.high = self.high.expand(batch_shape)
  59. super(Uniform, new).__init__(batch_shape, validate_args=False)
  60. new._validate_args = self._validate_args
  61. return new
  62. @constraints.dependent_property(is_discrete=False, event_dim=0)
  63. # pyrefly: ignore [bad-override]
  64. def support(self):
  65. return constraints.interval(self.low, self.high)
  66. def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
  67. shape = self._extended_shape(sample_shape)
  68. rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)
  69. return self.low + rand * (self.high - self.low)
  70. def log_prob(self, value):
  71. if self._validate_args:
  72. self._validate_sample(value)
  73. lb = self.low.le(value).type_as(self.low)
  74. ub = self.high.gt(value).type_as(self.low)
  75. return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
  76. def cdf(self, value):
  77. if self._validate_args:
  78. self._validate_sample(value)
  79. result = (value - self.low) / (self.high - self.low)
  80. return result.clamp(min=0, max=1)
  81. def icdf(self, value):
  82. result = value * (self.high - self.low) + self.low
  83. return result
  84. def entropy(self):
  85. return torch.log(self.high - self.low)