modeling_timesfm.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_timesfm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Google LLC and HuggingFace Inc. team.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable, Sequence
  22. from dataclasses import dataclass
  23. import torch
  24. import torch.nn as nn
  25. import torch.nn.functional as F
  26. from ... import initialization as init
  27. from ...integrations import use_kernel_forward_from_hub
  28. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  29. from ...modeling_outputs import BaseModelOutput
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  33. from ...utils.generic import merge_with_config_defaults
  34. from ...utils.output_capturing import capture_outputs
  35. from .configuration_timesfm import TimesFmConfig
  36. logger = logging.get_logger(__name__)
  37. @dataclass
  38. @auto_docstring
  39. class TimesFmOutput(BaseModelOutput):
  40. r"""
  41. loc (`torch.Tensor` of shape `(batch_size, )`):
  42. The mean of the time series inputs.
  43. scale (`torch.Tensor` of shape `(batch_size,)`):
  44. The scale of the time series inputs.
  45. """
  46. loc: torch.Tensor | None = None
  47. scale: torch.Tensor | None = None
  48. @dataclass
  49. @auto_docstring
  50. class TimesFmOutputForPrediction(BaseModelOutput):
  51. r"""
  52. mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  53. The mean predictions of the time series.
  54. full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  55. The full predictions of the time series including the mean and the quantiles.
  56. loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
  57. The loss of the TimesFM model.
  58. """
  59. mean_predictions: torch.Tensor | None = None
  60. full_predictions: torch.Tensor | None = None
  61. loss: torch.Tensor | float | None = None
  62. class TimesFmMLP(nn.Module):
  63. """Pax MLP in pytorch."""
  64. def __init__(self, config: TimesFmConfig):
  65. super().__init__()
  66. hidden_size = config.hidden_size
  67. intermediate_size = config.intermediate_size
  68. self.gate_proj = nn.Linear(hidden_size, intermediate_size)
  69. self.down_proj = nn.Linear(intermediate_size, hidden_size)
  70. self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
  71. def forward(self, x, paddings=None):
  72. gate_inp = self.layer_norm(x)
  73. gate = self.gate_proj(gate_inp)
  74. gate = F.relu(gate)
  75. outputs = self.down_proj(gate)
  76. if paddings is not None:
  77. outputs = outputs * (1.0 - paddings[:, :, None])
  78. return outputs + x
  79. class TimesFmResidualBlock(nn.Module):
  80. """TimesFM residual block."""
  81. def __init__(self, input_dims, hidden_dims, output_dims):
  82. super().__init__()
  83. self.input_dims = input_dims
  84. self.hidden_dims = hidden_dims
  85. self.output_dims = output_dims
  86. self.input_layer = nn.Linear(input_dims, hidden_dims)
  87. self.activation = nn.SiLU()
  88. self.output_layer = nn.Linear(hidden_dims, output_dims)
  89. self.residual_layer = nn.Linear(input_dims, output_dims)
  90. def forward(self, x):
  91. hidden = self.input_layer(x)
  92. hidden = self.activation(hidden)
  93. output = self.output_layer(hidden)
  94. residual = self.residual_layer(x)
  95. return output + residual
  96. @use_kernel_forward_from_hub("RMSNorm")
  97. class TimesFmRMSNorm(nn.Module):
  98. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  99. """
  100. TimesFmRMSNorm is equivalent to T5LayerNorm
  101. """
  102. super().__init__()
  103. self.weight = nn.Parameter(torch.ones(hidden_size))
  104. self.variance_epsilon = eps
  105. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  106. input_dtype = hidden_states.dtype
  107. hidden_states = hidden_states.to(torch.float32)
  108. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  109. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  110. return self.weight * hidden_states.to(input_dtype)
  111. def extra_repr(self):
  112. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  113. class TimesFmPositionalEmbedding(nn.Module):
  114. """Generates position embedding for a given 1-d sequence."""
  115. def __init__(self, config: TimesFmConfig):
  116. super().__init__()
  117. min_timescale = config.min_timescale
  118. max_timescale = config.max_timescale
  119. self.min_timescale, self.max_timescale = min_timescale, max_timescale
  120. self.embedding_dims = config.hidden_size
  121. num_timescales = self.embedding_dims // 2
  122. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
  123. self.register_buffer(
  124. "inv_timescales",
  125. min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
  126. )
  127. def forward(self, seq_length=None, position=None):
  128. """Generates a Tensor of sinusoids with different frequencies.
  129. Args:
  130. seq_length: an optional Python int defining the output sequence length.
  131. if the `position` argument is specified.
  132. position: [B, seq_length], optional position for each token in the
  133. sequence, only required when the sequence is packed.
  134. Returns:
  135. [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
  136. """
  137. if position is None and seq_length is None:
  138. raise ValueError("Either position or seq_length must be provided")
  139. if position is None:
  140. # [1, seqlen]
  141. position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
  142. elif position.ndim != 2:
  143. raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
  144. scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
  145. signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
  146. # Padding to ensure correct embedding dimension
  147. signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
  148. return signal
  149. def simple_eager_attention_forward(
  150. module: nn.Module,
  151. query_states: torch.Tensor,
  152. key_states: torch.Tensor,
  153. value_states: torch.Tensor,
  154. attention_mask: torch.Tensor | None,
  155. scaling: float,
  156. dropout: float | int = 0.0,
  157. **kwargs: Unpack[TransformersKwargs],
  158. ):
  159. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
  160. if attention_mask is not None:
  161. attn_weights = attn_weights + attention_mask
  162. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  163. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  164. attn_output = torch.matmul(attn_weights, value_states)
  165. attn_output = attn_output.transpose(1, 2).contiguous()
  166. return attn_output, attn_weights
  167. class TimesFmAttention(nn.Module):
  168. """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
  169. def __init__(self, config: TimesFmConfig, layer_idx: int):
  170. super().__init__()
  171. self.config = config
  172. self.is_causal = True
  173. self.attention_dropout = config.attention_dropout
  174. self.layer_idx = layer_idx
  175. self.num_heads = config.num_attention_heads
  176. self.hidden_size = config.hidden_size
  177. self.head_dim = config.head_dim
  178. self.q_size = self.num_heads * self.head_dim
  179. self.kv_size = self.num_heads * self.head_dim
  180. self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
  181. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  182. self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  183. self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  184. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
  185. def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
  186. scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
  187. return query * scale[None, None, None, :]
  188. def forward(
  189. self,
  190. hidden_states: torch.Tensor,
  191. attention_mask: torch.Tensor | None = None,
  192. **kwargs: Unpack[FlashAttentionKwargs],
  193. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  194. input_shape = hidden_states.shape[:-1]
  195. hidden_shape = (*input_shape, -1, self.head_dim)
  196. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  197. query_states = self._scale_query(query_states)
  198. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  199. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  200. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  201. self.config._attn_implementation, simple_eager_attention_forward
  202. )
  203. attn_output, attn_weights = attention_interface(
  204. self,
  205. query_states,
  206. key_states,
  207. value_states,
  208. attention_mask,
  209. dropout=0.0 if not self.training else self.attention_dropout,
  210. scaling=1.0,
  211. **kwargs,
  212. )
  213. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  214. attn_output = self.o_proj(attn_output)
  215. return attn_output, attn_weights
  216. class TimesFmDecoderLayer(nn.Module):
  217. """Transformer layer."""
  218. def __init__(self, config: TimesFmConfig, layer_idx: int):
  219. super().__init__()
  220. self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
  221. self.mlp = TimesFmMLP(config)
  222. self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  223. def forward(
  224. self,
  225. hidden_states: torch.Tensor,
  226. attention_mask: torch.Tensor,
  227. paddings: torch.Tensor,
  228. **kwargs,
  229. ) -> torch.Tensor:
  230. # Self Attention
  231. residual = hidden_states
  232. hidden_states = self.input_layernorm(hidden_states)
  233. hidden_states, _ = self.self_attn(
  234. hidden_states=hidden_states,
  235. attention_mask=attention_mask,
  236. )
  237. hidden_states = residual + hidden_states
  238. # MLP
  239. hidden_states = self.mlp(hidden_states, paddings=paddings)
  240. return hidden_states
  241. @auto_docstring
  242. class TimesFmPreTrainedModel(PreTrainedModel):
  243. config: TimesFmConfig
  244. base_model_prefix = "timesfm"
  245. _no_split_modules = ["TimesFmDecoderLayer"]
  246. main_input_name = "past_values"
  247. input_modalities = ("time",)
  248. _supports_sdpa = True
  249. _can_record_outputs = {
  250. "hidden_states": TimesFmDecoderLayer,
  251. "attentions": TimesFmAttention,
  252. }
  253. @torch.no_grad()
  254. def _init_weights(self, module):
  255. super()._init_weights(module)
  256. if isinstance(module, TimesFmAttention):
  257. # Initialize scaling parameter
  258. init.ones_(module.scaling)
  259. elif isinstance(module, TimesFmPositionalEmbedding):
  260. num_timescales = module.embedding_dims // 2
  261. max_timescale, min_timescale = module.max_timescale, module.min_timescale
  262. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
  263. num_timescales - 1, 1
  264. )
  265. init.copy_(
  266. module.inv_timescales,
  267. min_timescale
  268. * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
  269. )
  270. @auto_docstring
  271. class TimesFmModel(TimesFmPreTrainedModel):
  272. def __init__(self, config: TimesFmConfig):
  273. super().__init__(config)
  274. self.config = config
  275. self.input_ff_layer = TimesFmResidualBlock(
  276. input_dims=2 * config.patch_length,
  277. output_dims=config.hidden_size,
  278. hidden_dims=config.intermediate_size,
  279. )
  280. self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
  281. self.layers = nn.ModuleList(
  282. [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  283. )
  284. if self.config.use_positional_embedding:
  285. self.position_emb = TimesFmPositionalEmbedding(config=config)
  286. # Initialize weights and apply final processing
  287. self.post_init()
  288. def _forward_transform(
  289. self, inputs: torch.Tensor, patched_pads: torch.Tensor
  290. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  291. """Input is of shape [B, N, P]."""
  292. mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
  293. sigma = torch.clamp(sigma, min=self.config.tolerance)
  294. # Normalize each patch
  295. outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
  296. outputs = torch.where(
  297. torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
  298. torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
  299. outputs,
  300. )
  301. return outputs, (mu, sigma)
  302. @merge_with_config_defaults
  303. @capture_outputs
  304. @auto_docstring
  305. def forward(
  306. self,
  307. past_values: torch.Tensor,
  308. past_values_padding: torch.LongTensor,
  309. freq: torch.Tensor,
  310. **kwargs: Unpack[TransformersKwargs],
  311. ) -> TimesFmOutput:
  312. r"""
  313. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  314. Past values of the time series that serves as input to the model.
  315. past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  316. The padding indicator of the time series.
  317. freq (`torch.LongTensor` of shape `(batch_size,)`):
  318. Frequency indices for the time series data.
  319. """
  320. # Reshape into patches (using view for efficiency)
  321. bsize = past_values.shape[0]
  322. patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
  323. patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
  324. patched_inputs = torch.where(
  325. torch.abs(patched_pads - 1.0) < self.config.tolerance,
  326. torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
  327. patched_inputs,
  328. )
  329. patched_pads = torch.where(
  330. torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
  331. torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
  332. patched_pads,
  333. )
  334. patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
  335. # B x N x D
  336. patched_inputs = patched_inputs * (1.0 - patched_pads)
  337. concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
  338. model_input = self.input_ff_layer(concat_inputs)
  339. # A patch should not be padded even if there is at least one zero.
  340. patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
  341. if self.config.use_positional_embedding:
  342. pos_emb = self.position_emb(model_input.shape[1])
  343. pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
  344. pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
  345. model_input += pos_emb
  346. f_emb = self.freq_emb(freq) # B x 1 x D
  347. model_input += f_emb
  348. # Convert paddings to attention mask and combine with causal mask
  349. hidden_states = model_input
  350. attention_mask = self._prepare_4d_attention_mask(
  351. attention_mask=patched_padding,
  352. sequence_length=hidden_states.shape[1],
  353. dtype=hidden_states.dtype,
  354. device=hidden_states.device,
  355. is_causal=True,
  356. )
  357. for layer in self.layers[: self.config.num_hidden_layers]:
  358. hidden_states = layer(
  359. hidden_states,
  360. attention_mask=attention_mask,
  361. paddings=patched_padding,
  362. **kwargs,
  363. )
  364. return TimesFmOutput(
  365. last_hidden_state=hidden_states,
  366. loc=stats[0],
  367. scale=stats[1],
  368. )
  369. @staticmethod
  370. def _prepare_4d_attention_mask(
  371. attention_mask: torch.Tensor | None,
  372. sequence_length: int,
  373. dtype: torch.dtype,
  374. device: torch.device,
  375. is_causal: bool = True,
  376. ) -> torch.Tensor | None:
  377. """
  378. Creates 4D attention mask and combines causal and padding masks if needed.
  379. Args:
  380. attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
  381. sequence_length: Length of the sequence
  382. dtype: Data type of the mask
  383. device: Device of the mask
  384. is_causal: Whether to apply causal masking
  385. Returns:
  386. 4D attention mask of shape (batch_size, 1, seq_length, seq_length)
  387. """
  388. # Get minimum value for the dtype
  389. min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
  390. # Handle padding mask
  391. if attention_mask is not None:
  392. # Convert 2D padding mask to 4D attention mask
  393. attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
  394. attention_mask = attention_mask * min_value
  395. # Create causal mask if needed
  396. if is_causal:
  397. causal_mask = torch.triu(
  398. torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
  399. diagonal=1,
  400. )
  401. causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
  402. # Combine with padding mask if it exists
  403. if attention_mask is not None:
  404. attention_mask = torch.minimum(attention_mask, causal_mask)
  405. else:
  406. attention_mask = causal_mask
  407. return attention_mask
  408. @staticmethod
  409. def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  410. """Calculates mean and standard deviation of `inputs` across axis 1.
  411. It excludes values where `padding` is 1.
  412. Args:
  413. inputs: A PyTorch tensor of shape [b, n, p].
  414. padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
  415. Returns:
  416. A tuple containing the mean and standard deviation.
  417. We return the statistics of the first patch with more than three non-padded values.
  418. """
  419. # Selecting the first patch with more than 3 unpadded values.
  420. def _get_patch_index(arr: torch.Tensor):
  421. indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
  422. row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
  423. return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
  424. pad_sum = torch.sum(1 - padding, dim=2)
  425. patch_indices = _get_patch_index(pad_sum)
  426. bidxs = torch.arange(inputs.shape[0])
  427. arr = inputs[bidxs, patch_indices, :]
  428. pad = padding[bidxs, patch_indices, :]
  429. # Create a mask where padding is 0
  430. mask = 1 - pad
  431. # Calculate the number of valid elements
  432. num_valid_elements = torch.sum(mask, dim=1)
  433. num_valid_elements = torch.clamp(num_valid_elements, min=1.0)
  434. # Calculate the masked sum and mean
  435. masked_sum = torch.sum(arr * mask, dim=1)
  436. masked_mean = masked_sum / num_valid_elements # [b]
  437. # Calculate the masked variance using centered values
  438. masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask
  439. masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements
  440. masked_var = torch.clamp(masked_var, min=0.0)
  441. masked_std = torch.sqrt(masked_var)
  442. return masked_mean, masked_std
  443. @staticmethod
  444. def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
  445. """Shifts rows of seq based on the first 0 in each row of the mask.
  446. Args:
  447. mask: mask tensor of shape [B, N]
  448. seq: seq tensor of shape [B, N, P]
  449. Returns:
  450. The shifted sequence.
  451. """
  452. batch_size, num_seq, feature_dim = seq.shape
  453. new_mask: torch.BoolTensor = mask == 0
  454. # Use argmax to find the first True value in each row
  455. indices = new_mask.to(torch.int32).argmax(dim=1)
  456. # Handle rows with all zeros
  457. indices[~new_mask.any(dim=1)] = -1
  458. # Create index ranges for each sequence in the batch
  459. idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
  460. # Calculate shifted indices for each element in each sequence
  461. shifted_idx = (idx_range - indices[:, None, None]) % num_seq
  462. # Gather values from seq using shifted indices
  463. shifted_seq = seq.gather(1, shifted_idx)
  464. return shifted_seq
  465. class TimesFmModelForPrediction(TimesFmPreTrainedModel):
  466. """TimesFM model for quantile and mean prediction."""
  467. def __init__(self, config: TimesFmConfig):
  468. super().__init__(config)
  469. self.config = config
  470. self.context_len = config.context_length
  471. self.horizon_len = config.horizon_length
  472. self.decoder = TimesFmModel(config)
  473. # quantile and mean output
  474. self.horizon_ff_layer = TimesFmResidualBlock(
  475. input_dims=config.hidden_size,
  476. output_dims=config.horizon_length * (1 + len(config.quantiles)),
  477. hidden_dims=config.intermediate_size,
  478. )
  479. # Initialize weights and apply final processing
  480. self.post_init()
  481. def _preprocess(
  482. self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None
  483. ) -> tuple[torch.Tensor, ...]:
  484. """Pad/truncate input time series to `context_len` and build a padding mask.
  485. Args:
  486. inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task.
  487. freq: Optional list of frequencies (returned as a tensor when provided).
  488. context_len: Optional context length override (defaults to `self.context_len`).
  489. Returns:
  490. Tuple of (padded_inputs, padding_mask) and optionally a freq tensor.
  491. """
  492. if context_len is None:
  493. context_len = self.context_len
  494. input_ts, input_padding = [], []
  495. for ts in inputs:
  496. input_len = ts.shape[0]
  497. padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
  498. if input_len < context_len:
  499. num_front_pad = context_len - input_len
  500. ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
  501. padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
  502. elif input_len > context_len:
  503. ts = ts[-context_len:]
  504. padding = padding[-(context_len + self.horizon_len) :]
  505. input_ts.append(ts)
  506. input_padding.append(padding)
  507. result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))
  508. if freq is not None:
  509. result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),)
  510. return result
  511. def _postprocess_output(
  512. self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
  513. ) -> torch.Tensor:
  514. """Postprocess output of stacked transformer."""
  515. # B x N x (H.Q)
  516. output_ts = self.horizon_ff_layer(model_output)
  517. # Reshape using view
  518. b, n, _ = output_ts.shape
  519. output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
  520. mu, sigma = stats
  521. return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
  522. def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
  523. losses = []
  524. for i, q in enumerate(self.config.quantiles):
  525. errors = targets - predictions[..., i]
  526. loss = torch.max((q - 1) * errors, q * errors)
  527. losses.append(loss.mean())
  528. return torch.stack(losses).mean()
  529. @can_return_tuple
  530. @auto_docstring
  531. def forward(
  532. self,
  533. past_values: Sequence[torch.Tensor],
  534. freq: Sequence[torch.Tensor | int] | None = None,
  535. window_size: int | None = None,
  536. future_values: torch.Tensor | None = None,
  537. forecast_context_len: int | None = None,
  538. return_forecast_on_context: bool = False,
  539. truncate_negative: bool = False,
  540. **kwargs: Unpack[TransformersKwargs],
  541. ) -> TimesFmOutputForPrediction:
  542. r"""
  543. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  544. Past values of the time series that serves as input to the model.
  545. freq (`torch.LongTensor` of shape `(batch_size,)`):
  546. Frequency indices for the time series data.
  547. window_size (`int`, *optional*):
  548. Window size of trend + residual decomposition. If None then we do not do decomposition.
  549. future_values (`torch.Tensor`, *optional*):
  550. Optional future time series values to be used for loss computation.
  551. forecast_context_len (`int`, *optional*):
  552. Optional max context length.
  553. return_forecast_on_context (`bool`, *optional*):
  554. True to return the forecast on the context when available, i.e. after the first input patch.
  555. truncate_negative (`bool`, *optional*):
  556. Truncate to only non-negative values if any of the contexts have non-negative values,
  557. otherwise do nothing.
  558. Example:
  559. ```python
  560. >>> from transformers import TimesFmModelForPrediction
  561. >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
  562. >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
  563. >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
  564. >>> # Generate
  565. >>> with torch.no_grad():
  566. >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
  567. >>> point_forecast_conv = outputs.mean_predictions
  568. >>> quantile_forecast_conv = outputs.full_predictions
  569. ```
  570. """
  571. if forecast_context_len is None:
  572. fcontext_len = self.context_len
  573. else:
  574. fcontext_len = forecast_context_len
  575. device = past_values[0].device
  576. inputs = [ts[-fcontext_len:] for ts in past_values]
  577. inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
  578. if window_size is not None:
  579. new_inputs = []
  580. new_freqs = []
  581. for i, ts in enumerate(inputs):
  582. new_inputs.extend(self._timesfm_moving_average(ts, window_size))
  583. if freq is not None:
  584. new_freqs.extend([freq[i]] * 2)
  585. inputs = new_inputs
  586. if freq is not None:
  587. freq = new_freqs
  588. if freq is None:
  589. logger.info("No frequency provided via `freq`. Default to high (0).")
  590. freq = [0] * len(inputs)
  591. input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
  592. input_ts = input_ts.to(device)
  593. input_padding = input_padding.to(device)
  594. inp_freq = inp_freq.to(device)
  595. final_out = input_ts
  596. context_len = final_out.shape[1]
  597. full_outputs = []
  598. if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
  599. raise ValueError(
  600. "Length of paddings must match length of input + horizon_len:"
  601. f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
  602. )
  603. output_patch_len = self.config.horizon_length
  604. num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
  605. for step_index in range(num_decode_patches):
  606. current_padding = input_padding[:, 0 : final_out.shape[1]]
  607. input_ts = final_out[:, -fcontext_len:]
  608. input_padding = current_padding[:, -fcontext_len:]
  609. decoder_output: TimesFmOutput = self.decoder(
  610. past_values=input_ts,
  611. past_values_padding=input_padding,
  612. freq=inp_freq,
  613. **kwargs,
  614. )
  615. fprop_outputs = self._postprocess_output(
  616. decoder_output.last_hidden_state,
  617. (decoder_output.loc, decoder_output.scale),
  618. )
  619. if return_forecast_on_context and step_index == 0:
  620. new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
  621. new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
  622. full_outputs.append(new_full_ts)
  623. new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
  624. new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
  625. full_outputs.append(new_full_ts)
  626. final_out = torch.concatenate([final_out, new_ts], axis=-1)
  627. if return_forecast_on_context:
  628. full_outputs = torch.concatenate(full_outputs, axis=1)[
  629. :, : (context_len - self.config.patch_length + self.horizon_len), :
  630. ]
  631. else:
  632. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
  633. mean_outputs = full_outputs[:, :, 0]
  634. if window_size is not None:
  635. mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
  636. full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
  637. if inp_min >= 0 and truncate_negative:
  638. mean_outputs = torch.maximum(mean_outputs, 0.0)
  639. full_outputs = torch.maximum(full_outputs, 0.0)
  640. loss = None
  641. if future_values is not None:
  642. mse_loss = F.mse_loss(mean_outputs, future_values)
  643. quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
  644. loss = mse_loss + quantile_loss
  645. return TimesFmOutputForPrediction(
  646. last_hidden_state=decoder_output.last_hidden_state,
  647. attentions=decoder_output.attentions,
  648. hidden_states=decoder_output.hidden_states,
  649. mean_predictions=mean_outputs,
  650. full_predictions=full_outputs,
  651. loss=loss,
  652. )
  653. @staticmethod
  654. def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
  655. """Calculates the moving average using PyTorch's convolution function."""
  656. # Pad with zeros to handle initial window positions
  657. arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
  658. # Create a convolution kernel
  659. kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
  660. # Apply convolution to calculate the moving average
  661. smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
  662. return [smoothed_arr, arr - smoothed_arr]
  663. __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]