time_series_utils.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright 2023 The HuggingFace Inc. team.
  2. # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Time series distributional output classes and utilities.
  17. """
  18. from collections.abc import Callable
  19. import torch
  20. from torch import nn
  21. from torch.distributions import (
  22. AffineTransform,
  23. Distribution,
  24. Independent,
  25. NegativeBinomial,
  26. Normal,
  27. StudentT,
  28. TransformedDistribution,
  29. )
  30. class AffineTransformed(TransformedDistribution):
  31. def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
  32. self.scale = 1.0 if scale is None else scale
  33. self.loc = 0.0 if loc is None else loc
  34. super().__init__(base_distribution, [AffineTransform(loc=self.loc, scale=self.scale, event_dim=event_dim)])
  35. @property
  36. def mean(self):
  37. """
  38. Returns the mean of the distribution.
  39. """
  40. return self.base_dist.mean * self.scale + self.loc
  41. @property
  42. def variance(self):
  43. """
  44. Returns the variance of the distribution.
  45. """
  46. return self.base_dist.variance * self.scale**2
  47. @property
  48. def stddev(self):
  49. """
  50. Returns the standard deviation of the distribution.
  51. """
  52. return self.variance.sqrt()
  53. class ParameterProjection(nn.Module):
  54. def __init__(
  55. self, in_features: int, args_dim: dict[str, int], domain_map: Callable[..., tuple[torch.Tensor]], **kwargs
  56. ) -> None:
  57. super().__init__(**kwargs)
  58. self.args_dim = args_dim
  59. self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()])
  60. self.domain_map = domain_map
  61. def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
  62. params_unbounded = [proj(x) for proj in self.proj]
  63. return self.domain_map(*params_unbounded)
  64. class LambdaLayer(nn.Module):
  65. def __init__(self, function):
  66. super().__init__()
  67. self.function = function
  68. def forward(self, x, *args):
  69. return self.function(x, *args)
  70. class DistributionOutput:
  71. distribution_class: type
  72. in_features: int
  73. args_dim: dict[str, int]
  74. def __init__(self, dim: int = 1) -> None:
  75. self.dim = dim
  76. self.args_dim = {k: dim * self.args_dim[k] for k in self.args_dim}
  77. def _base_distribution(self, distr_args):
  78. if self.dim == 1:
  79. return self.distribution_class(*distr_args)
  80. else:
  81. return Independent(self.distribution_class(*distr_args), 1)
  82. def distribution(
  83. self,
  84. distr_args,
  85. loc: torch.Tensor | None = None,
  86. scale: torch.Tensor | None = None,
  87. ) -> Distribution:
  88. distr = self._base_distribution(distr_args)
  89. if loc is None and scale is None:
  90. return distr
  91. else:
  92. return AffineTransformed(distr, loc=loc, scale=scale, event_dim=self.event_dim)
  93. @property
  94. def event_shape(self) -> tuple:
  95. r"""
  96. Shape of each individual event contemplated by the distributions that this object constructs.
  97. """
  98. return () if self.dim == 1 else (self.dim,)
  99. @property
  100. def event_dim(self) -> int:
  101. r"""
  102. Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object
  103. constructs.
  104. """
  105. return len(self.event_shape)
  106. @property
  107. def value_in_support(self) -> float:
  108. r"""
  109. A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By
  110. default 0.0. This value will be used when padding data series.
  111. """
  112. return 0.0
  113. def get_parameter_projection(self, in_features: int) -> nn.Module:
  114. r"""
  115. Return the parameter projection layer that maps the input to the appropriate parameters of the distribution.
  116. """
  117. return ParameterProjection(
  118. in_features=in_features,
  119. args_dim=self.args_dim,
  120. domain_map=LambdaLayer(self.domain_map),
  121. )
  122. def domain_map(self, *args: torch.Tensor):
  123. r"""
  124. Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the
  125. correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a
  126. distribution of the right event_shape.
  127. """
  128. raise NotImplementedError()
  129. @staticmethod
  130. def squareplus(x: torch.Tensor) -> torch.Tensor:
  131. r"""
  132. Helper to map inputs to the positive orthant by applying the square-plus operation. Reference:
  133. https://twitter.com/jon_barron/status/1387167648669048833
  134. """
  135. return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0
  136. class StudentTOutput(DistributionOutput):
  137. """
  138. Student-T distribution output class.
  139. """
  140. args_dim: dict[str, int] = {"df": 1, "loc": 1, "scale": 1}
  141. distribution_class: type = StudentT
  142. @classmethod
  143. def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
  144. scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
  145. df = 2.0 + cls.squareplus(df)
  146. return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)
  147. class NormalOutput(DistributionOutput):
  148. """
  149. Normal distribution output class.
  150. """
  151. args_dim: dict[str, int] = {"loc": 1, "scale": 1}
  152. distribution_class: type = Normal
  153. @classmethod
  154. def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor):
  155. scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
  156. return loc.squeeze(-1), scale.squeeze(-1)
  157. class NegativeBinomialOutput(DistributionOutput):
  158. """
  159. Negative Binomial distribution output class.
  160. """
  161. args_dim: dict[str, int] = {"total_count": 1, "logits": 1}
  162. distribution_class: type = NegativeBinomial
  163. @classmethod
  164. def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor):
  165. total_count = cls.squareplus(total_count)
  166. return total_count.squeeze(-1), logits.squeeze(-1)
  167. def _base_distribution(self, distr_args) -> Distribution:
  168. total_count, logits = distr_args
  169. if self.dim == 1:
  170. return self.distribution_class(total_count=total_count, logits=logits)
  171. else:
  172. return Independent(self.distribution_class(total_count=total_count, logits=logits), 1)
  173. # Overwrites the parent class method. We cannot scale using the affine
  174. # transformation since negative binomial should return integers. Instead
  175. # we scale the parameters.
  176. def distribution(
  177. self, distr_args, loc: torch.Tensor | None = None, scale: torch.Tensor | None = None
  178. ) -> Distribution:
  179. total_count, logits = distr_args
  180. if scale is not None:
  181. # See scaling property of Gamma.
  182. logits += scale.log()
  183. return self._base_distribution((total_count, logits))