modeling_patchtst.py 83 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973
  1. # Copyright 2023 IBM & Hugging Face. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch PatchTST model."""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2CLS
  22. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_outputs import BaseModelOutput
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
  28. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  29. from .configuration_patchtst import PatchTSTConfig
  30. logger = logging.get_logger(__name__)
  31. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  32. def eager_attention_forward(
  33. module: nn.Module,
  34. query: torch.Tensor,
  35. key: torch.Tensor,
  36. value: torch.Tensor,
  37. attention_mask: torch.Tensor | None,
  38. scaling: float | None = None,
  39. dropout: float = 0.0,
  40. **kwargs: Unpack[TransformersKwargs],
  41. ):
  42. if scaling is None:
  43. scaling = query.size(-1) ** -0.5
  44. # Take the dot product between "query" and "key" to get the raw attention scores.
  45. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  46. if attention_mask is not None:
  47. attn_weights = attn_weights + attention_mask
  48. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  49. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  50. attn_output = torch.matmul(attn_weights, value)
  51. attn_output = attn_output.transpose(1, 2).contiguous()
  52. return attn_output, attn_weights
  53. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTST
  54. class PatchTSTAttention(nn.Module):
  55. """Multi-headed attention from 'Attention Is All You Need' paper"""
  56. def __init__(
  57. self,
  58. embed_dim: int,
  59. num_heads: int,
  60. dropout: float = 0.0,
  61. is_decoder: bool = False,
  62. bias: bool = True,
  63. is_causal: bool = False,
  64. config: PatchTSTConfig | None = None,
  65. ):
  66. super().__init__()
  67. self.embed_dim = embed_dim
  68. self.num_heads = num_heads
  69. self.dropout = dropout
  70. self.head_dim = embed_dim // num_heads
  71. self.config = config
  72. if (self.head_dim * num_heads) != self.embed_dim:
  73. raise ValueError(
  74. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  75. f" and `num_heads`: {num_heads})."
  76. )
  77. self.scaling = self.head_dim**-0.5
  78. self.is_decoder = is_decoder
  79. self.is_causal = is_causal
  80. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  81. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  82. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  83. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  84. def forward(
  85. self,
  86. hidden_states: torch.Tensor,
  87. key_value_states: torch.Tensor | None = None,
  88. attention_mask: torch.Tensor | None = None,
  89. output_attentions: bool | None = False,
  90. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  91. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  92. **kwargs: Unpack[FlashAttentionKwargs],
  93. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  94. """Input shape: Batch x Time x Channel"""
  95. # if key_value_states are provided this layer is used as a cross-attention layer
  96. # for the decoder
  97. is_cross_attention = key_value_states is not None
  98. # determine input shapes
  99. input_shape = hidden_states.shape[:-1]
  100. hidden_shape = (*input_shape, -1, self.head_dim)
  101. # get query proj
  102. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  103. current_states = key_value_states if is_cross_attention else hidden_states
  104. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  105. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
  106. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
  107. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  108. self.config._attn_implementation, eager_attention_forward
  109. )
  110. attn_output, attn_weights = attention_interface(
  111. self,
  112. query_states,
  113. key_states,
  114. value_states,
  115. attention_mask,
  116. dropout=0.0 if not self.training else self.dropout,
  117. scaling=self.scaling,
  118. output_attentions=output_attentions,
  119. **kwargs,
  120. )
  121. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  122. attn_output = self.out_proj(attn_output)
  123. return attn_output, attn_weights, None
  124. class PatchTSTBatchNorm(nn.Module):
  125. """
  126. Compute batch normalization over the sequence length (time) dimension.
  127. """
  128. def __init__(self, config: PatchTSTConfig):
  129. super().__init__()
  130. self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
  131. def forward(self, inputs: torch.Tensor):
  132. """
  133. Parameters:
  134. inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
  135. input for Batch norm calculation
  136. Returns:
  137. `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
  138. """
  139. output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
  140. output = self.batchnorm(output)
  141. return output.transpose(1, 2)
  142. def random_masking(
  143. inputs: torch.Tensor,
  144. mask_ratio: float,
  145. unmasked_channel_indices: list | None = None,
  146. channel_consistent_masking: bool = False,
  147. mask_value: int = 0,
  148. ):
  149. """random_masking: Mask the input considering the control variables.
  150. Args:
  151. inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
  152. The input tensor to mask.
  153. mask_ratio (`float`):
  154. Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
  155. unmasked_channel_indices (list, *optional*):
  156. Indices of channels that will not be masked.
  157. channel_consistent_masking (bool, *optional*, defaults to `False`):
  158. When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
  159. across channels.
  160. mask_value (int, *optional*, defaults to 0):
  161. Define the value of masked patches for pretraining.
  162. Returns:
  163. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
  164. n]
  165. """
  166. if mask_ratio < 0 or mask_ratio >= 1:
  167. raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
  168. batch_size, num_channels, sequence_length, num_features = inputs.shape
  169. device = inputs.device
  170. len_keep = int(sequence_length * (1 - mask_ratio))
  171. if channel_consistent_masking:
  172. noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L
  173. noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time
  174. else:
  175. # noise in [0, 1], bs x num_channels x L
  176. noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
  177. # mask: [bs x num_channels x num_patch]
  178. mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
  179. mask[:, :, :len_keep] = 0
  180. # sort noise for each sample
  181. ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove
  182. ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L]
  183. mask = torch.gather(mask, dim=-1, index=ids_restore)
  184. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length]
  185. if unmasked_channel_indices is not None:
  186. mask[:, unmasked_channel_indices, :, :] = 0
  187. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  188. return inputs_mask, mask[..., 0]
  189. def forecast_masking(
  190. inputs: torch.Tensor,
  191. num_forecast_mask_patches: list | int,
  192. unmasked_channel_indices: list | None = None,
  193. mask_value: int = 0,
  194. ):
  195. """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
  196. If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
  197. Parameters:
  198. inputs (`torch.Tensor`):
  199. Input of shape `(bs, num_channels, num_patch, patch_length)`
  200. num_forecast_mask_patches (`list`):
  201. Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
  202. unmasked_channel_indices (`list`, *optional*):
  203. Indices of channels that are not masked.
  204. mask_value (`int`, *optional*, defaults to 0):
  205. Values in the masked patches will be filled by `mask_value`.
  206. Returns:
  207. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
  208. num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
  209. """
  210. if isinstance(num_forecast_mask_patches, int):
  211. num_forecast_mask_patches = [num_forecast_mask_patches]
  212. forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
  213. batch_size, num_channels, sequence_length, num_features = inputs.shape
  214. mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
  215. t_list = []
  216. total_length = 0
  217. total_ratio = sum(forecast_mask_ratios)
  218. for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
  219. if patch_length <= 0 or patch_length >= sequence_length:
  220. raise ValueError(
  221. f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
  222. )
  223. temp_len = int(batch_size * ratio / total_ratio)
  224. t_list.append([patch_length, ratio, temp_len])
  225. total_length += temp_len
  226. t_list = sorted(t_list, key=lambda x: x[2])
  227. if total_length < batch_size:
  228. t_list[0][2] = t_list[0][2] + (batch_size - total_length)
  229. elif total_length > batch_size:
  230. t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
  231. batch1 = 0
  232. for patch_len, _, temp_len in t_list:
  233. batch2 = batch1 + temp_len
  234. mask[batch1:batch2, :, -patch_len:] = 1
  235. batch1 = batch2
  236. perm = torch.randperm(mask.shape[0])
  237. mask = mask[perm]
  238. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len]
  239. if unmasked_channel_indices is not None:
  240. mask[:, unmasked_channel_indices, :, :] = 0
  241. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  242. return inputs_mask, mask[..., 0]
  243. class PatchTSTPatchify(nn.Module):
  244. """
  245. A class to patchify the time series sequence into different patches
  246. Returns:
  247. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  248. """
  249. def __init__(self, config: PatchTSTConfig):
  250. super().__init__()
  251. self.sequence_length = config.context_length
  252. self.patch_length = config.patch_length
  253. self.patch_stride = config.patch_stride
  254. if self.sequence_length <= self.patch_length:
  255. raise ValueError(
  256. f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
  257. )
  258. # get the number of patches
  259. self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
  260. new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
  261. self.sequence_start = self.sequence_length - new_sequence_length
  262. def forward(self, past_values: torch.Tensor):
  263. """
  264. Parameters:
  265. past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
  266. Input for patchification
  267. Returns:
  268. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  269. """
  270. sequence_length = past_values.shape[-2]
  271. if sequence_length != self.sequence_length:
  272. raise ValueError(
  273. f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
  274. )
  275. # output: [bs x new_sequence_length x num_channels]
  276. output = past_values[:, self.sequence_start :, :]
  277. # output: [bs x num_patches x num_input_channels x patch_length]
  278. output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
  279. # output: [bs x num_input_channels x num_patches x patch_length]
  280. output = output.transpose(-2, -3).contiguous()
  281. return output
  282. class PatchTSTMasking(nn.Module):
  283. """
  284. Class to perform random or forecast masking.
  285. Parameters:
  286. config (`PatchTSTConfig`): model config
  287. Returns:
  288. x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  289. Masked patched input
  290. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  291. Bool tensor indicating True on masked points
  292. """
  293. def __init__(self, config: PatchTSTConfig):
  294. super().__init__()
  295. self.random_mask_ratio = config.random_mask_ratio
  296. self.channel_consistent_masking = config.channel_consistent_masking
  297. self.mask_type = config.mask_type
  298. self.num_forecast_mask_patches = config.num_forecast_mask_patches
  299. self.unmasked_channel_indices = config.unmasked_channel_indices
  300. self.mask_value = config.mask_value
  301. if self.unmasked_channel_indices is not None:
  302. self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
  303. def forward(self, patch_input: torch.Tensor):
  304. """
  305. Parameters:
  306. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  307. Patch input
  308. Return:
  309. masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  310. Masked patched input
  311. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  312. Bool tensor indicating True on masked points
  313. """
  314. if self.mask_type == "random":
  315. masked_input, mask = random_masking(
  316. inputs=patch_input,
  317. mask_ratio=self.random_mask_ratio,
  318. unmasked_channel_indices=self.unmasked_channel_indices,
  319. channel_consistent_masking=self.channel_consistent_masking,
  320. mask_value=self.mask_value,
  321. )
  322. elif self.mask_type == "forecast":
  323. masked_input, mask = forecast_masking(
  324. inputs=patch_input,
  325. num_forecast_mask_patches=self.num_forecast_mask_patches,
  326. unmasked_channel_indices=self.unmasked_channel_indices,
  327. mask_value=self.mask_value,
  328. )
  329. else:
  330. raise ValueError(f"Invalid mask type {self.mask_type}.")
  331. # mask: [bs x num_input_channels x num_patch]
  332. mask = mask.bool()
  333. return masked_input, mask
  334. class PatchTSTEncoderLayer(nn.Module):
  335. """
  336. PatchTST encoder layer
  337. """
  338. def __init__(self, config: PatchTSTConfig):
  339. super().__init__()
  340. self.channel_attention = config.channel_attention
  341. self.self_attn = PatchTSTAttention(
  342. embed_dim=config.d_model,
  343. num_heads=config.num_attention_heads,
  344. dropout=config.attention_dropout,
  345. config=config,
  346. )
  347. # Add & Norm of the sublayer 1
  348. self.dropout_path1 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  349. if config.norm_type == "batchnorm":
  350. self.norm_sublayer1 = PatchTSTBatchNorm(config)
  351. elif config.norm_type == "layernorm":
  352. self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  353. else:
  354. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  355. # Add & Norm of the sublayer 2
  356. if self.channel_attention:
  357. self.dropout_path2 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  358. if config.norm_type == "batchnorm":
  359. self.norm_sublayer2 = PatchTSTBatchNorm(config)
  360. elif config.norm_type == "layernorm":
  361. self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  362. else:
  363. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  364. # Position-wise Feed-Forward
  365. self.ff = nn.Sequential(
  366. nn.Linear(config.d_model, config.ffn_dim, bias=config.bias),
  367. ACT2CLS[config.activation_function](),
  368. nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(),
  369. nn.Linear(config.ffn_dim, config.d_model, bias=config.bias),
  370. )
  371. # Add & Norm of sublayer 3
  372. self.dropout_path3 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  373. if config.norm_type == "batchnorm":
  374. self.norm_sublayer3 = PatchTSTBatchNorm(config)
  375. elif config.norm_type == "layernorm":
  376. self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  377. else:
  378. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  379. self.pre_norm = config.pre_norm
  380. def forward(self, hidden_state: torch.Tensor, output_attentions: bool | None = None):
  381. """
  382. Parameters:
  383. hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*):
  384. Past values of the time series
  385. output_attentions (`bool`, *optional*):
  386. Whether or not to return the output attention of all layers
  387. Return:
  388. `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`
  389. """
  390. batch_size, num_input_channels, sequence_length, d_model = hidden_state.shape
  391. # First sublayer: attention across time
  392. # hidden_states: [(bs*num_channels) x sequence_length x d_model]
  393. hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model)
  394. if self.pre_norm:
  395. ## Norm and Multi-Head attention and Add residual connection
  396. attn_output, attn_weights, _ = self.self_attn(
  397. hidden_states=self.norm_sublayer1(hidden_state), output_attentions=output_attentions
  398. )
  399. # Add: residual connection with residual dropout
  400. hidden_state = hidden_state + self.dropout_path1(attn_output)
  401. else:
  402. ## Multi-Head attention and Add residual connection and Norm - Standard Transformer from BERT
  403. attn_output, attn_weights, _ = self.self_attn(
  404. hidden_states=hidden_state, output_attentions=output_attentions
  405. )
  406. # hidden_states: [(bs*num_channels) x sequence_length x d_model]
  407. hidden_state = self.norm_sublayer1(hidden_state + self.dropout_path1(attn_output))
  408. # hidden_state: [bs x num_channels x sequence_length x d_model]
  409. hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model)
  410. # second sublayer: attention across variable at any given time
  411. if self.channel_attention:
  412. # hidden_state: [bs x sequence_length x num_channels x d_model]
  413. hidden_state = hidden_state.transpose(2, 1).contiguous()
  414. # hidden_state: [(bs*sequence_length) x num_channels x d_model]
  415. hidden_state = hidden_state.view(batch_size * sequence_length, num_input_channels, d_model)
  416. if self.pre_norm:
  417. ## Norm and Multi-Head attention and Add residual connection
  418. attn_output, channel_attn_weights, _ = self.self_attn(
  419. hidden_states=self.norm_sublayer2(hidden_state), output_attentions=output_attentions
  420. )
  421. # Add: residual connection with residual dropout
  422. hidden_state = hidden_state + self.dropout_path2(attn_output)
  423. else:
  424. ## Multi-Head attention and Add residual connection and Norm
  425. attn_output, channel_attn_weights, _ = self.self_attn(
  426. hidden_states=hidden_state, output_attentions=output_attentions
  427. )
  428. # hidden_states: [(bs*sequence_length) x num_channels x d_model]
  429. hidden_state = self.norm_sublayer2(hidden_state + self.dropout_path2(attn_output))
  430. # Reshape hidden state
  431. # hidden_state: [bs x sequence_length x num_channels x d_model]
  432. hidden_state = hidden_state.reshape(batch_size, sequence_length, num_input_channels, d_model)
  433. # hidden_state: [bs x num_channels x sequence_length x d_model]
  434. hidden_state = hidden_state.transpose(1, 2).contiguous()
  435. # Third sublayer: mixing across hidden
  436. # hidden_state: [(batch_size*num_channels) x sequence_length x d_model]
  437. hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model)
  438. if self.pre_norm:
  439. ## Norm and Position-wise Feed-Forward and Add residual connection
  440. # Add: residual connection with residual dropout
  441. hidden_state = hidden_state + self.dropout_path3(self.ff(self.norm_sublayer3(hidden_state)))
  442. else:
  443. ## Position-wise Feed-Forward and Add residual connection and Norm
  444. # Add: residual connection with residual dropout
  445. hidden_state = self.norm_sublayer3(hidden_state + self.dropout_path3(self.ff(hidden_state)))
  446. # [bs x num_channels x sequence_length x d_model]
  447. hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model)
  448. outputs = (hidden_state,)
  449. if output_attentions:
  450. outputs += (attn_weights, channel_attn_weights) if self.channel_attention else (attn_weights,)
  451. return outputs
  452. @auto_docstring
  453. class PatchTSTPreTrainedModel(PreTrainedModel):
  454. config: PatchTSTConfig
  455. base_model_prefix = "model"
  456. main_input_name = "past_values"
  457. input_modalities = ("time",)
  458. supports_gradient_checkpointing = False
  459. _supports_flash_attn = True
  460. _supports_sdpa = True
  461. _supports_flex_attn = True
  462. @torch.no_grad()
  463. def _init_weights(self, module: nn.Module):
  464. """
  465. Initialize weights
  466. """
  467. if isinstance(module, PatchTSTPositionalEncoding):
  468. # get the number of patches
  469. num_patches = (
  470. max(self.config.context_length, self.config.patch_length) - self.config.patch_length
  471. ) // self.config.patch_stride + 1
  472. # initialize cls_token
  473. if self.config.use_cls_token:
  474. init.normal_(module.cls_token, std=0.02)
  475. num_patches += 1
  476. # initialize positional encoding
  477. position_enc = module._init_pe(self.config, num_patches)
  478. if is_deepspeed_zero3_enabled():
  479. import deepspeed
  480. with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None):
  481. if module.position_enc.numel() > 0:
  482. init.copy_(module.position_enc, position_enc)
  483. else:
  484. init.copy_(module.position_enc, position_enc)
  485. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
  486. init.zeros_(module.bias)
  487. init.ones_(module.weight)
  488. if getattr(module, "running_mean", None) is not None:
  489. init.zeros_(module.running_mean)
  490. init.ones_(module.running_var)
  491. init.zeros_(module.num_batches_tracked)
  492. elif isinstance(module, nn.Linear):
  493. init.normal_(module.weight, mean=0.0, std=self.config.init_std)
  494. if module.bias is not None:
  495. init.zeros_(module.bias)
  496. def _set_gradient_checkpointing(self, module, value=False):
  497. if isinstance(module, (PatchTSTEncoder)):
  498. module.gradient_checkpointing = value
  499. class PatchTSTEmbedding(nn.Module):
  500. def __init__(self, config: PatchTSTConfig):
  501. super().__init__()
  502. self.num_input_channels = config.num_input_channels
  503. self.share_embedding = config.share_embedding
  504. # Input encoding: projection of feature vectors onto a d-dim vector space
  505. if self.share_embedding:
  506. self.input_embedding = nn.Linear(config.patch_length, config.d_model)
  507. else:
  508. self.input_embedding = nn.ModuleList()
  509. for _ in range(config.num_input_channels):
  510. self.input_embedding.append(nn.Linear(config.patch_length, config.d_model))
  511. def forward(self, patch_input: torch.Tensor):
  512. """
  513. Parameters:
  514. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  515. Patch input for embedding
  516. return:
  517. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
  518. """
  519. # Input encoding
  520. num_input_channels = patch_input.shape[1]
  521. if num_input_channels != self.num_input_channels:
  522. raise ValueError(
  523. f"The defined number of input channels ({self.num_input_channels}) in the config "
  524. f"has to be the same as the number of channels in the batch input ({num_input_channels})"
  525. )
  526. if self.share_embedding:
  527. embeddings = self.input_embedding(patch_input) # x: [bs x num_channels x num_patches x d_model]
  528. else:
  529. embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)]
  530. embeddings = torch.stack(embeddings, dim=1)
  531. return embeddings
  532. class PatchTSTPositionalEncoding(nn.Module):
  533. """
  534. Class for positional encoding
  535. """
  536. def __init__(self, config: PatchTSTConfig, num_patches: int):
  537. super().__init__()
  538. self.use_cls_token = config.use_cls_token
  539. self.num_input_channels = config.num_input_channels
  540. if config.use_cls_token:
  541. # cls_token: [1 x num_input_channels x 1 x d_model]
  542. self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model))
  543. num_patches += 1
  544. # positional encoding: [num_patches x d_model]
  545. self.position_enc = self._init_pe(config, num_patches)
  546. # Positional dropout
  547. self.positional_dropout = (
  548. nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity()
  549. )
  550. @staticmethod
  551. def _init_pe(config: PatchTSTConfig, num_patches: int) -> nn.Parameter:
  552. # Positional encoding
  553. if config.positional_encoding_type == "random":
  554. position_enc = nn.Parameter(torch.randn(num_patches, config.d_model), requires_grad=True)
  555. elif config.positional_encoding_type == "sincos":
  556. position_enc = torch.zeros(num_patches, config.d_model)
  557. position = torch.arange(0, num_patches).unsqueeze(1)
  558. div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
  559. position_enc[:, 0::2] = torch.sin(position * div_term)
  560. position_enc[:, 1::2] = torch.cos(position * div_term)
  561. position_enc = position_enc - position_enc.mean()
  562. position_enc = position_enc / (position_enc.std() * 10)
  563. position_enc = nn.Parameter(position_enc, requires_grad=False)
  564. else:
  565. raise ValueError(
  566. f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
  567. )
  568. return position_enc
  569. def forward(self, patch_input: torch.Tensor):
  570. if self.use_cls_token:
  571. # patch_input: [bs x num_channels x num_patches x d_model]
  572. patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :])
  573. # append cls token where cls_token: [1 x num_channels x 1 x d_model]
  574. cls_token = self.cls_token + self.position_enc[:1, :]
  575. # get the same copy of cls_token for all the samples in batch: [bs x num_channels x 1 x d_model]
  576. cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1)
  577. # hidden_state: [bs x num_channels x (num_patches+1) x d_model]
  578. hidden_state = torch.cat((cls_tokens, patch_input), dim=2)
  579. else:
  580. # hidden_state: [bs x num_channels x num_patches x d_model]
  581. hidden_state = self.positional_dropout(patch_input + self.position_enc)
  582. return hidden_state
  583. class PatchTSTEncoder(PatchTSTPreTrainedModel):
  584. """
  585. PatchTST Encoder
  586. """
  587. def __init__(self, config: PatchTSTConfig, num_patches: int):
  588. super().__init__(config)
  589. self.gradient_checkpointing = False
  590. # Input embedding: projection of feature vectors onto a d-dim vector space
  591. self.embedder = PatchTSTEmbedding(config)
  592. # Positional encoding
  593. self.positional_encoder = PatchTSTPositionalEncoding(config, num_patches)
  594. # Encoder
  595. self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.num_hidden_layers)])
  596. # Initialize weights and apply final processing
  597. self.post_init()
  598. def forward(
  599. self,
  600. patch_input: torch.Tensor,
  601. output_hidden_states: bool | None = None,
  602. output_attentions: bool | None = None,
  603. **kwargs,
  604. ) -> BaseModelOutput:
  605. """
  606. Parameters:
  607. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  608. Past values of the time series
  609. output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
  610. output_attentions (bool, optional): Indicates if attentions should be outputted.
  611. return:
  612. `BaseModelOutput`
  613. """
  614. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  615. output_hidden_states = (
  616. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  617. )
  618. # Input embedding
  619. patch_input = self.embedder(patch_input)
  620. # Positional encoding
  621. hidden_state = self.positional_encoder(patch_input)
  622. encoder_states = () if output_hidden_states else None
  623. all_attentions = () if output_attentions else None
  624. for encoder_layer in self.layers:
  625. if output_hidden_states:
  626. encoder_states = encoder_states + (hidden_state,)
  627. layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions)
  628. # get hidden state. hidden_state shape is [bs x num_channels x num_patches x d_model]
  629. # or [bs x num_channels x (num_patches+1) x d_model] if use cls_token
  630. hidden_state = layer_outputs[0]
  631. # append attention matrix at each layer
  632. if output_attentions:
  633. all_attentions = all_attentions + (layer_outputs[1],)
  634. # return past_values, hidden_states
  635. return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions)
  636. @dataclass
  637. @auto_docstring(
  638. custom_intro="""
  639. Base class for model's outputs, with potential hidden states.
  640. """
  641. )
  642. class PatchTSTModelOutput(ModelOutput):
  643. r"""
  644. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  645. Sequence of hidden-states at the output of the last layer of the model.
  646. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  647. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  648. one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
  649. the model at the output of each layer plus the optional initial embedding outputs.
  650. mask (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*):
  651. Bool masked tensor indicating which patches are masked
  652. loc (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
  653. Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  654. scale (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*):
  655. Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  656. patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  657. Patched input to the Transformer
  658. """
  659. last_hidden_state: torch.FloatTensor | None = None
  660. hidden_states: tuple[torch.FloatTensor] | None = None
  661. attentions: tuple[torch.FloatTensor] | None = None
  662. mask: torch.FloatTensor | None = None
  663. loc: torch.FloatTensor | None = None
  664. scale: torch.FloatTensor | None = None
  665. patch_input: torch.FloatTensor | None = None
  666. @dataclass
  667. @auto_docstring(
  668. custom_intro="""
  669. Output type of [`PatchTSTForPretraining`].
  670. """
  671. )
  672. class PatchTSTForPretrainingOutput(ModelOutput):
  673. r"""
  674. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  675. MSE loss.
  676. prediction_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  677. Prediction outputs of the time series modeling heads.
  678. """
  679. loss: torch.FloatTensor | None = None
  680. prediction_output: torch.FloatTensor | None = None
  681. hidden_states: tuple[torch.FloatTensor] | None = None
  682. attentions: tuple[torch.FloatTensor] | None = None
  683. @dataclass
  684. @auto_docstring(
  685. custom_intro="""
  686. Output type of [`PatchTSTForRegression`].
  687. """
  688. )
  689. class PatchTSTForRegressionOutput(ModelOutput):
  690. r"""
  691. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  692. MSE loss.
  693. regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  694. Regression outputs of the time series modeling heads.
  695. """
  696. loss: torch.FloatTensor | None = None
  697. regression_outputs: torch.FloatTensor | None = None
  698. hidden_states: tuple[torch.FloatTensor] | None = None
  699. attentions: tuple[torch.FloatTensor] | None = None
  700. @dataclass
  701. @auto_docstring(
  702. custom_intro="""
  703. Output type of [`PatchTSTForPrediction`].
  704. """
  705. )
  706. class PatchTSTForPredictionOutput(ModelOutput):
  707. r"""
  708. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  709. MSE loss.
  710. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
  711. Prediction outputs of the time series modeling heads.
  712. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  713. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  714. sequence_length)`.
  715. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  716. heads.
  717. loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  718. Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  719. scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  720. Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  721. """
  722. loss: torch.FloatTensor | None = None
  723. prediction_outputs: torch.FloatTensor | None = None
  724. hidden_states: tuple[torch.FloatTensor] | None = None
  725. attentions: tuple[torch.FloatTensor] | None = None
  726. loc: torch.FloatTensor | None = None
  727. scale: torch.FloatTensor | None = None
  728. @dataclass
  729. @auto_docstring(
  730. custom_intro="""
  731. Output type of [`PatchTSTForClassification`].
  732. """
  733. )
  734. class PatchTSTForClassificationOutput(ModelOutput):
  735. r"""
  736. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  737. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  738. (classification) loss.
  739. prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  740. Prediction scores of the PatchTST modeling head (scores before SoftMax).
  741. """
  742. loss: torch.FloatTensor | None = None
  743. prediction_logits: torch.FloatTensor | None = None
  744. hidden_states: tuple[torch.FloatTensor] | None = None
  745. attentions: tuple[torch.FloatTensor] | None = None
  746. @dataclass
  747. @auto_docstring(
  748. custom_intro="""
  749. Base class for time series model's predictions outputs that contains the sampled values from the chosen
  750. distribution.
  751. """
  752. )
  753. class SamplePatchTSTOutput(ModelOutput):
  754. r"""
  755. sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, num_targets)`):
  756. Sampled values from the chosen distribution.
  757. """
  758. sequences: torch.FloatTensor | None = None
  759. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
  760. def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
  761. """
  762. Computes the negative log likelihood loss from input distribution with respect to target.
  763. """
  764. return -input.log_prob(target)
  765. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
  766. def weighted_average(input_tensor: torch.Tensor, weights: torch.Tensor | None = None, dim=None) -> torch.Tensor:
  767. """
  768. Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
  769. meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
  770. Args:
  771. input_tensor (`torch.FloatTensor`):
  772. Input tensor, of which the average must be computed.
  773. weights (`torch.FloatTensor`, *optional*):
  774. Weights tensor, of the same shape as `input_tensor`.
  775. dim (`int`, *optional*):
  776. The dim along which to average `input_tensor`.
  777. Returns:
  778. `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
  779. """
  780. if weights is not None:
  781. weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
  782. sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
  783. return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
  784. else:
  785. return input_tensor.mean(dim=dim)
  786. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  787. class PatchTSTStdScaler(nn.Module):
  788. """
  789. Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
  790. subtracting from the mean and dividing by the standard deviation.
  791. """
  792. def __init__(self, config: PatchTSTConfig):
  793. super().__init__()
  794. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  795. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  796. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
  797. def forward(
  798. self, data: torch.Tensor, observed_indicator: torch.Tensor
  799. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  800. """
  801. Parameters:
  802. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  803. input for Batch norm calculation
  804. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  805. Calculating the scale on the observed indicator.
  806. Returns:
  807. tuple of `torch.Tensor` of shapes
  808. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  809. `(batch_size, 1, num_input_channels)`)
  810. """
  811. denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
  812. denominator = denominator.clamp_min(1.0)
  813. loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
  814. variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
  815. scale = torch.sqrt(variance + self.minimum_scale)
  816. return (data - loc) / scale, loc, scale
  817. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  818. class PatchTSTMeanScaler(nn.Module):
  819. """
  820. Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
  821. accordingly.
  822. """
  823. def __init__(self, config: PatchTSTConfig):
  824. super().__init__()
  825. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  826. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  827. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
  828. self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
  829. def forward(
  830. self, data: torch.Tensor, observed_indicator: torch.Tensor
  831. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  832. """
  833. Parameters:
  834. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  835. input for Batch norm calculation
  836. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  837. Calculating the scale on the observed indicator.
  838. Returns:
  839. tuple of `torch.Tensor` of shapes
  840. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  841. `(batch_size, 1, num_input_channels)`)
  842. """
  843. ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
  844. num_observed = observed_indicator.sum(self.dim, keepdim=True)
  845. scale = ts_sum / torch.clamp(num_observed, min=1)
  846. # If `default_scale` is provided, we use it, otherwise we use the scale
  847. # of the batch.
  848. if self.default_scale is None:
  849. batch_sum = ts_sum.sum(dim=0)
  850. batch_observations = torch.clamp(num_observed.sum(0), min=1)
  851. default_scale = torch.squeeze(batch_sum / batch_observations)
  852. else:
  853. default_scale = self.default_scale * torch.ones_like(scale)
  854. # apply default scale where there are no observations
  855. scale = torch.where(num_observed > 0, scale, default_scale)
  856. # ensure the scale is at least `self.minimum_scale`
  857. scale = torch.clamp(scale, min=self.minimum_scale)
  858. scaled_data = data / scale
  859. if not self.keepdim:
  860. scale = scale.squeeze(dim=self.dim)
  861. return scaled_data, torch.zeros_like(scale), scale
  862. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  863. class PatchTSTNOPScaler(nn.Module):
  864. """
  865. Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
  866. """
  867. def __init__(self, config: PatchTSTConfig):
  868. super().__init__()
  869. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  870. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  871. def forward(
  872. self, data: torch.Tensor, observed_indicator: torch.Tensor | None = None
  873. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  874. """
  875. Parameters:
  876. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  877. input for Batch norm calculation
  878. Returns:
  879. tuple of `torch.Tensor` of shapes
  880. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  881. `(batch_size, 1, num_input_channels)`)
  882. """
  883. scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  884. loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  885. return data, loc, scale
  886. class PatchTSTScaler(nn.Module):
  887. def __init__(self, config: PatchTSTConfig):
  888. super().__init__()
  889. if config.scaling == "mean" or config.scaling is True:
  890. self.scaler = PatchTSTMeanScaler(config)
  891. elif config.scaling == "std":
  892. self.scaler = PatchTSTStdScaler(config)
  893. else:
  894. self.scaler = PatchTSTNOPScaler(config)
  895. def forward(
  896. self, data: torch.Tensor, observed_indicator: torch.Tensor
  897. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  898. """
  899. Parameters:
  900. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  901. Input for scaler calculation
  902. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  903. Calculating the scale on the observed indicator.
  904. Returns:
  905. tuple of `torch.Tensor` of shapes
  906. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  907. `(batch_size, 1, um_input_channels)`)
  908. """
  909. data, loc, scale = self.scaler(data, observed_indicator)
  910. return data, loc, scale
  911. @auto_docstring
  912. class PatchTSTModel(PatchTSTPreTrainedModel):
  913. def __init__(self, config: PatchTSTConfig):
  914. super().__init__(config)
  915. self.scaler = PatchTSTScaler(config)
  916. self.patchifier = PatchTSTPatchify(config)
  917. self.do_mask_input = config.do_mask_input
  918. # get num_patches information from PatchTSTPatchify
  919. num_patches = self.patchifier.num_patches
  920. if self.do_mask_input:
  921. self.masking = PatchTSTMasking(config)
  922. else:
  923. self.masking = nn.Identity()
  924. self.encoder = PatchTSTEncoder(config, num_patches=num_patches)
  925. # Initialize weights and apply final processing
  926. self.post_init()
  927. def forward(
  928. self,
  929. past_values: torch.Tensor,
  930. past_observed_mask: torch.Tensor | None = None,
  931. future_values: torch.Tensor | None = None,
  932. output_hidden_states: bool | None = None,
  933. output_attentions: bool | None = None,
  934. return_dict: bool | None = None,
  935. **kwargs,
  936. ) -> tuple | PatchTSTModelOutput:
  937. r"""
  938. Parameters:
  939. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  940. Input sequence to the model
  941. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  942. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  943. in `[0, 1]`:
  944. - 1 for values that are **observed**,
  945. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  946. future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*):
  947. Future target values associated with the `past_values`
  948. output_hidden_states (`bool`, *optional*):
  949. Whether or not to return the hidden states of all layers
  950. output_attentions (`bool`, *optional*):
  951. Whether or not to return the output attention of all layers
  952. return_dict (`bool`, *optional*):
  953. Whether or not to return a `ModelOutput` instead of a plain tuple.
  954. Returns:
  955. `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False)
  956. Examples:
  957. ```python
  958. >>> from huggingface_hub import hf_hub_download
  959. >>> import torch
  960. >>> from transformers import PatchTSTModel
  961. >>> file = hf_hub_download(
  962. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  963. ... )
  964. >>> batch = torch.load(file)
  965. >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain")
  966. >>> # during training, one provides both past and future values
  967. >>> outputs = model(
  968. ... past_values=batch["past_values"],
  969. ... future_values=batch["future_values"],
  970. ... )
  971. >>> last_hidden_state = outputs.last_hidden_state
  972. ```"""
  973. return_dict = return_dict if return_dict is not None else self.config.return_dict
  974. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  975. output_hidden_states = (
  976. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  977. )
  978. if past_observed_mask is None:
  979. past_observed_mask = torch.ones_like(past_values)
  980. # x: tensor [bs x sequence_length x num_input_channels]
  981. scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask)
  982. # patched_values: [bs x num_input_channels x num_patches x patch_length] for pretrain
  983. patched_values = self.patchifier(scaled_past_values)
  984. if self.do_mask_input:
  985. masked_values, mask = self.masking(patched_values)
  986. else:
  987. masked_values, mask = self.masking(patched_values), None
  988. encoder_output = self.encoder(
  989. patch_input=masked_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  990. )
  991. if not return_dict:
  992. outputs = (encoder_output.last_hidden_state, encoder_output.hidden_states, encoder_output.attentions)
  993. outputs = outputs + (mask, loc, scale, patched_values)
  994. return tuple(v for v in outputs if v is not None)
  995. return PatchTSTModelOutput(
  996. last_hidden_state=encoder_output.last_hidden_state,
  997. hidden_states=encoder_output.hidden_states,
  998. attentions=encoder_output.attentions,
  999. mask=mask,
  1000. loc=loc,
  1001. scale=scale,
  1002. patch_input=patched_values,
  1003. )
  1004. class PatchTSTMaskPretrainHead(nn.Module):
  1005. """
  1006. Pretraining head for mask modelling
  1007. """
  1008. def __init__(self, config: PatchTSTConfig):
  1009. super().__init__()
  1010. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1011. self.linear = nn.Linear(config.d_model, config.patch_length)
  1012. self.use_cls_token = config.use_cls_token
  1013. def forward(self, embedding: torch.Tensor) -> torch.Tensor:
  1014. """
  1015. Parameters:
  1016. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1017. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1018. Embedding from the model
  1019. Returns:
  1020. `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1021. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True
  1022. """
  1023. embedding = self.linear(self.dropout(embedding)) # [bs x num_channels x num_patches x patch_length]
  1024. if self.use_cls_token:
  1025. embedding = embedding[:, :, 1:, :] # remove the first cls token
  1026. return embedding
  1027. @auto_docstring(
  1028. custom_intro="""
  1029. The PatchTST for pretrain model.
  1030. """
  1031. )
  1032. class PatchTSTForPretraining(PatchTSTPreTrainedModel):
  1033. def __init__(self, config: PatchTSTConfig):
  1034. super().__init__(config)
  1035. config.do_mask_input = True
  1036. self.model = PatchTSTModel(config=config)
  1037. self.head = PatchTSTMaskPretrainHead(config)
  1038. # Initialize weights and apply final processing
  1039. self.post_init()
  1040. def forward(
  1041. self,
  1042. past_values: torch.Tensor,
  1043. past_observed_mask: torch.Tensor | None = None,
  1044. output_hidden_states: bool | None = None,
  1045. output_attentions: bool | None = None,
  1046. return_dict: bool | None = None,
  1047. **kwargs,
  1048. ) -> tuple | PatchTSTForPretrainingOutput:
  1049. r"""
  1050. Parameters:
  1051. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1052. Input sequence to the model
  1053. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1054. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1055. in `[0, 1]`:
  1056. - 1 for values that are **observed**,
  1057. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1058. output_hidden_states (`bool`, *optional*):
  1059. Whether or not to return the hidden states of all layers
  1060. output_attentions (`bool`, *optional*):
  1061. Whether or not to return the output attention of all layers
  1062. return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple.
  1063. Returns:
  1064. `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1065. `config.return_dict`=False)
  1066. Examples:
  1067. ```python
  1068. >>> from huggingface_hub import hf_hub_download
  1069. >>> import torch
  1070. >>> from transformers import PatchTSTConfig, PatchTSTForPretraining
  1071. >>> file = hf_hub_download(
  1072. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  1073. ... )
  1074. >>> batch = torch.load(file)
  1075. >>> # Config for random mask pretraining
  1076. >>> config = PatchTSTConfig(
  1077. ... num_input_channels=7,
  1078. ... context_length=512,
  1079. ... patch_length=12,
  1080. ... stride=12,
  1081. ... mask_type='random',
  1082. ... random_mask_ratio=0.4,
  1083. ... use_cls_token=True,
  1084. ... )
  1085. >>> # Config for forecast mask pretraining
  1086. >>> config = PatchTSTConfig(
  1087. ... num_input_channels=7,
  1088. ... context_length=512,
  1089. ... patch_length=12,
  1090. ... stride=12,
  1091. ... mask_type='forecast',
  1092. ... num_forecast_mask_patches=5,
  1093. ... use_cls_token=True,
  1094. ... )
  1095. >>> model = PatchTSTForPretraining(config)
  1096. >>> # during training, one provides both past and future values
  1097. >>> outputs = model(past_values=batch["past_values"])
  1098. >>> loss = outputs.loss
  1099. >>> loss.backward()
  1100. ```"""
  1101. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1102. # past_values: [bs x num_channels x num_patches x d_model] or
  1103. # [bs x num_channels x (num_patches+1) x d_model] if use cls_token
  1104. model_output = self.model(
  1105. past_values=past_values,
  1106. past_observed_mask=past_observed_mask,
  1107. output_hidden_states=output_hidden_states,
  1108. output_attentions=output_attentions,
  1109. return_dict=True,
  1110. )
  1111. # last_hidden_state: [bs x num_channels x num_patches x patch_length] or
  1112. # [bs x num_channels x (num_patches+1) x patch_length] if use cls_token
  1113. x_hat = self.head(model_output.last_hidden_state)
  1114. # calculate masked_loss
  1115. loss = nn.MSELoss(reduction="none")
  1116. loss_val = loss(x_hat, model_output.patch_input)
  1117. masked_loss = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
  1118. encoder_states = model_output.hidden_states
  1119. if not return_dict:
  1120. outputs = (x_hat,) + model_output[1:-4]
  1121. outputs = (masked_loss,) + outputs if masked_loss is not None else outputs
  1122. return outputs
  1123. return PatchTSTForPretrainingOutput(
  1124. loss=masked_loss, prediction_output=x_hat, hidden_states=encoder_states, attentions=model_output.attentions
  1125. )
  1126. class PatchTSTClassificationHead(nn.Module):
  1127. def __init__(self, config: PatchTSTConfig):
  1128. super().__init__()
  1129. self.use_cls_token = config.use_cls_token
  1130. self.pooling_type = config.pooling_type
  1131. self.flatten = nn.Flatten(start_dim=1)
  1132. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1133. self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets)
  1134. def forward(self, embedding: torch.Tensor):
  1135. """
  1136. Parameters:
  1137. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1138. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1139. Embedding from the model
  1140. Returns:
  1141. `torch.Tensor` of shape `(bs, num_targets)`
  1142. """
  1143. if self.use_cls_token:
  1144. # use the first output token, pooled_embedding: bs x num_channels x d_model
  1145. pooled_embedding = embedding[:, :, 0, :]
  1146. elif self.pooling_type == "mean":
  1147. # pooled_embedding: [bs x num_channels x d_model]
  1148. pooled_embedding = embedding.mean(dim=2)
  1149. elif self.pooling_type == "max":
  1150. # pooled_embedding: [bs x num_channels x d_model]
  1151. pooled_embedding = embedding.max(dim=2).values
  1152. else:
  1153. raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
  1154. # pooled_embedding: bs x num_channels * d_model
  1155. pooled_embedding = self.flatten(pooled_embedding)
  1156. # output: bs x n_classes
  1157. output = self.linear(self.dropout(pooled_embedding))
  1158. return output
  1159. @auto_docstring(
  1160. custom_intro="""
  1161. The PatchTST for classification model.
  1162. """
  1163. )
  1164. class PatchTSTForClassification(PatchTSTPreTrainedModel):
  1165. def __init__(self, config: PatchTSTConfig):
  1166. super().__init__(config)
  1167. # Turn off masking
  1168. if config.do_mask_input:
  1169. logger.warning("Setting `do_mask_input` parameter to False.")
  1170. config.do_mask_input = False
  1171. self.model = PatchTSTModel(config)
  1172. self.head = PatchTSTClassificationHead(config)
  1173. # Initialize weights and apply final processing
  1174. self.post_init()
  1175. @auto_docstring
  1176. def forward(
  1177. self,
  1178. past_values: torch.Tensor,
  1179. target_values: torch.Tensor | None = None,
  1180. past_observed_mask: bool | None = None,
  1181. output_hidden_states: bool | None = None,
  1182. output_attentions: bool | None = None,
  1183. return_dict: bool | None = None,
  1184. **kwargs,
  1185. ) -> tuple | PatchTSTForClassificationOutput:
  1186. r"""
  1187. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1188. Input sequence to the model
  1189. target_values (`torch.Tensor`, *optional*):
  1190. Labels associates with the `past_values`
  1191. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1192. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1193. in `[0, 1]`:
  1194. - 1 for values that are **observed**,
  1195. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1196. Examples:
  1197. ```python
  1198. >>> from transformers import PatchTSTConfig, PatchTSTForClassification
  1199. >>> # classification task with two input channel2 and 3 classes
  1200. >>> config = PatchTSTConfig(
  1201. ... num_input_channels=2,
  1202. ... num_targets=3,
  1203. ... context_length=512,
  1204. ... patch_length=12,
  1205. ... stride=12,
  1206. ... use_cls_token=True,
  1207. ... )
  1208. >>> model = PatchTSTForClassification(config=config)
  1209. >>> # during inference, one only provides past values
  1210. >>> past_values = torch.randn(20, 512, 2)
  1211. >>> outputs = model(past_values=past_values)
  1212. >>> labels = outputs.prediction_logits
  1213. ```"""
  1214. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1215. model_output = self.model(
  1216. past_values=past_values,
  1217. past_observed_mask=past_observed_mask,
  1218. output_hidden_states=output_hidden_states,
  1219. output_attentions=output_attentions,
  1220. return_dict=True,
  1221. )
  1222. y_hat = self.head(model_output.last_hidden_state)
  1223. loss_val = None
  1224. if target_values is not None:
  1225. loss = nn.CrossEntropyLoss()
  1226. loss_val = loss(y_hat, target_values)
  1227. if not return_dict:
  1228. outputs = (y_hat,) + model_output[1:-3]
  1229. outputs = (loss_val,) + outputs if loss_val is not None else outputs
  1230. return outputs
  1231. return PatchTSTForClassificationOutput(
  1232. loss=loss_val,
  1233. prediction_logits=y_hat,
  1234. hidden_states=model_output.hidden_states,
  1235. attentions=model_output.attentions,
  1236. )
  1237. @auto_docstring(
  1238. custom_intro="""
  1239. The PatchTST for regression Model.
  1240. """
  1241. )
  1242. class PatchTSTPredictionHead(nn.Module):
  1243. def __init__(self, config: PatchTSTConfig, num_patches: int, distribution_output=None):
  1244. r"""
  1245. num_patches (`int`):
  1246. The number of patches in the input sequence.
  1247. distribution_output (`DistributionOutput`, *optional*):
  1248. The distribution output layer for probabilistic forecasting. If None, a linear output layer is used.
  1249. """
  1250. super().__init__()
  1251. self.share_projection = config.share_projection
  1252. self.num_input_channels = config.num_input_channels
  1253. self.use_cls_token = config.use_cls_token
  1254. self.pooling_type = config.pooling_type
  1255. if self.pooling_type or self.use_cls_token:
  1256. head_dim = config.d_model
  1257. else:
  1258. head_dim = config.d_model * num_patches
  1259. if not self.share_projection:
  1260. # if each channel has its own head
  1261. self.projections = nn.ModuleList()
  1262. self.dropouts = nn.ModuleList()
  1263. self.flattens = nn.ModuleList()
  1264. for i in range(self.num_input_channels):
  1265. self.flattens.append(nn.Flatten(start_dim=2))
  1266. if distribution_output is None:
  1267. # use linear head
  1268. self.projections.append(nn.Linear(head_dim, config.prediction_length))
  1269. else:
  1270. # use distribution head
  1271. self.projections.append(distribution_output.get_parameter_projection(head_dim))
  1272. self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity())
  1273. else:
  1274. # all the channels share the same head
  1275. self.flatten = nn.Flatten(start_dim=2)
  1276. if distribution_output is None:
  1277. # use linear head
  1278. self.projection = nn.Linear(head_dim, config.prediction_length)
  1279. else:
  1280. # use distribution head
  1281. self.projection = distribution_output.get_parameter_projection(head_dim)
  1282. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1283. def forward(self, embedding: torch.Tensor):
  1284. """
  1285. Parameters:
  1286. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1287. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1288. Embedding from the model
  1289. Returns:
  1290. `torch.Tensor` of shape `(bs, forecast_len, num_channels)`
  1291. """
  1292. if self.use_cls_token:
  1293. # pooled_embedding: [bs x num_channels x d_model]
  1294. pooled_embedding = embedding[:, :, 0, :]
  1295. else:
  1296. if self.pooling_type == "mean":
  1297. # pooled_embedding: [bs x num_channels x d_model]
  1298. pooled_embedding = embedding.mean(dim=2)
  1299. elif self.pooling_type == "max":
  1300. # pooled_embedding: [bs x num_channels x d_model]
  1301. pooled_embedding = embedding.max(dim=2).values
  1302. else:
  1303. # pooled_embedding: [bs x num_channels x num_patches x d_model]
  1304. pooled_embedding = embedding
  1305. if not self.share_projection:
  1306. output = []
  1307. for i in range(self.num_input_channels):
  1308. # pooled_embedding: [bs x (d_model * num_patches)] or [bs x d_model)]
  1309. pooled_embedding = self.flattens[i](pooled_embedding[:, i, :])
  1310. pooled_embedding = self.dropouts[i](pooled_embedding)
  1311. # pooled_embedding: [bs x forecast_len]
  1312. # or tuple ([bs x forecast_len], [bs x forecast_len]) if using distribution head
  1313. pooled_embedding = self.projections[i](pooled_embedding)
  1314. output.append(pooled_embedding)
  1315. # output: [bs x num_channels x forecast_len]
  1316. output = torch.stack(output, dim=1)
  1317. else:
  1318. # pooled_embedding: [bs x num_channels x (d_model * num_patches)] or [bs x num_channels x d_model)]
  1319. pooled_embedding = self.flatten(pooled_embedding)
  1320. pooled_embedding = self.dropout(pooled_embedding)
  1321. # output: [bs x num_channels x forecast_len] or
  1322. # tuple ([bs x num_channels x forecast_len], [bs x num_channels x forecast_len]) if using distribution head
  1323. output = self.projection(pooled_embedding)
  1324. if isinstance(output, tuple):
  1325. # output: ([bs x forecast_len x num_channels], [bs x forecast_len x num_channels])
  1326. output = tuple(z.transpose(2, 1) for z in output)
  1327. else:
  1328. output = output.transpose(2, 1) # [bs x forecast_len x num_channels]
  1329. return output
  1330. @auto_docstring(
  1331. custom_intro="""
  1332. The PatchTST for prediction model.
  1333. """
  1334. )
  1335. class PatchTSTForPrediction(PatchTSTPreTrainedModel):
  1336. def __init__(self, config: PatchTSTConfig):
  1337. super().__init__(config)
  1338. # Turn off masking
  1339. if config.do_mask_input:
  1340. logger.warning("Setting `do_mask_input` parameter to False.")
  1341. config.do_mask_input = False
  1342. self.model = PatchTSTModel(config)
  1343. if config.loss == "mse":
  1344. self.distribution_output = None
  1345. else:
  1346. if config.distribution_output == "student_t":
  1347. self.distribution_output = StudentTOutput(dim=config.prediction_length)
  1348. elif config.distribution_output == "normal":
  1349. self.distribution_output = NormalOutput(dim=config.prediction_length)
  1350. elif config.distribution_output == "negative_binomial":
  1351. self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length)
  1352. else:
  1353. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1354. self.head = PatchTSTPredictionHead(
  1355. config, self.model.patchifier.num_patches, distribution_output=self.distribution_output
  1356. )
  1357. # Initialize weights and apply final processing
  1358. self.post_init()
  1359. def forward(
  1360. self,
  1361. past_values: torch.Tensor,
  1362. past_observed_mask: torch.Tensor | None = None,
  1363. future_values: torch.Tensor | None = None,
  1364. output_hidden_states: bool | None = None,
  1365. output_attentions: bool | None = None,
  1366. return_dict: bool | None = None,
  1367. **kwargs,
  1368. ) -> tuple | PatchTSTForPredictionOutput:
  1369. r"""
  1370. Parameters:
  1371. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1372. Input sequence to the model
  1373. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1374. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1375. in `[0, 1]`:
  1376. - 1 for values that are **observed**,
  1377. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1378. future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*):
  1379. Future target values associated with the `past_values`
  1380. output_hidden_states (`bool`, *optional*):
  1381. Whether or not to return the hidden states of all layers
  1382. output_attentions (`bool`, *optional*):
  1383. Whether or not to return the output attention of all layers
  1384. return_dict (`bool`, *optional*):
  1385. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1386. Returns:
  1387. `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1388. `config.return_dict`=False)
  1389. Examples:
  1390. ```python
  1391. >>> from huggingface_hub import hf_hub_download
  1392. >>> import torch
  1393. >>> from transformers import PatchTSTConfig, PatchTSTForPrediction
  1394. >>> file = hf_hub_download(
  1395. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  1396. ... )
  1397. >>> batch = torch.load(file)
  1398. >>> # Prediction task with 7 input channels and prediction length is 96
  1399. >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast")
  1400. >>> # during training, one provides both past and future values
  1401. >>> outputs = model(
  1402. ... past_values=batch["past_values"],
  1403. ... future_values=batch["future_values"],
  1404. ... )
  1405. >>> loss = outputs.loss
  1406. >>> loss.backward()
  1407. >>> # during inference, one only provides past values, the model outputs future values
  1408. >>> outputs = model(past_values=batch["past_values"])
  1409. >>> prediction_outputs = outputs.prediction_outputs
  1410. ```"""
  1411. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1412. # get model output
  1413. model_output = self.model(
  1414. past_values=past_values,
  1415. past_observed_mask=past_observed_mask,
  1416. output_hidden_states=output_hidden_states,
  1417. output_attentions=output_attentions,
  1418. return_dict=True,
  1419. )
  1420. # get output head
  1421. y_hat = self.head(model_output.last_hidden_state)
  1422. loss_val = None
  1423. if self.distribution_output:
  1424. y_hat_out = y_hat
  1425. else:
  1426. y_hat_out = y_hat * model_output.scale + model_output.loc
  1427. if future_values is not None:
  1428. if self.distribution_output:
  1429. distribution = self.distribution_output.distribution(
  1430. y_hat, loc=model_output.loc, scale=model_output.scale
  1431. )
  1432. loss_val = nll(distribution, future_values)
  1433. # take average of the loss
  1434. loss_val = weighted_average(loss_val)
  1435. else:
  1436. loss = nn.MSELoss(reduction="mean")
  1437. loss_val = loss(y_hat_out, future_values)
  1438. loc = model_output.loc
  1439. scale = model_output.scale
  1440. if not return_dict:
  1441. outputs = (y_hat_out,) + model_output[1:-1]
  1442. outputs = (loss_val,) + outputs if loss_val is not None else outputs
  1443. return outputs
  1444. return PatchTSTForPredictionOutput(
  1445. loss=loss_val,
  1446. prediction_outputs=y_hat_out,
  1447. hidden_states=model_output.hidden_states,
  1448. attentions=model_output.attentions,
  1449. loc=loc,
  1450. scale=scale,
  1451. )
  1452. @torch.no_grad()
  1453. def generate(
  1454. self,
  1455. past_values: torch.Tensor,
  1456. past_observed_mask: torch.Tensor | None = None,
  1457. ) -> SamplePatchTSTOutput:
  1458. """
  1459. Generate sequences of sample predictions from a model with a probability distribution head.
  1460. Parameters:
  1461. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1462. Past values of the time series that serves as context in order to predict the future.
  1463. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1464. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1465. in `[0, 1]`:
  1466. - 1 for values that are **observed**,
  1467. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1468. Return:
  1469. [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
  1470. samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
  1471. for multivariate predictions.
  1472. """
  1473. # get number of samples
  1474. num_parallel_samples = self.config.num_parallel_samples
  1475. # get model output
  1476. outputs = self(
  1477. past_values=past_values,
  1478. future_values=None,
  1479. past_observed_mask=past_observed_mask,
  1480. output_hidden_states=False,
  1481. )
  1482. if self.distribution_output:
  1483. # get distribution
  1484. distribution = self.distribution_output.distribution(
  1485. outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
  1486. )
  1487. # get samples: list of [bs x forecast_len x num_channels]
  1488. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1489. # samples: [bs x num_samples x forecast_len x num_channels]
  1490. samples = torch.stack(samples, dim=1)
  1491. else:
  1492. samples = outputs.prediction_outputs.unsqueeze(1)
  1493. return SamplePatchTSTOutput(sequences=samples)
  1494. class PatchTSTRegressionHead(nn.Module):
  1495. """
  1496. Regression head
  1497. """
  1498. def __init__(self, config: PatchTSTConfig, distribution_output=None):
  1499. super().__init__()
  1500. self.y_range = config.output_range
  1501. self.use_cls_token = config.use_cls_token
  1502. self.pooling_type = config.pooling_type
  1503. self.distribution_output = distribution_output
  1504. head_dim = config.num_input_channels * config.d_model
  1505. self.flatten = nn.Flatten(start_dim=1)
  1506. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1507. if distribution_output is None:
  1508. self.projection = nn.Linear(head_dim, config.num_targets)
  1509. else:
  1510. self.projection = distribution_output.get_parameter_projection(head_dim)
  1511. def forward(self, embedding: torch.Tensor):
  1512. """
  1513. Parameters:
  1514. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1515. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1516. Embedding from the model
  1517. Returns:
  1518. `torch.Tensor` of shape `(bs, output_dim)`
  1519. """
  1520. if self.use_cls_token:
  1521. # use the first output token, pooled_embedding: [bs x num_channels x d_model]
  1522. pooled_embedding = embedding[:, :, 0, :]
  1523. elif self.pooling_type == "mean":
  1524. # pooled_embedding: [bs x num_channels x d_model]
  1525. pooled_embedding = embedding.mean(dim=2)
  1526. elif self.pooling_type == "max":
  1527. # pooled_embedding: [bs x num_channels x d_model]
  1528. pooled_embedding = embedding.max(dim=2).values
  1529. else:
  1530. raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
  1531. # flatten the input
  1532. # pooled_embedding: bs x (num_channels * d_model)
  1533. pooled_embedding = self.dropout(self.flatten(pooled_embedding))
  1534. # projection
  1535. # output: bs x output_dim or a tuple of this shape for distribution head
  1536. output = self.projection(pooled_embedding)
  1537. # apply sigmoid to bound the output if required
  1538. if (self.distribution_output is None) & (self.y_range is not None): # linear head
  1539. output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0]
  1540. return output
  1541. @auto_docstring(
  1542. custom_intro="""
  1543. The PatchTST for regression model.
  1544. """
  1545. )
  1546. class PatchTSTForRegression(PatchTSTPreTrainedModel):
  1547. def __init__(self, config: PatchTSTConfig):
  1548. super().__init__(config)
  1549. # Turn off masking
  1550. if config.do_mask_input:
  1551. logger.warning("Setting `do_mask_input` parameter to False.")
  1552. config.do_mask_input = False
  1553. self.model = PatchTSTModel(config)
  1554. if config.loss == "mse":
  1555. self.distribution_output = None
  1556. else:
  1557. if config.distribution_output == "student_t":
  1558. self.distribution_output = StudentTOutput(dim=config.num_targets)
  1559. elif config.distribution_output == "normal":
  1560. self.distribution_output = NormalOutput(dim=config.num_targets)
  1561. elif config.distribution_output == "negative_binomial":
  1562. self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
  1563. else:
  1564. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1565. self.head = PatchTSTRegressionHead(config, self.distribution_output)
  1566. # Initialize weights and apply final processing
  1567. self.post_init()
  1568. @auto_docstring
  1569. def forward(
  1570. self,
  1571. past_values: torch.Tensor,
  1572. target_values: torch.Tensor | None = None,
  1573. past_observed_mask: torch.Tensor | None = None,
  1574. output_hidden_states: bool | None = None,
  1575. output_attentions: bool | None = None,
  1576. return_dict: bool | None = None,
  1577. **kwargs,
  1578. ) -> tuple | PatchTSTForRegressionOutput:
  1579. r"""
  1580. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1581. Input sequence to the model
  1582. target_values (`torch.Tensor` of shape `(bs, num_input_channels)`):
  1583. Target values associates with the `past_values`
  1584. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1585. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1586. in `[0, 1]`:
  1587. - 1 for values that are **observed**,
  1588. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1589. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1590. Examples:
  1591. ```python
  1592. >>> from transformers import PatchTSTConfig, PatchTSTForRegression
  1593. >>> # Regression task with 6 input channels and regress 2 targets
  1594. >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression")
  1595. >>> # during inference, one only provides past values, the model outputs future values
  1596. >>> past_values = torch.randn(20, 512, 6)
  1597. >>> outputs = model(past_values=past_values)
  1598. >>> regression_outputs = outputs.regression_outputs
  1599. ```"""
  1600. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1601. model_output = self.model(
  1602. past_values=past_values,
  1603. past_observed_mask=past_observed_mask,
  1604. output_hidden_states=output_hidden_states,
  1605. output_attentions=output_attentions,
  1606. return_dict=True,
  1607. )
  1608. # get output head. y_hat is of shape [bs x num_targets] or tuple of this shape
  1609. y_hat = self.head(model_output.last_hidden_state)
  1610. loss = None
  1611. if target_values is not None:
  1612. if self.distribution_output:
  1613. distribution = self.distribution_output.distribution(y_hat)
  1614. # y_hat should be a 2-tuple, each with dimension [bs, num_targets]
  1615. y_hat = tuple(item.view(-1, self.config.num_targets) for item in y_hat)
  1616. loss = nll(distribution, target_values)
  1617. # take average of the loss
  1618. loss = weighted_average(loss)
  1619. else:
  1620. loss = nn.MSELoss(reduction="mean")
  1621. loss = loss(y_hat, target_values)
  1622. if not return_dict:
  1623. # hidden_states, attentions, mask
  1624. outputs = (y_hat,) + model_output[1:-3]
  1625. outputs = (loss,) + outputs if loss is not None else outputs
  1626. return outputs
  1627. return PatchTSTForRegressionOutput(
  1628. loss=loss,
  1629. regression_outputs=y_hat,
  1630. hidden_states=model_output.hidden_states,
  1631. attentions=model_output.attentions,
  1632. )
  1633. @torch.no_grad()
  1634. def generate(
  1635. self,
  1636. past_values: torch.Tensor,
  1637. past_observed_mask: torch.Tensor | None = None,
  1638. ) -> SamplePatchTSTOutput:
  1639. """
  1640. Generate sequences of sample predictions from a model with a probability distribution head.
  1641. Parameters:
  1642. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1643. Past values of the time series that serves as context in order to predict the future.
  1644. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1645. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1646. in `[0, 1]`:
  1647. - 1 for values that are **observed**,
  1648. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1649. Return:
  1650. [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
  1651. samples, num_targets)`.
  1652. """
  1653. # get number of samples
  1654. num_parallel_samples = self.config.num_parallel_samples
  1655. # get model output
  1656. outputs = self(
  1657. past_values=past_values,
  1658. target_values=None,
  1659. past_observed_mask=past_observed_mask,
  1660. output_hidden_states=False,
  1661. )
  1662. # get distribution
  1663. distribution = self.distribution_output.distribution(outputs.regression_outputs)
  1664. # get samples: list of [bs x num_targets]
  1665. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1666. # samples: [bs x num_samples x num_targets]
  1667. samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
  1668. return SamplePatchTSTOutput(sequences=samples)
  1669. __all__ = [
  1670. "PatchTSTModel",
  1671. "PatchTSTPreTrainedModel",
  1672. "PatchTSTForPrediction",
  1673. "PatchTSTForPretraining",
  1674. "PatchTSTForRegression",
  1675. "PatchTSTForClassification",
  1676. ]