modular_timesfm.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. # Copyright 2025 Google LLC and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch TimesFM model."""
  15. import math
  16. from collections.abc import Callable, Sequence
  17. from dataclasses import dataclass
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from ... import initialization as init
  22. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  23. from ...modeling_outputs import BaseModelOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  27. from ...utils.generic import merge_with_config_defaults
  28. from ...utils.output_capturing import capture_outputs
  29. from ..llama.modeling_llama import LlamaRMSNorm
  30. from ..phi4_multimodal.modeling_phi4_multimodal import simple_eager_attention_forward
  31. from .configuration_timesfm import TimesFmConfig
  32. logger = logging.get_logger(__name__)
  33. @dataclass
  34. @auto_docstring
  35. class TimesFmOutput(BaseModelOutput):
  36. r"""
  37. loc (`torch.Tensor` of shape `(batch_size, )`):
  38. The mean of the time series inputs.
  39. scale (`torch.Tensor` of shape `(batch_size,)`):
  40. The scale of the time series inputs.
  41. """
  42. loc: torch.Tensor | None = None
  43. scale: torch.Tensor | None = None
  44. @dataclass
  45. @auto_docstring
  46. class TimesFmOutputForPrediction(BaseModelOutput):
  47. r"""
  48. mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  49. The mean predictions of the time series.
  50. full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  51. The full predictions of the time series including the mean and the quantiles.
  52. loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
  53. The loss of the TimesFM model.
  54. """
  55. mean_predictions: torch.Tensor | None = None
  56. full_predictions: torch.Tensor | None = None
  57. loss: torch.Tensor | float | None = None
  58. class TimesFmMLP(nn.Module):
  59. """Pax MLP in pytorch."""
  60. def __init__(self, config: TimesFmConfig):
  61. super().__init__()
  62. hidden_size = config.hidden_size
  63. intermediate_size = config.intermediate_size
  64. self.gate_proj = nn.Linear(hidden_size, intermediate_size)
  65. self.down_proj = nn.Linear(intermediate_size, hidden_size)
  66. self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
  67. def forward(self, x, paddings=None):
  68. gate_inp = self.layer_norm(x)
  69. gate = self.gate_proj(gate_inp)
  70. gate = F.relu(gate)
  71. outputs = self.down_proj(gate)
  72. if paddings is not None:
  73. outputs = outputs * (1.0 - paddings[:, :, None])
  74. return outputs + x
  75. class TimesFmResidualBlock(nn.Module):
  76. """TimesFM residual block."""
  77. def __init__(self, input_dims, hidden_dims, output_dims):
  78. super().__init__()
  79. self.input_dims = input_dims
  80. self.hidden_dims = hidden_dims
  81. self.output_dims = output_dims
  82. self.input_layer = nn.Linear(input_dims, hidden_dims)
  83. self.activation = nn.SiLU()
  84. self.output_layer = nn.Linear(hidden_dims, output_dims)
  85. self.residual_layer = nn.Linear(input_dims, output_dims)
  86. def forward(self, x):
  87. hidden = self.input_layer(x)
  88. hidden = self.activation(hidden)
  89. output = self.output_layer(hidden)
  90. residual = self.residual_layer(x)
  91. return output + residual
  92. class TimesFmRMSNorm(LlamaRMSNorm):
  93. pass
  94. class TimesFmPositionalEmbedding(nn.Module):
  95. """Generates position embedding for a given 1-d sequence."""
  96. def __init__(self, config: TimesFmConfig):
  97. super().__init__()
  98. min_timescale = config.min_timescale
  99. max_timescale = config.max_timescale
  100. self.min_timescale, self.max_timescale = min_timescale, max_timescale
  101. self.embedding_dims = config.hidden_size
  102. num_timescales = self.embedding_dims // 2
  103. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
  104. self.register_buffer(
  105. "inv_timescales",
  106. min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
  107. )
  108. def forward(self, seq_length=None, position=None):
  109. """Generates a Tensor of sinusoids with different frequencies.
  110. Args:
  111. seq_length: an optional Python int defining the output sequence length.
  112. if the `position` argument is specified.
  113. position: [B, seq_length], optional position for each token in the
  114. sequence, only required when the sequence is packed.
  115. Returns:
  116. [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
  117. """
  118. if position is None and seq_length is None:
  119. raise ValueError("Either position or seq_length must be provided")
  120. if position is None:
  121. # [1, seqlen]
  122. position = torch.arange(seq_length, dtype=torch.float32, device=self.inv_timescales.device).unsqueeze(0)
  123. elif position.ndim != 2:
  124. raise ValueError(f"position must be 2-dimensional, got shape {position.shape}")
  125. scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1)
  126. signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
  127. # Padding to ensure correct embedding dimension
  128. signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
  129. return signal
  130. class TimesFmAttention(nn.Module):
  131. """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query."""
  132. def __init__(self, config: TimesFmConfig, layer_idx: int):
  133. super().__init__()
  134. self.config = config
  135. self.is_causal = True
  136. self.attention_dropout = config.attention_dropout
  137. self.layer_idx = layer_idx
  138. self.num_heads = config.num_attention_heads
  139. self.hidden_size = config.hidden_size
  140. self.head_dim = config.head_dim
  141. self.q_size = self.num_heads * self.head_dim
  142. self.kv_size = self.num_heads * self.head_dim
  143. self.scaling = nn.Parameter(torch.empty((self.head_dim,)))
  144. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  145. self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  146. self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
  147. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
  148. def _scale_query(self, query: torch.Tensor) -> torch.Tensor:
  149. scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim))
  150. return query * scale[None, None, None, :]
  151. def forward(
  152. self,
  153. hidden_states: torch.Tensor,
  154. attention_mask: torch.Tensor | None = None,
  155. **kwargs: Unpack[FlashAttentionKwargs],
  156. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  157. input_shape = hidden_states.shape[:-1]
  158. hidden_shape = (*input_shape, -1, self.head_dim)
  159. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  160. query_states = self._scale_query(query_states)
  161. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  162. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  163. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  164. self.config._attn_implementation, simple_eager_attention_forward
  165. )
  166. attn_output, attn_weights = attention_interface(
  167. self,
  168. query_states,
  169. key_states,
  170. value_states,
  171. attention_mask,
  172. dropout=0.0 if not self.training else self.attention_dropout,
  173. scaling=1.0,
  174. **kwargs,
  175. )
  176. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  177. attn_output = self.o_proj(attn_output)
  178. return attn_output, attn_weights
  179. class TimesFmDecoderLayer(nn.Module):
  180. """Transformer layer."""
  181. def __init__(self, config: TimesFmConfig, layer_idx: int):
  182. super().__init__()
  183. self.self_attn = TimesFmAttention(config, layer_idx=layer_idx)
  184. self.mlp = TimesFmMLP(config)
  185. self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  186. def forward(
  187. self,
  188. hidden_states: torch.Tensor,
  189. attention_mask: torch.Tensor,
  190. paddings: torch.Tensor,
  191. **kwargs,
  192. ) -> torch.Tensor:
  193. # Self Attention
  194. residual = hidden_states
  195. hidden_states = self.input_layernorm(hidden_states)
  196. hidden_states, _ = self.self_attn(
  197. hidden_states=hidden_states,
  198. attention_mask=attention_mask,
  199. )
  200. hidden_states = residual + hidden_states
  201. # MLP
  202. hidden_states = self.mlp(hidden_states, paddings=paddings)
  203. return hidden_states
  204. @auto_docstring
  205. class TimesFmPreTrainedModel(PreTrainedModel):
  206. config: TimesFmConfig
  207. base_model_prefix = "timesfm"
  208. _no_split_modules = ["TimesFmDecoderLayer"]
  209. main_input_name = "past_values"
  210. input_modalities = ("time",)
  211. _supports_sdpa = True
  212. _can_record_outputs = {
  213. "hidden_states": TimesFmDecoderLayer,
  214. "attentions": TimesFmAttention,
  215. }
  216. @torch.no_grad()
  217. def _init_weights(self, module):
  218. super()._init_weights(module)
  219. if isinstance(module, TimesFmAttention):
  220. # Initialize scaling parameter
  221. init.ones_(module.scaling)
  222. elif isinstance(module, TimesFmPositionalEmbedding):
  223. num_timescales = module.embedding_dims // 2
  224. max_timescale, min_timescale = module.max_timescale, module.min_timescale
  225. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(
  226. num_timescales - 1, 1
  227. )
  228. init.copy_(
  229. module.inv_timescales,
  230. min_timescale
  231. * torch.exp(torch.arange(num_timescales, dtype=torch.float32) * -log_timescale_increment),
  232. )
  233. @auto_docstring
  234. class TimesFmModel(TimesFmPreTrainedModel):
  235. def __init__(self, config: TimesFmConfig):
  236. super().__init__(config)
  237. self.config = config
  238. self.input_ff_layer = TimesFmResidualBlock(
  239. input_dims=2 * config.patch_length,
  240. output_dims=config.hidden_size,
  241. hidden_dims=config.intermediate_size,
  242. )
  243. self.freq_emb = nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size)
  244. self.layers = nn.ModuleList(
  245. [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  246. )
  247. if self.config.use_positional_embedding:
  248. self.position_emb = TimesFmPositionalEmbedding(config=config)
  249. # Initialize weights and apply final processing
  250. self.post_init()
  251. def _forward_transform(
  252. self, inputs: torch.Tensor, patched_pads: torch.Tensor
  253. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  254. """Input is of shape [B, N, P]."""
  255. mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads)
  256. sigma = torch.clamp(sigma, min=self.config.tolerance)
  257. # Normalize each patch
  258. outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
  259. outputs = torch.where(
  260. torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
  261. torch.tensor(self.config.pad_val, dtype=outputs.dtype, device=outputs.device),
  262. outputs,
  263. )
  264. return outputs, (mu, sigma)
  265. @merge_with_config_defaults
  266. @capture_outputs
  267. @auto_docstring
  268. def forward(
  269. self,
  270. past_values: torch.Tensor,
  271. past_values_padding: torch.LongTensor,
  272. freq: torch.Tensor,
  273. **kwargs: Unpack[TransformersKwargs],
  274. ) -> TimesFmOutput:
  275. r"""
  276. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  277. Past values of the time series that serves as input to the model.
  278. past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  279. The padding indicator of the time series.
  280. freq (`torch.LongTensor` of shape `(batch_size,)`):
  281. Frequency indices for the time series data.
  282. """
  283. # Reshape into patches (using view for efficiency)
  284. bsize = past_values.shape[0]
  285. patched_inputs = past_values.view(bsize, -1, self.config.patch_length)
  286. patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length)
  287. patched_inputs = torch.where(
  288. torch.abs(patched_pads - 1.0) < self.config.tolerance,
  289. torch.tensor(0.0, dtype=patched_inputs.dtype, device=patched_inputs.device),
  290. patched_inputs,
  291. )
  292. patched_pads = torch.where(
  293. torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
  294. torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
  295. patched_pads,
  296. )
  297. patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads)
  298. # B x N x D
  299. patched_inputs = patched_inputs * (1.0 - patched_pads)
  300. concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
  301. model_input = self.input_ff_layer(concat_inputs)
  302. # A patch should not be padded even if there is at least one zero.
  303. patched_padding = torch.min(patched_pads, dim=-1)[0] # Get the values from the min result
  304. if self.config.use_positional_embedding:
  305. pos_emb = self.position_emb(model_input.shape[1])
  306. pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
  307. pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb)
  308. model_input += pos_emb
  309. f_emb = self.freq_emb(freq) # B x 1 x D
  310. model_input += f_emb
  311. # Convert paddings to attention mask and combine with causal mask
  312. hidden_states = model_input
  313. attention_mask = self._prepare_4d_attention_mask(
  314. attention_mask=patched_padding,
  315. sequence_length=hidden_states.shape[1],
  316. dtype=hidden_states.dtype,
  317. device=hidden_states.device,
  318. is_causal=True,
  319. )
  320. for layer in self.layers[: self.config.num_hidden_layers]:
  321. hidden_states = layer(
  322. hidden_states,
  323. attention_mask=attention_mask,
  324. paddings=patched_padding,
  325. **kwargs,
  326. )
  327. return TimesFmOutput(
  328. last_hidden_state=hidden_states,
  329. loc=stats[0],
  330. scale=stats[1],
  331. )
  332. @staticmethod
  333. def _prepare_4d_attention_mask(
  334. attention_mask: torch.Tensor | None,
  335. sequence_length: int,
  336. dtype: torch.dtype,
  337. device: torch.device,
  338. is_causal: bool = True,
  339. ) -> torch.Tensor | None:
  340. """
  341. Creates 4D attention mask and combines causal and padding masks if needed.
  342. Args:
  343. attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
  344. sequence_length: Length of the sequence
  345. dtype: Data type of the mask
  346. device: Device of the mask
  347. is_causal: Whether to apply causal masking
  348. Returns:
  349. 4D attention mask of shape (batch_size, 1, seq_length, seq_length)
  350. """
  351. # Get minimum value for the dtype
  352. min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
  353. # Handle padding mask
  354. if attention_mask is not None:
  355. # Convert 2D padding mask to 4D attention mask
  356. attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
  357. attention_mask = attention_mask * min_value
  358. # Create causal mask if needed
  359. if is_causal:
  360. causal_mask = torch.triu(
  361. torch.ones((sequence_length, sequence_length), dtype=dtype, device=device) * min_value,
  362. diagonal=1,
  363. )
  364. causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length)
  365. # Combine with padding mask if it exists
  366. if attention_mask is not None:
  367. attention_mask = torch.minimum(attention_mask, causal_mask)
  368. else:
  369. attention_mask = causal_mask
  370. return attention_mask
  371. @staticmethod
  372. def _timesfm_masked_mean_std(inputs: torch.Tensor, padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  373. """Calculates mean and standard deviation of `inputs` across axis 1.
  374. It excludes values where `padding` is 1.
  375. Args:
  376. inputs: A PyTorch tensor of shape [b, n, p].
  377. padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
  378. Returns:
  379. A tuple containing the mean and standard deviation.
  380. We return the statistics of the first patch with more than three non-padded values.
  381. """
  382. # Selecting the first patch with more than 3 unpadded values.
  383. def _get_patch_index(arr: torch.Tensor):
  384. indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
  385. row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
  386. return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
  387. pad_sum = torch.sum(1 - padding, dim=2)
  388. patch_indices = _get_patch_index(pad_sum)
  389. bidxs = torch.arange(inputs.shape[0])
  390. arr = inputs[bidxs, patch_indices, :]
  391. pad = padding[bidxs, patch_indices, :]
  392. # Create a mask where padding is 0
  393. mask = 1 - pad
  394. # Calculate the number of valid elements
  395. num_valid_elements = torch.sum(mask, dim=1)
  396. num_valid_elements = torch.clamp(num_valid_elements, min=1.0)
  397. # Calculate the masked sum and mean
  398. masked_sum = torch.sum(arr * mask, dim=1)
  399. masked_mean = masked_sum / num_valid_elements # [b]
  400. # Calculate the masked variance using centered values
  401. masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask
  402. masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements
  403. masked_var = torch.clamp(masked_var, min=0.0)
  404. masked_std = torch.sqrt(masked_var)
  405. return masked_mean, masked_std
  406. @staticmethod
  407. def _timesfm_shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
  408. """Shifts rows of seq based on the first 0 in each row of the mask.
  409. Args:
  410. mask: mask tensor of shape [B, N]
  411. seq: seq tensor of shape [B, N, P]
  412. Returns:
  413. The shifted sequence.
  414. """
  415. batch_size, num_seq, feature_dim = seq.shape
  416. new_mask: torch.BoolTensor = mask == 0
  417. # Use argmax to find the first True value in each row
  418. indices = new_mask.to(torch.int32).argmax(dim=1)
  419. # Handle rows with all zeros
  420. indices[~new_mask.any(dim=1)] = -1
  421. # Create index ranges for each sequence in the batch
  422. idx_range = torch.arange(num_seq, device=seq.device).view(1, -1, 1).expand(batch_size, -1, feature_dim)
  423. # Calculate shifted indices for each element in each sequence
  424. shifted_idx = (idx_range - indices[:, None, None]) % num_seq
  425. # Gather values from seq using shifted indices
  426. shifted_seq = seq.gather(1, shifted_idx)
  427. return shifted_seq
  428. class TimesFmModelForPrediction(TimesFmPreTrainedModel):
  429. """TimesFM model for quantile and mean prediction."""
  430. def __init__(self, config: TimesFmConfig):
  431. super().__init__(config)
  432. self.config = config
  433. self.context_len = config.context_length
  434. self.horizon_len = config.horizon_length
  435. self.decoder = TimesFmModel(config)
  436. # quantile and mean output
  437. self.horizon_ff_layer = TimesFmResidualBlock(
  438. input_dims=config.hidden_size,
  439. output_dims=config.horizon_length * (1 + len(config.quantiles)),
  440. hidden_dims=config.intermediate_size,
  441. )
  442. # Initialize weights and apply final processing
  443. self.post_init()
  444. def _preprocess(
  445. self, inputs: Sequence[torch.Tensor], freq: Sequence[int] | None = None, context_len: int | None = None
  446. ) -> tuple[torch.Tensor, ...]:
  447. """Pad/truncate input time series to `context_len` and build a padding mask.
  448. Args:
  449. inputs: A list of 1d Tensors. Each Tensor is the context time series of a single forecast task.
  450. freq: Optional list of frequencies (returned as a tensor when provided).
  451. context_len: Optional context length override (defaults to `self.context_len`).
  452. Returns:
  453. Tuple of (padded_inputs, padding_mask) and optionally a freq tensor.
  454. """
  455. if context_len is None:
  456. context_len = self.context_len
  457. input_ts, input_padding = [], []
  458. for ts in inputs:
  459. input_len = ts.shape[0]
  460. padding = torch.zeros(input_len + self.horizon_len, dtype=ts.dtype, device=ts.device)
  461. if input_len < context_len:
  462. num_front_pad = context_len - input_len
  463. ts = torch.cat([torch.zeros(num_front_pad, dtype=ts.dtype, device=ts.device), ts], dim=0)
  464. padding = torch.cat([torch.ones(num_front_pad, dtype=ts.dtype, device=padding.device), padding], dim=0)
  465. elif input_len > context_len:
  466. ts = ts[-context_len:]
  467. padding = padding[-(context_len + self.horizon_len) :]
  468. input_ts.append(ts)
  469. input_padding.append(padding)
  470. result = (torch.stack(input_ts, dim=0), torch.stack(input_padding, dim=0))
  471. if freq is not None:
  472. result = result + (torch.tensor(freq[: len(inputs)], dtype=torch.int32).reshape(-1, 1),)
  473. return result
  474. def _postprocess_output(
  475. self, model_output: torch.Tensor, stats: tuple[torch.Tensor, torch.Tensor]
  476. ) -> torch.Tensor:
  477. """Postprocess output of stacked transformer."""
  478. # B x N x (H.Q)
  479. output_ts = self.horizon_ff_layer(model_output)
  480. # Reshape using view
  481. b, n, _ = output_ts.shape
  482. output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1)
  483. mu, sigma = stats
  484. return output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
  485. def _quantile_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
  486. losses = []
  487. for i, q in enumerate(self.config.quantiles):
  488. errors = targets - predictions[..., i]
  489. loss = torch.max((q - 1) * errors, q * errors)
  490. losses.append(loss.mean())
  491. return torch.stack(losses).mean()
  492. @can_return_tuple
  493. @auto_docstring
  494. def forward(
  495. self,
  496. past_values: Sequence[torch.Tensor],
  497. freq: Sequence[torch.Tensor | int] | None = None,
  498. window_size: int | None = None,
  499. future_values: torch.Tensor | None = None,
  500. forecast_context_len: int | None = None,
  501. return_forecast_on_context: bool = False,
  502. truncate_negative: bool = False,
  503. **kwargs: Unpack[TransformersKwargs],
  504. ) -> TimesFmOutputForPrediction:
  505. r"""
  506. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  507. Past values of the time series that serves as input to the model.
  508. freq (`torch.LongTensor` of shape `(batch_size,)`):
  509. Frequency indices for the time series data.
  510. window_size (`int`, *optional*):
  511. Window size of trend + residual decomposition. If None then we do not do decomposition.
  512. future_values (`torch.Tensor`, *optional*):
  513. Optional future time series values to be used for loss computation.
  514. forecast_context_len (`int`, *optional*):
  515. Optional max context length.
  516. return_forecast_on_context (`bool`, *optional*):
  517. True to return the forecast on the context when available, i.e. after the first input patch.
  518. truncate_negative (`bool`, *optional*):
  519. Truncate to only non-negative values if any of the contexts have non-negative values,
  520. otherwise do nothing.
  521. Example:
  522. ```python
  523. >>> from transformers import TimesFmModelForPrediction
  524. >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
  525. >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
  526. >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
  527. >>> # Generate
  528. >>> with torch.no_grad():
  529. >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
  530. >>> point_forecast_conv = outputs.mean_predictions
  531. >>> quantile_forecast_conv = outputs.full_predictions
  532. ```
  533. """
  534. if forecast_context_len is None:
  535. fcontext_len = self.context_len
  536. else:
  537. fcontext_len = forecast_context_len
  538. device = past_values[0].device
  539. inputs = [ts[-fcontext_len:] for ts in past_values]
  540. inp_min = torch.min(torch.stack([torch.min(ts) for ts in inputs]))
  541. if window_size is not None:
  542. new_inputs = []
  543. new_freqs = []
  544. for i, ts in enumerate(inputs):
  545. new_inputs.extend(self._timesfm_moving_average(ts, window_size))
  546. if freq is not None:
  547. new_freqs.extend([freq[i]] * 2)
  548. inputs = new_inputs
  549. if freq is not None:
  550. freq = new_freqs
  551. if freq is None:
  552. logger.info("No frequency provided via `freq`. Default to high (0).")
  553. freq = [0] * len(inputs)
  554. input_ts, input_padding, inp_freq = self._preprocess(inputs, freq)
  555. input_ts = input_ts.to(device)
  556. input_padding = input_padding.to(device)
  557. inp_freq = inp_freq.to(device)
  558. final_out = input_ts
  559. context_len = final_out.shape[1]
  560. full_outputs = []
  561. if input_padding.shape[1] != final_out.shape[1] + self.horizon_len:
  562. raise ValueError(
  563. "Length of paddings must match length of input + horizon_len:"
  564. f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}"
  565. )
  566. output_patch_len = self.config.horizon_length
  567. num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len
  568. for step_index in range(num_decode_patches):
  569. current_padding = input_padding[:, 0 : final_out.shape[1]]
  570. input_ts = final_out[:, -fcontext_len:]
  571. input_padding = current_padding[:, -fcontext_len:]
  572. decoder_output: TimesFmOutput = self.decoder(
  573. past_values=input_ts,
  574. past_values_padding=input_padding,
  575. freq=inp_freq,
  576. **kwargs,
  577. )
  578. fprop_outputs = self._postprocess_output(
  579. decoder_output.last_hidden_state,
  580. (decoder_output.loc, decoder_output.scale),
  581. )
  582. if return_forecast_on_context and step_index == 0:
  583. new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :]
  584. new_full_ts = new_full_ts.reshape(new_full_ts.size(0), -1, new_full_ts.size(3))
  585. full_outputs.append(new_full_ts)
  586. new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
  587. new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
  588. full_outputs.append(new_full_ts)
  589. final_out = torch.concatenate([final_out, new_ts], axis=-1)
  590. if return_forecast_on_context:
  591. full_outputs = torch.concatenate(full_outputs, axis=1)[
  592. :, : (context_len - self.config.patch_length + self.horizon_len), :
  593. ]
  594. else:
  595. full_outputs = torch.concatenate(full_outputs, axis=1)[:, 0 : self.horizon_len, :]
  596. mean_outputs = full_outputs[:, :, 0]
  597. if window_size is not None:
  598. mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
  599. full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
  600. if inp_min >= 0 and truncate_negative:
  601. mean_outputs = torch.maximum(mean_outputs, 0.0)
  602. full_outputs = torch.maximum(full_outputs, 0.0)
  603. loss = None
  604. if future_values is not None:
  605. mse_loss = F.mse_loss(mean_outputs, future_values)
  606. quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values)
  607. loss = mse_loss + quantile_loss
  608. return TimesFmOutputForPrediction(
  609. last_hidden_state=decoder_output.last_hidden_state,
  610. attentions=decoder_output.attentions,
  611. hidden_states=decoder_output.hidden_states,
  612. mean_predictions=mean_outputs,
  613. full_predictions=full_outputs,
  614. loss=loss,
  615. )
  616. @staticmethod
  617. def _timesfm_moving_average(arr: torch.Tensor, window_size: int) -> list[torch.Tensor]:
  618. """Calculates the moving average using PyTorch's convolution function."""
  619. # Pad with zeros to handle initial window positions
  620. arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0)
  621. # Create a convolution kernel
  622. kernel = torch.ones(window_size, dtype=arr.dtype, device=arr.device) / window_size
  623. # Apply convolution to calculate the moving average
  624. smoothed_arr = F.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze()
  625. return [smoothed_arr, arr - smoothed_arr]
  626. __all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"]