poisson.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from torch.distributions import constraints
  5. from torch.distributions.exp_family import ExponentialFamily
  6. from torch.distributions.utils import broadcast_all
  7. from torch.types import _Number, Number
  8. __all__ = ["Poisson"]
  9. class Poisson(ExponentialFamily):
  10. r"""
  11. Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
  12. Samples are nonnegative integers, with a pmf given by
  13. .. math::
  14. \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
  15. Example::
  16. >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
  17. >>> m = Poisson(torch.tensor([4]))
  18. >>> m.sample()
  19. tensor([ 3.])
  20. Args:
  21. rate (Number, Tensor): the rate parameter
  22. """
  23. # pyrefly: ignore [bad-override]
  24. arg_constraints = {"rate": constraints.nonnegative}
  25. support = constraints.nonnegative_integer
  26. @property
  27. def mean(self) -> Tensor:
  28. return self.rate
  29. @property
  30. def mode(self) -> Tensor:
  31. return self.rate.floor()
  32. @property
  33. def variance(self) -> Tensor:
  34. return self.rate
  35. def __init__(
  36. self,
  37. rate: Tensor | Number,
  38. validate_args: bool | None = None,
  39. ) -> None:
  40. (self.rate,) = broadcast_all(rate)
  41. if isinstance(rate, _Number):
  42. batch_shape = torch.Size()
  43. else:
  44. batch_shape = self.rate.size()
  45. super().__init__(batch_shape, validate_args=validate_args)
  46. def expand(self, batch_shape, _instance=None):
  47. new = self._get_checked_instance(Poisson, _instance)
  48. batch_shape = torch.Size(batch_shape)
  49. new.rate = self.rate.expand(batch_shape)
  50. super(Poisson, new).__init__(batch_shape, validate_args=False)
  51. new._validate_args = self._validate_args
  52. return new
  53. def sample(self, sample_shape=torch.Size()):
  54. shape = self._extended_shape(sample_shape)
  55. with torch.no_grad():
  56. return torch.poisson(self.rate.expand(shape))
  57. def log_prob(self, value):
  58. if self._validate_args:
  59. self._validate_sample(value)
  60. rate, value = broadcast_all(self.rate, value)
  61. return value.xlogy(rate) - rate - (value + 1).lgamma()
  62. @property
  63. def _natural_params(self) -> tuple[Tensor]:
  64. return (torch.log(self.rate),)
  65. # pyrefly: ignore [bad-override]
  66. def _log_normalizer(self, x):
  67. return torch.exp(x)