modeling_xlstm.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600
  1. # Copyright 2025 NXAI GmbH. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch xLSTM Model."""
  15. from dataclasses import dataclass
  16. import torch
  17. import torch.nn.functional as F
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...generation import GenerationMixin
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_utils import PreTrainedModel
  24. from ...processing_utils import Unpack
  25. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_xlstm_available
  26. from ...utils.generic import merge_with_config_defaults
  27. from ...utils.output_capturing import capture_outputs
  28. from .configuration_xlstm import xLSTMConfig
  29. if is_xlstm_available():
  30. from xlstm.xlstm_large.model import RMSNorm as xLSTMRMSNorm
  31. from xlstm.xlstm_large.model import mLSTMBlock, mLSTMStateType, soft_cap
  32. external_xlstm = True
  33. class xLSTMBlock(GradientCheckpointingLayer, mLSTMBlock):
  34. pass
  35. else:
  36. from collections.abc import Callable
  37. from functools import partial
  38. from typing import Literal
  39. from .configuration_xlstm import round_up_to_next_multiple_of
  40. mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
  41. mLSTMStateType = dict[int, mLSTMLayerStateType]
  42. external_xlstm = False
  43. def soft_cap(values: torch.Tensor, cap_value: float | torch.Tensor | None = None) -> torch.Tensor:
  44. """
  45. Soft caps a tensor to a value.
  46. Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention
  47. and output language heads to prevent large logits from dominating the softmax. See for example Gemma2:
  48. https://huggingface.co/papers/2408.00118
  49. Args:
  50. values: The tensor to cap.
  51. cap_value: The value to cap the values to. If None, no cap is applied.
  52. Returns:
  53. The capped values.
  54. """
  55. if cap_value is None:
  56. return values
  57. return cap_value * torch.tanh(values / cap_value)
  58. def mlstm_chunkwise_recurrent_fw_C(
  59. matK: torch.Tensor,
  60. matV: torch.Tensor,
  61. vecB: torch.Tensor,
  62. vecI: torch.Tensor,
  63. matC_states: torch.Tensor | None = None,
  64. vecN_states: torch.Tensor | None = None,
  65. scaMinter_states: torch.Tensor | None = None,
  66. matC_initial: torch.Tensor | None = None,
  67. vecN_initial: torch.Tensor | None = None,
  68. scaMinter_initial: torch.Tensor | None = None,
  69. qk_scale: float | None = None,
  70. chunk_size: int = 64,
  71. num_chunks: int = 1,
  72. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  73. batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1]
  74. nc = num_chunks
  75. _dtype, _device = matK.dtype, matK.device
  76. if qk_scale is None:
  77. qk_scale = dhqk**-0.5
  78. # initialize the states tensors
  79. if matC_states is None:
  80. matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device)
  81. if vecN_states is None:
  82. vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device)
  83. if scaMinter_states is None:
  84. scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device)
  85. # assign the initial states to the running states
  86. matC_k = (
  87. torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device)
  88. if matC_initial is None
  89. else matC_initial
  90. )
  91. vecN_k = (
  92. torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial
  93. )
  94. scaM_inter_k = (
  95. torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device)
  96. if scaMinter_initial is None
  97. else scaMinter_initial
  98. )
  99. vecA = vecB[..., -1, None] - vecB + vecI
  100. scaG = vecB[..., -1]
  101. scaA_max = vecA.max(-1).values
  102. scaM_inter_k = scaM_inter_k.squeeze(-1)
  103. for key in range(0, num_chunks):
  104. # store the states from the previous iteration before updating them
  105. # in the first iteration, these are the initial states
  106. matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k
  107. vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k
  108. scaMinter_states[:, :, key] = scaM_inter_k
  109. # m_k update
  110. scaA_max_k = scaA_max[:, :, key]
  111. scaG_k = scaG[:, :, key]
  112. scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k)
  113. # C_k update
  114. matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale
  115. matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :]
  116. vecA_k = vecA[:, :, key, :]
  117. vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None]
  118. matK_chunk_gated = matK_chunk * vecAbar_k
  119. scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
  120. # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
  121. matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
  122. # n_k update
  123. vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1)
  124. # move to the next iteration
  125. scaM_inter_k = scaM_inter_k_next
  126. matC_k = matC_k_next
  127. vecN_k = vecN_k_next
  128. # store the states from the last iteration
  129. matC_states[:, :, -dhqk:, :] = matC_k
  130. vecN_states[:, :, -dhqk:] = vecN_k
  131. scaMinter_states[:, :, -1] = scaM_inter_k
  132. return matC_states, vecN_states, scaMinter_states
  133. def mlstm_chunkwise_parallel_fw_H(
  134. matQ: torch.Tensor,
  135. matK: torch.Tensor,
  136. matV: torch.Tensor,
  137. # these states must be all states up to the last chunk, i.e. :-1
  138. matC_states: torch.Tensor,
  139. vecN_states: torch.Tensor,
  140. scaMinter_states: torch.Tensor,
  141. vecI: torch.Tensor,
  142. vecB: torch.Tensor,
  143. qk_scale: float,
  144. chunk_size: int = 64,
  145. num_chunks: int = 1,
  146. eps: float = 1e-6,
  147. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  148. _device = matQ.device
  149. nc = num_chunks
  150. batch_size, nh, dqk, dhv = matC_states.shape
  151. dhqk = dqk // nc
  152. matC_k_states = matC_states.view(batch_size, nh, nc, dhqk, dhv)
  153. vecN_k_states = vecN_states.view(batch_size, nh, nc, dhqk)
  154. scaMinter_k_states = scaMinter_states
  155. matQ = matQ.view(batch_size, nh, nc, chunk_size, dhqk)
  156. matK = matK.view(batch_size, nh, nc, chunk_size, dhqk)
  157. matV = matV.view(batch_size, nh, nc, chunk_size, dhv)
  158. ltr = torch.tril(
  159. torch.ones(
  160. (chunk_size, chunk_size),
  161. dtype=torch.bool,
  162. device=_device,
  163. )
  164. )
  165. # Compute intra chunk contribution: H_intra
  166. matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :]
  167. matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf"))
  168. matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :]
  169. # max_state intra
  170. vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values
  171. # max_state combined
  172. vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None]
  173. vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k)
  174. vecM_k_combine = vecM_k_combine[:, :, :, :, None]
  175. vecM_b_inter = vecM_b_inter[:, :, :, :, None]
  176. matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine
  177. matD_chunk = torch.exp(matLogD_stabilized_chunk)
  178. matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale
  179. matM_chunk = matS_chunk * matD_chunk
  180. # ? Combine H_intra with H_inter
  181. vecBbar = torch.exp(vecM_b_inter - vecM_k_combine)
  182. matQ_chunk_gated = matQ * vecBbar * qk_scale
  183. matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV
  184. vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True)
  185. vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine))
  186. matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps)
  187. matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv)
  188. # we need the denominator and the overall max state for the backward pass
  189. vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size)
  190. vecM_out = vecM_k_combine.reshape(batch_size, nh, nc * chunk_size)
  191. return matH_out, vecN_out, vecM_out
  192. def mlstm_chunkwise_fw(
  193. query: torch.Tensor,
  194. key: torch.Tensor,
  195. value: torch.Tensor,
  196. igate: torch.Tensor,
  197. fgate: torch.Tensor,
  198. cstate: torch.Tensor | None = None,
  199. nstate: torch.Tensor | None = None,
  200. mstate: torch.Tensor | None = None,
  201. qk_scale: float | None = None,
  202. return_last_states: bool = False,
  203. return_all_states: bool = False,
  204. chunk_size: int = 64,
  205. eps: float = 1e-6,
  206. ) -> tuple[
  207. torch.Tensor,
  208. torch.Tensor,
  209. torch.Tensor,
  210. tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
  211. tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
  212. ]:
  213. batch_size, nh, sequence_length, dhqk = query.shape
  214. if sequence_length % chunk_size != 0:
  215. raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
  216. nc = sequence_length // chunk_size
  217. vecI = igate.view(batch_size, nh, nc, chunk_size)
  218. vecF = fgate.view(batch_size, nh, nc, chunk_size)
  219. # compute the gates, the g and the a and b vectors
  220. vecF_logsig = fgate.logsigmoid(vecF)
  221. vecB = vecF_logsig.cumsum(-1)
  222. if qk_scale is None:
  223. qk_scale = dhqk**-0.5
  224. #! materialize the C_k, n_k, m_k states for each chunk
  225. matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
  226. matK=key,
  227. matV=value,
  228. vecB=vecB,
  229. vecI=vecI,
  230. matC_initial=cstate,
  231. vecN_initial=nstate,
  232. scaMinter_initial=mstate,
  233. qk_scale=qk_scale,
  234. chunk_size=chunk_size,
  235. num_chunks=nc,
  236. )
  237. #! compute the outputs within each chunk
  238. matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
  239. matQ=query,
  240. matK=key,
  241. matV=value,
  242. matC_states=matC_k_states[:, :, :-dhqk, :],
  243. vecN_states=vecN_k_states[:, :, :-dhqk],
  244. scaMinter_states=scaMinter_k_states[:, :, :-1],
  245. vecI=vecI,
  246. vecB=vecB,
  247. qk_scale=qk_scale,
  248. chunk_size=chunk_size,
  249. num_chunks=nc,
  250. eps=eps,
  251. )
  252. ret_tuple = (matH_out, vecN_out, vecM_out)
  253. if return_last_states:
  254. ret_tuple += (
  255. (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]),
  256. )
  257. else:
  258. ret_tuple += (None,)
  259. if return_all_states:
  260. ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),)
  261. else:
  262. ret_tuple += (None,)
  263. return ret_tuple
  264. def mlstm_chunkwise_native_autograd(
  265. query: torch.Tensor,
  266. key: torch.Tensor,
  267. value: torch.Tensor,
  268. igate: torch.Tensor,
  269. fgate: torch.Tensor,
  270. c_initial: torch.Tensor | None = None,
  271. n_initial: torch.Tensor | None = None,
  272. m_initial: torch.Tensor | None = None,
  273. return_last_states: bool = False,
  274. eps: float = 1e-6,
  275. chunk_size: int = 64,
  276. **kwargs,
  277. ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  278. batch_size, nh, sequence_length, dhqk = query.shape
  279. if sequence_length % chunk_size != 0:
  280. raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
  281. nc = sequence_length // chunk_size
  282. vecI = igate.view(batch_size, nh, nc, chunk_size)
  283. vecF = fgate.view(batch_size, nh, nc, chunk_size)
  284. # compute the gates, the g and the a and b vectors
  285. vecF_logsig = F.logsigmoid(vecF)
  286. vecB = vecF_logsig.cumsum(-1)
  287. qk_scale = dhqk**-0.5
  288. #! materialize the C_k, n_k, m_k states for each chunk
  289. matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
  290. matK=key,
  291. matV=value,
  292. vecB=vecB,
  293. vecI=vecI,
  294. matC_initial=c_initial,
  295. vecN_initial=n_initial,
  296. scaMinter_initial=m_initial,
  297. qk_scale=qk_scale,
  298. chunk_size=chunk_size,
  299. num_chunks=nc,
  300. )
  301. #! compute the outputs within each chunk
  302. matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
  303. matQ=query,
  304. matK=key,
  305. matV=value,
  306. matC_states=matC_k_states[:, :, :-dhqk, :],
  307. vecN_states=vecN_k_states[:, :, :-dhqk],
  308. scaMinter_states=scaMinter_k_states[:, :, :-1],
  309. vecI=vecI,
  310. vecB=vecB,
  311. qk_scale=qk_scale,
  312. chunk_size=chunk_size,
  313. num_chunks=nc,
  314. eps=eps,
  315. )
  316. last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:])
  317. if return_last_states:
  318. return matH_out, last_states
  319. else:
  320. return matH_out
  321. def mlstm_recurrent_step_native(
  322. query: torch.Tensor,
  323. key: torch.Tensor,
  324. value: torch.Tensor,
  325. igate: torch.Tensor,
  326. fgate: torch.Tensor,
  327. cstate: torch.Tensor,
  328. nstate: torch.Tensor,
  329. mstate: torch.Tensor,
  330. eps: float = 1e-6,
  331. dtype_state: torch.dtype = torch.float32,
  332. **kwargs,
  333. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  334. """This is a single step of the mLSTM operation in recurrent form."""
  335. dtype_qkv = query.dtype
  336. matC_old = cstate.to(dtype=dtype_state)
  337. vecN_old = nstate.to(dtype=dtype_state)
  338. scaM_old = mstate.to(dtype=dtype_state)
  339. batch_size, nh, dhqk = query.shape
  340. _, _, dhhv = value.shape
  341. if query.shape != key.shape:
  342. raise ValueError("query and key must have the same shape")
  343. if matC_old.shape != (batch_size, nh, dhqk, dhhv):
  344. raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
  345. if vecN_old.shape != (batch_size, nh, dhqk):
  346. raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
  347. if scaM_old.shape != (batch_size, nh, 1):
  348. raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
  349. if igate.shape != (batch_size, nh, 1):
  350. raise ValueError(f"scaI has wrong shape, got {igate.shape}")
  351. if fgate.shape != (batch_size, nh, 1):
  352. raise ValueError(f"scaF has wrong shape, got {fgate.shape}")
  353. # gates
  354. scaF_log = torch.nn.functional.logsigmoid(fgate)
  355. # update rule
  356. scaM_state_new = torch.max(scaF_log + scaM_old, igate)
  357. scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new)
  358. scaI_act = torch.exp(igate - scaM_state_new)
  359. vecQ_scaled = query * (dhqk ** (-0.5))
  360. matC_state_new = scaF_act[:, :, :, None] * matC_old.clone() + scaI_act[:, :, :, None] * (
  361. key[:, :, :, None] @ value[:, :, None, :]
  362. )
  363. vecN_state_new = scaF_act * vecN_old.clone() + scaI_act * key
  364. h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv)
  365. h_num = h_num.squeeze(2).to(dtype=dtype_state)
  366. qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv)
  367. qn_dotproduct = qn_dotproduct.squeeze(2)
  368. max_val = torch.exp(-scaM_state_new)
  369. h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state)
  370. h = h_num / h_denom
  371. h = h.to(dtype=dtype_qkv)
  372. matC_state_new = matC_state_new.to(dtype=dtype_state)
  373. vecN_state_new = vecN_state_new.to(dtype=dtype_state)
  374. scaM_state_new = scaM_state_new.to(dtype=dtype_state)
  375. return h, (matC_state_new, vecN_state_new, scaM_state_new)
  376. def mlstm_recurrent_sequence_native(
  377. query: torch.Tensor,
  378. key: torch.Tensor,
  379. value: torch.Tensor,
  380. igate: torch.Tensor,
  381. fgate: torch.Tensor,
  382. c_initial: torch.Tensor | None = None,
  383. n_initial: torch.Tensor | None = None,
  384. m_initial: torch.Tensor | None = None,
  385. return_last_states: bool = False,
  386. eps: float = 1e-6,
  387. dtype_state: torch.dtype = torch.float32,
  388. **kwargs,
  389. ) -> tuple[
  390. torch.Tensor,
  391. torch.Tensor,
  392. torch.Tensor,
  393. tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
  394. tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
  395. ]:
  396. batch_size, nh, sequence_length, dhqk = query.shape
  397. dhv = value.shape[-1]
  398. device = query.device
  399. if c_initial is not None:
  400. if n_initial is None or m_initial is None:
  401. raise ValueError("Initial states must be provided together.")
  402. if n_initial is None or m_initial is None:
  403. raise ValueError("Initial states must be provided together.")
  404. matC_state, vecN_state, vecM_state = (
  405. c_initial.to(dtype=dtype_state),
  406. n_initial.to(dtype=dtype_state),
  407. m_initial.to(dtype=dtype_state),
  408. )
  409. else:
  410. # memory state
  411. matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device)
  412. # normalizer state
  413. vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device)
  414. # max state
  415. vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device)
  416. vecH_list = []
  417. for t in range(sequence_length):
  418. # gates
  419. vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None]
  420. # projections
  421. vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :]
  422. # step
  423. vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native(
  424. cstate=matC_state,
  425. nstate=vecN_state,
  426. mstate=vecM_state,
  427. query=vecQ_t,
  428. key=vecK_t,
  429. value=vecV_t,
  430. igate=vecI_t,
  431. fgate=vecF_t,
  432. eps=eps,
  433. dtype_state=dtype_state,
  434. **kwargs,
  435. )
  436. vecH_list.append(vecH)
  437. matH = torch.stack(vecH_list, dim=-2)
  438. if return_last_states:
  439. return matH, (matC_state, vecN_state, vecM_state)
  440. else:
  441. return matH
  442. def wrap_chunkwise_pad_zeros(
  443. mlstm_chunkwise_kernel: Callable,
  444. query: torch.Tensor,
  445. key: torch.Tensor,
  446. value: torch.Tensor,
  447. fgate: torch.Tensor,
  448. igate: torch.Tensor,
  449. c_initial: torch.Tensor | None = None,
  450. n_initial: torch.Tensor | None = None,
  451. m_initial: torch.Tensor | None = None,
  452. return_last_states: bool = False,
  453. eps: float = 1e-6,
  454. autocast_kernel_dtype: torch.dtype = torch.bfloat16,
  455. chunk_size: int = 64,
  456. **kwargs,
  457. ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  458. if return_last_states:
  459. raise ValueError(
  460. "We are padding zeros, so we cannot return last states,",
  461. "as they would be not the true last states.",
  462. )
  463. batch_size, nh, sequence_length, dhqk = query.shape
  464. S_unpadded = sequence_length
  465. # padding to chunk size for kernels
  466. if sequence_length % chunk_size != 0:
  467. S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size
  468. q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3])
  469. k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3])
  470. v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3])
  471. i_pad = igate.new_zeros(batch_size, nh, S_padded)
  472. f_pad = fgate.new_zeros(batch_size, nh, S_padded)
  473. q_pad[:, :, :S_unpadded, :] = query
  474. k_pad[:, :, :S_unpadded, :] = key
  475. v_pad[:, :, :S_unpadded, :] = value
  476. i_pad[:, :, :S_unpadded] = igate
  477. f_pad[:, :, :S_unpadded] = fgate
  478. else:
  479. q_pad = query
  480. k_pad = key
  481. v_pad = value
  482. i_pad = igate
  483. f_pad = fgate
  484. matH = mlstm_chunkwise_kernel(
  485. query=q_pad,
  486. key=k_pad,
  487. value=v_pad,
  488. igate=i_pad,
  489. fgate=f_pad,
  490. c_initial=c_initial,
  491. n_initial=n_initial,
  492. m_initial=m_initial,
  493. return_last_states=return_last_states,
  494. eps=eps,
  495. autocast_kernel_dtype=autocast_kernel_dtype,
  496. chunk_size=chunk_size,
  497. **kwargs,
  498. )
  499. matH = matH[:, :, :S_unpadded, :]
  500. return matH
  501. def wrap_chunkwise_arbitrary_sequence_length(
  502. mlstm_chunkwise_kernel: Callable,
  503. mlstm_sequence_kernel: Callable,
  504. mlstm_step_kernel: Callable,
  505. query: torch.Tensor,
  506. key: torch.Tensor,
  507. value: torch.Tensor,
  508. fgate: torch.Tensor,
  509. igate: torch.Tensor,
  510. c_initial: torch.Tensor | None = None,
  511. n_initial: torch.Tensor | None = None,
  512. m_initial: torch.Tensor | None = None,
  513. return_last_states: bool = True,
  514. eps: float = 1e-6,
  515. autocast_kernel_dtype: torch.dtype = torch.bfloat16,
  516. chunk_size: int = 64,
  517. enable_logging: bool = False,
  518. ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  519. """This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length.
  520. For this it uses three kernels:
  521. - mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel.
  522. - mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence.
  523. - mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step.
  524. It tries to maximize the chunksizes to improve performance.
  525. It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16.
  526. At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length.
  527. E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary.
  528. For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM
  529. in a single step and loop over this in pytorch.
  530. Args:
  531. mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel
  532. mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence
  533. query: The query tensor (batch_size, nh, sequence_length, dhqk)
  534. key: The key tensor (batch_size, nh, sequence_length, dhqk)
  535. value: The value tensor (batch_size, nh, sequence_length, dhhv)
  536. fgate: The forget gate tensor (batch_size, nh, sequence_length)
  537. igate: The input gate tensor (batch_size, nh, sequence_length)
  538. c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv)
  539. n_initial: The initial hidden state tensor (batch_size, nh, dhqk)
  540. m_initial: The initial memory state tensor (batch_size, nh, 1)
  541. return_last_states: If True, the function will return the last states of the mLSTM
  542. eps: The epsilon value used for numerical stability
  543. autocast_kernel_dtype: The dtype used for the kernel computation
  544. chunk_size: The chunk size used for the chunkwise kernel
  545. enable_logging: If True, the function will log debug information. Default is False.
  546. Returns:
  547. The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM
  548. Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)).
  549. """
  550. batch_size, nh, sequence_length, dhqk = key.shape
  551. dhhv = value.shape[-1]
  552. c_state = (
  553. c_initial
  554. if c_initial is not None
  555. else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32)
  556. )
  557. n_state = (
  558. n_initial
  559. if n_initial is not None
  560. else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32)
  561. )
  562. m_state = (
  563. m_initial
  564. if m_initial is not None
  565. else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32)
  566. )
  567. if sequence_length > 1:
  568. # process the sequence length in chunks
  569. h_outs = []
  570. seq_len_start_idx = 0
  571. remaining_seq_len = sequence_length - seq_len_start_idx
  572. num_chunks = remaining_seq_len // chunk_size
  573. if num_chunks > 0:
  574. iter_seq_len = chunk_size * num_chunks
  575. seq_len_idx = seq_len_start_idx + iter_seq_len
  576. h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel(
  577. query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
  578. key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
  579. value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
  580. fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(),
  581. igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(),
  582. c_initial=c_state,
  583. n_initial=n_state,
  584. m_initial=m_state,
  585. chunk_size=chunk_size,
  586. return_last_states=True,
  587. autocast_kernel_dtype=autocast_kernel_dtype,
  588. eps=eps,
  589. )
  590. seq_len_start_idx += iter_seq_len
  591. h_outs.append(h_out)
  592. remaining_seq_len = sequence_length - seq_len_start_idx
  593. if remaining_seq_len > 0:
  594. # we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state
  595. h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel(
  596. query=query[..., seq_len_start_idx:sequence_length, :].contiguous(),
  597. key=key[..., seq_len_start_idx:sequence_length, :].contiguous(),
  598. value=value[..., seq_len_start_idx:sequence_length, :].contiguous(),
  599. igate=igate[..., seq_len_start_idx:sequence_length].contiguous(),
  600. fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(),
  601. c_initial=c_state,
  602. n_initial=n_state,
  603. m_initial=m_state,
  604. return_last_states=True,
  605. eps=eps,
  606. )
  607. h_outs.append(h_out)
  608. h_out = torch.concatenate(h_outs, dim=2)
  609. else:
  610. if sequence_length != 1:
  611. raise ValueError(
  612. f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence."
  613. )
  614. # process the sequence length in a single step
  615. # while this case is also captured by the regular mode above,
  616. # it avoids the overhead of the loop and calls the step kernel directly
  617. # The step function does not want a sequence dimension
  618. # qkv shape is (batch_size, nh, dhqk/dhv)
  619. # igate, fgate shape is (batch_size, nh, 1)
  620. h_out, (c_state, n_state, m_state) = mlstm_step_kernel(
  621. query=query.squeeze(2),
  622. key=key.squeeze(2),
  623. value=value.squeeze(2),
  624. igate=igate,
  625. fgate=fgate,
  626. cstate=c_state,
  627. nstate=n_state,
  628. mstate=m_state,
  629. eps=eps,
  630. )
  631. h_out = h_out[:, :, None, :]
  632. if return_last_states:
  633. return h_out, (c_state, n_state, m_state)
  634. else:
  635. return h_out
  636. class xLSTMBackend(nn.Module):
  637. """xLSTM Backend Module for PyTorch.
  638. This module wraps the xLSTM kernels and provides a high-level interface for training and inference.
  639. """
  640. config_class = xLSTMConfig
  641. def __init__(self, config: xLSTMConfig):
  642. super().__init__()
  643. self.config = config
  644. self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd
  645. self.sequence_kernel_fn = mlstm_recurrent_sequence_native
  646. self.step_kernel_fn = mlstm_recurrent_step_native
  647. self._inference_fn = partial(
  648. wrap_chunkwise_arbitrary_sequence_length,
  649. mlstm_chunkwise_kernel=self.chunkwise_kernel_fn,
  650. mlstm_sequence_kernel=partial(
  651. self.sequence_kernel_fn,
  652. dtype_state=getattr(torch, config.inference_state_dtype),
  653. ),
  654. mlstm_step_kernel=partial(
  655. self.step_kernel_fn,
  656. dtype_state=getattr(torch, config.inference_state_dtype),
  657. ),
  658. chunk_size=config.chunk_size,
  659. eps=config.eps,
  660. autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
  661. return_last_states=True,
  662. )
  663. train_kernel_fn = partial(
  664. self.chunkwise_kernel_fn,
  665. autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
  666. eps=config.eps,
  667. chunk_size=config.chunk_size,
  668. )
  669. if "with_padding" in config.mode:
  670. train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn)
  671. self._train_fn = train_kernel_fn
  672. def forward(
  673. self,
  674. query: torch.Tensor,
  675. key: torch.Tensor,
  676. value: torch.Tensor,
  677. igate: torch.Tensor,
  678. fgate: torch.Tensor,
  679. c_initial: torch.Tensor | None = None,
  680. n_initial: torch.Tensor | None = None,
  681. m_initial: torch.Tensor | None = None,
  682. return_last_states: bool | None = None,
  683. mode: Literal["train", "inference"] | None = None,
  684. ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  685. """Forward pass of the mLSTM backend.
  686. Depending on the configured mode, this method will call the appropriate kernel function.
  687. Args:
  688. query: The query tensor of shape (batch_size, nh, sequence_length, dhqk).
  689. key: The key tensor of shape (batch_size, nh, sequence_length, dhqk).
  690. value: The value tensor of shape (batch_size, nh, sequence_length, dhhv).
  691. igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length).
  692. fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length).
  693. c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv).
  694. Defaults to None.
  695. n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None.
  696. m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None.
  697. return_last_states: Whether to return the last states of the sequence. Defaults to None.
  698. If None, the value from the config is used.
  699. Returns:
  700. hidden states of shape (batch_size, nh, sequence_length, dhhv)
  701. hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv),
  702. the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1)
  703. """
  704. if mode is None:
  705. mode = self.config.mode
  706. if "train" in mode:
  707. if return_last_states is None:
  708. return_last_states = self.config.return_last_states
  709. if self.config.mode == "train_with_padding":
  710. if return_last_states:
  711. raise ValueError("return_last_states=True is not supported with train_with_padding mode.")
  712. return self._train_fn(
  713. query=query,
  714. key=key,
  715. value=value,
  716. igate=igate,
  717. fgate=fgate,
  718. c_initial=c_initial,
  719. n_initial=n_initial,
  720. m_initial=m_initial,
  721. return_last_states=return_last_states,
  722. )
  723. elif "inference" in mode:
  724. # inference mode always returns the last states
  725. return self._inference_fn(
  726. query=query,
  727. key=key,
  728. value=value,
  729. igate=igate,
  730. fgate=fgate,
  731. c_initial=c_initial,
  732. n_initial=n_initial,
  733. m_initial=m_initial,
  734. )
  735. else:
  736. raise ValueError(f"Unknown mode: {self.config.mode}")
  737. def extra_repr(self) -> str:
  738. return f"{self.config}"
  739. class xLSTMRMSNorm(nn.Module):
  740. """Root mean square normalization layer implementation similar
  741. to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
  742. It normalizes the input tensor by the root mean square of the last dimension.
  743. Args:
  744. num_features: The number of features in the input tensor.
  745. eps: A small value to avoid division by zero.
  746. use_weight: Whether to use a learnable weight.
  747. use_bias: Whether to use a learnable bias.
  748. force_float32_reductions: Whether to force float32 reductions.
  749. """
  750. def __init__(
  751. self,
  752. num_features: int,
  753. eps: float = 1e-6,
  754. use_weight: bool = True,
  755. use_bias: bool = False,
  756. force_float32_reductions: bool = True,
  757. ):
  758. super().__init__()
  759. self.num_features = num_features
  760. self.eps = eps
  761. self.force_float32_reductions = force_float32_reductions
  762. if use_weight:
  763. self.weight = nn.Parameter(torch.ones(num_features))
  764. else:
  765. self.weight = None
  766. if use_bias:
  767. self.bias = nn.Parameter(torch.zeros(num_features))
  768. else:
  769. self.bias = None
  770. def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
  771. if self.weight is not None:
  772. x = x * self.weight
  773. if self.bias is not None:
  774. x = x + self.bias
  775. return x
  776. def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
  777. # apply rms norm over the last dimension, i.e. HD dimension
  778. in_dtype = x.dtype
  779. if self.force_float32_reductions:
  780. x = x.float()
  781. x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
  782. return x.to(in_dtype)
  783. def forward(self, x: torch.Tensor) -> torch.Tensor:
  784. x = self._rms_normalize(x)
  785. x = self._apply_weight_bias(x)
  786. return x
  787. class xLSTMMultiHeadLayerNorm(nn.Module):
  788. """Multi-head version of the LayerNorm layer.
  789. It normalizes the last dimension of the input tensor.
  790. The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where:
  791. batch_size: batch size
  792. sequence_length: sequence length
  793. nh: number of heads
  794. DH: head dimension
  795. The normalization is applied over the last dimension (DH) of the input tensor.
  796. Args:
  797. num_heads: The number of heads.
  798. head_dim: The head dimension.
  799. eps: A small value to avoid division by zero.
  800. use_weight: Whether to use a learnable weight.
  801. use_bias: Whether to use a learnable bias.
  802. force_float32_reductions: Whether to force float32 reductions
  803. Returns:
  804. The normalized tensor with the shape (batch_size, sequence_length, nh * DH).
  805. """
  806. def __init__(
  807. self,
  808. num_heads: int,
  809. head_dim: int,
  810. eps: float = 1e-6,
  811. use_weight: bool = True,
  812. use_bias: bool = False,
  813. force_float32_reductions: bool = True,
  814. ):
  815. super().__init__()
  816. self.num_features = num_heads * head_dim
  817. self.eps = eps
  818. self.force_float32_reductions = force_float32_reductions
  819. if use_weight:
  820. self.weight = nn.Parameter(torch.ones(self.num_features))
  821. else:
  822. self.weight = None
  823. if use_bias:
  824. self.bias = nn.Parameter(torch.zeros(self.num_features))
  825. else:
  826. self.bias = None
  827. self.num_heads = num_heads
  828. self.head_dim = head_dim
  829. def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
  830. if self.weight is not None:
  831. x = x * self.weight
  832. if self.bias is not None:
  833. x = x + self.bias
  834. return x
  835. def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
  836. # apply layer norm over the last dimension, i.e. HD dimension
  837. in_dtype = x.dtype
  838. if self.force_float32_reductions:
  839. x = x.float()
  840. x_centered = x - x.mean(dim=-1, keepdim=True)
  841. y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
  842. return y.to(in_dtype)
  843. def forward(
  844. self,
  845. x: torch.Tensor,
  846. ) -> torch.Tensor:
  847. batch_size, sequence_length, nh, DH = x.shape
  848. if nh != self.num_heads:
  849. raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}")
  850. if self.head_dim != DH:
  851. raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
  852. x = self._layer_normalize(x)
  853. x = x.reshape(batch_size, sequence_length, -1)
  854. x = self._apply_weight_bias(x)
  855. return x
  856. class xLSTMFeedForward(nn.Module):
  857. def __init__(self, config: xLSTMConfig):
  858. super().__init__()
  859. self.config = config
  860. self.up_proj_dim = round_up_to_next_multiple_of(
  861. config.hidden_size * config.ffn_proj_factor,
  862. config.ffn_round_up_to_multiple_of,
  863. )
  864. if self.config.weight_mode == "single":
  865. self.proj_up_gate = nn.Linear(
  866. in_features=config.hidden_size,
  867. out_features=self.up_proj_dim,
  868. bias=self.config.use_bias,
  869. )
  870. self.proj_up = nn.Linear(
  871. in_features=config.hidden_size,
  872. out_features=self.up_proj_dim,
  873. bias=self.config.use_bias,
  874. )
  875. elif self.config.weight_mode == "fused":
  876. self.proj_up_gate_z = nn.Linear(
  877. in_features=config.hidden_size,
  878. out_features=2 * self.up_proj_dim,
  879. bias=self.config.use_bias,
  880. )
  881. self.proj_down = nn.Linear(
  882. in_features=self.up_proj_dim,
  883. out_features=config.hidden_size,
  884. bias=self.config.use_bias,
  885. )
  886. self.act_fn = nn.SiLU()
  887. def forward(self, x: torch.Tensor) -> torch.Tensor:
  888. if self.config.weight_mode == "single":
  889. x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x)
  890. elif self.config.weight_mode == "fused":
  891. x = self.proj_up_gate_z(x)
  892. gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1)
  893. x = self.act_fn(gate) * z
  894. y = self.proj_down(x)
  895. return y
  896. class xLSTMLayer(nn.Module):
  897. def __init__(self, config: xLSTMConfig):
  898. super().__init__()
  899. self.config = config
  900. self.v_dim = int(config.hidden_size * config.v_dim_factor)
  901. self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
  902. if self.config.weight_mode == "single":
  903. self.q = nn.Linear(
  904. in_features=self.config.hidden_size,
  905. out_features=self.qk_dim,
  906. bias=self.config.use_bias,
  907. )
  908. self.k = nn.Linear(
  909. in_features=self.config.hidden_size,
  910. out_features=self.qk_dim,
  911. bias=self.config.use_bias,
  912. )
  913. self.v = nn.Linear(
  914. in_features=self.config.hidden_size,
  915. out_features=self.v_dim,
  916. bias=self.config.use_bias,
  917. )
  918. self.ogate_preact = nn.Linear(
  919. in_features=self.config.hidden_size,
  920. out_features=self.v_dim,
  921. bias=self.config.use_bias,
  922. )
  923. self.igate_preact = nn.Linear(
  924. in_features=self.config.hidden_size,
  925. out_features=self.config.num_heads,
  926. bias=True,
  927. )
  928. self.fgate_preact = nn.Linear(
  929. in_features=self.config.hidden_size,
  930. out_features=self.config.num_heads,
  931. bias=True,
  932. )
  933. elif self.config.weight_mode == "fused":
  934. self.qkv_opreact = nn.Linear(
  935. in_features=self.config.hidden_size,
  936. out_features=2 * self.qk_dim + 2 * self.v_dim,
  937. bias=self.config.use_bias,
  938. )
  939. self.ifgate_preact = nn.Linear(
  940. in_features=self.config.hidden_size,
  941. out_features=2 * self.config.num_heads,
  942. bias=True,
  943. )
  944. self.ogate_act_fn = nn.Sigmoid()
  945. self.mlstm_backend = xLSTMBackend(config=self.config)
  946. self.multihead_norm = xLSTMMultiHeadLayerNorm(
  947. num_heads=self.config.num_heads,
  948. head_dim=self.v_dim // self.config.num_heads,
  949. eps=self.config.norm_eps,
  950. use_weight=True,
  951. use_bias=self.config.use_bias,
  952. force_float32_reductions=self.config.norm_reduction_force_float32,
  953. )
  954. self.out_proj = nn.Linear(
  955. in_features=self.v_dim,
  956. out_features=self.config.hidden_size,
  957. bias=self.config.use_bias,
  958. )
  959. def forward(
  960. self, x: torch.Tensor, state: mLSTMLayerStateType | None = None
  961. ) -> tuple[torch.Tensor, mLSTMLayerStateType | None]:
  962. if x.ndim != 3:
  963. raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
  964. batch_size, sequence_length, _ = x.shape
  965. if self.config.weight_mode == "single":
  966. query = self.q(x)
  967. key = self.k(x)
  968. value = self.v(x)
  969. o_preact = self.ogate_preact(x)
  970. i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
  971. f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
  972. elif self.config.weight_mode == "fused":
  973. qkv_opreact = self.qkv_opreact(x)
  974. query, key, value, o_preact = torch.tensor_split(
  975. qkv_opreact,
  976. (
  977. self.qk_dim,
  978. 2 * self.qk_dim,
  979. 2 * self.qk_dim + self.v_dim,
  980. ),
  981. dim=-1,
  982. )
  983. if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap)
  984. i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1)
  985. query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
  986. key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
  987. value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
  988. i_preact = i_preact.transpose(1, 2)
  989. f_preact = f_preact.transpose(1, 2)
  990. if state is None:
  991. c_initial, n_initial, m_initial = None, None, None
  992. else:
  993. c_initial, n_initial, m_initial = state
  994. h, state = self.mlstm_backend(
  995. query=query,
  996. key=key,
  997. value=value,
  998. igate=i_preact,
  999. fgate=f_preact,
  1000. c_initial=c_initial,
  1001. n_initial=n_initial,
  1002. m_initial=m_initial,
  1003. )
  1004. expected_h_shape = (
  1005. batch_size,
  1006. self.config.num_heads,
  1007. sequence_length,
  1008. self.v_dim // self.config.num_heads,
  1009. )
  1010. if h.shape != expected_h_shape:
  1011. raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
  1012. h = h.transpose(1, 2)
  1013. h_norm = self.multihead_norm(h)
  1014. h_norm = h_norm.reshape(batch_size, sequence_length, -1)
  1015. h_out = self.ogate_act_fn(o_preact) * h_norm
  1016. y = self.out_proj(h_out)
  1017. return y, state
  1018. class xLSTMBlock(GradientCheckpointingLayer):
  1019. def __init__(self, config: xLSTMConfig):
  1020. super().__init__()
  1021. self.config = config
  1022. self.norm_mlstm = xLSTMRMSNorm(
  1023. num_features=config.hidden_size,
  1024. eps=config.norm_eps,
  1025. use_weight=True,
  1026. use_bias=config.use_bias,
  1027. force_float32_reductions=config.norm_reduction_force_float32,
  1028. )
  1029. self.mlstm_layer = xLSTMLayer(config)
  1030. self.norm_ffn = xLSTMRMSNorm(
  1031. num_features=config.hidden_size,
  1032. eps=config.norm_eps,
  1033. use_weight=True,
  1034. use_bias=config.use_bias,
  1035. force_float32_reductions=config.norm_reduction_force_float32,
  1036. )
  1037. self.ffn = xLSTMFeedForward(config)
  1038. def forward(self, x: torch.Tensor, state: mLSTMStateType | None = None) -> tuple[torch.Tensor, mLSTMStateType]:
  1039. x_mlstm = self.norm_mlstm(x)
  1040. x_mlstm, state = self.mlstm_layer(x_mlstm, state)
  1041. x = x + x_mlstm
  1042. x_ffn = self.norm_ffn(x)
  1043. x_ffn = self.ffn(x_ffn)
  1044. x = x + x_ffn
  1045. return x, state
  1046. def small_init_method(dim):
  1047. """
  1048. Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
  1049. Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
  1050. the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
  1051. std = (2 / (5 * dim)) ** (1 / 2)
  1052. def init_(tensor):
  1053. return init.normal_(tensor, mean=0.0, std=std)
  1054. return init_
  1055. def wang_init_method(n_layers, dim):
  1056. """
  1057. Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
  1058. """
  1059. std = 2 / n_layers / dim ** (1 / 2)
  1060. def init_(tensor):
  1061. return init.normal_(tensor, mean=0.0, std=std)
  1062. return init_
  1063. class xLSTMPreTrainedModel(PreTrainedModel):
  1064. """
  1065. An abstract class for an interface to loading a pre-trained xLSTM model.
  1066. """
  1067. config_class = xLSTMConfig
  1068. base_model_prefix = "backbone"
  1069. _no_split_modules = ["xLSTMBlock"]
  1070. supports_gradient_checkpointing = True
  1071. _is_stateful = True
  1072. _can_record_outputs = {
  1073. "hidden_states": xLSTMBlock,
  1074. }
  1075. def _module_name_map(self, module):
  1076. for name, mod in self.named_modules():
  1077. if mod is module:
  1078. return name
  1079. return ""
  1080. @torch.no_grad()
  1081. def _init_weights(self, module):
  1082. if isinstance(module, nn.Embedding):
  1083. small_init_method(self.config.hidden_size)(self.embeddings.weight)
  1084. elif isinstance(module, nn.Linear):
  1085. if module.bias is not None:
  1086. init.zeros_(module.bias)
  1087. if self.config.weight_mode == "single" and "gate" in self._module_name_map(module):
  1088. init.zeros_(module.weight)
  1089. if "igate" in self._module_name_map(module):
  1090. init.copy_(module.bias, -10.0 * torch.ones_like(module.bias))
  1091. elif "fgate" in self._module_name_map(module):
  1092. init.copy_(
  1093. module.bias,
  1094. torch.linspace(
  1095. 3.0,
  1096. 6.0,
  1097. module.bias.shape[-1],
  1098. ).to(
  1099. device=module.bias.device,
  1100. dtype=module.bias.dtype,
  1101. ),
  1102. )
  1103. elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module):
  1104. init.zeros_(module.weight)
  1105. init.copy_(
  1106. module.bias[: self.config.num_heads],
  1107. module.bias[: self.config.num_heads]
  1108. - module.bias[: self.config.num_heads]
  1109. - 10.0 * torch.ones_like(module.bias),
  1110. )
  1111. init.copy_(
  1112. module.bias[: self.config.num_heads],
  1113. module.bias[: self.config.num_heads]
  1114. - module.bias[self.config.num_heads :]
  1115. + torch.linspace(
  1116. 3.0,
  1117. 6.0,
  1118. module.bias.shape[-1],
  1119. ).to(
  1120. device=module.bias.device,
  1121. dtype=module.bias.dtype,
  1122. ),
  1123. )
  1124. elif "proj_down" in self._module_name_map(module):
  1125. wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight)
  1126. elif "out_proj" in self._module_name_map(module):
  1127. wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight)
  1128. elif module.weight is not None:
  1129. small_init_method(self.config.hidden_size)(module.weight)
  1130. elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"):
  1131. init.ones_(module.weight)
  1132. if hasattr(module, "bias") and module.bias is not None:
  1133. init.zeros_(module.bias)
  1134. class xLSTMCache:
  1135. """
  1136. Cache for xLSTM model which does not have attention mechanism and key value states.
  1137. Arguments:
  1138. config (`PreTrainedConfig):
  1139. The configuration file defining the shape-related attributes required to initialize the static cache.
  1140. max_batch_size (`int`):
  1141. The batch size with which the model will be used.
  1142. dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
  1143. The default `dtype` to use when initializing the layer.
  1144. device (`torch.device` or `str`, *optional*):
  1145. The device on which the cache should be initialized. Should be the same as the layer.
  1146. Attributes:
  1147. seqlen_offset: int
  1148. dtype: torch.dtype
  1149. Example:
  1150. ```python
  1151. >>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache
  1152. >>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b")
  1153. >>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b")
  1154. >>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt")
  1155. >>> # Prepare a cache class and pass it to model's forward
  1156. >>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
  1157. >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True)
  1158. >>> outputs.cache_params
  1159. xLSTMCache()
  1160. """
  1161. def __init__(
  1162. self,
  1163. config: xLSTMConfig,
  1164. max_batch_size: int,
  1165. dtype: torch.dtype = torch.bfloat16,
  1166. device: str | None = None,
  1167. **kwargs,
  1168. ):
  1169. self.seqlen_offset = 0
  1170. self.dtype = dtype
  1171. self.config = config
  1172. self.rnn_state = {
  1173. layer: (
  1174. torch.zeros(
  1175. [max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim],
  1176. dtype=dtype,
  1177. device=device,
  1178. ),
  1179. torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device),
  1180. torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device),
  1181. )
  1182. for layer in range(config.num_hidden_layers)
  1183. }
  1184. def reset(self):
  1185. self.rnn_state = {
  1186. layer: (
  1187. torch.zeros_like(self.rnn_state[layer][0]),
  1188. torch.zeros_like(self.rnn_state[layer][1]),
  1189. torch.zeros_like(self.rnn_state[layer][2]),
  1190. )
  1191. for layer in self.rnn_state
  1192. }
  1193. @dataclass
  1194. @auto_docstring
  1195. class xLSTMOutput(ModelOutput):
  1196. r"""
  1197. cache_params (`xLSTMCache`):
  1198. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  1199. avoid providing the old `input_ids`.
  1200. """
  1201. last_hidden_state: torch.FloatTensor | None
  1202. cache_params: xLSTMCache | None = None
  1203. hidden_states: tuple[torch.FloatTensor] | None = None
  1204. @auto_docstring
  1205. class xLSTMModel(xLSTMPreTrainedModel):
  1206. def __init__(self, config):
  1207. super().__init__(config)
  1208. # use embbeding_dim and num_blocks once here to make use of them
  1209. self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
  1210. self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)])
  1211. self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps)
  1212. self.gradient_checkpointing = False
  1213. # Initialize weights and apply final processing
  1214. self.post_init()
  1215. def get_input_embeddings(self):
  1216. return self.embeddings
  1217. def set_input_embeddings(self, new_embedding):
  1218. self.embeddings = new_embedding
  1219. @merge_with_config_defaults
  1220. @capture_outputs
  1221. @auto_docstring
  1222. def forward(
  1223. self,
  1224. input_ids: torch.LongTensor | None = None,
  1225. inputs_embeds: torch.LongTensor | None = None,
  1226. cache_params: xLSTMCache | None = None,
  1227. use_cache: bool | None = None,
  1228. **kwargs: Unpack[TransformersKwargs],
  1229. ) -> tuple | xLSTMOutput:
  1230. r"""
  1231. cache_params (`xLSTMCache`, *optional*):
  1232. The xLSTMCache that carries the RNN states.
  1233. """
  1234. # Resolved here (not just by @capture_outputs) because the chunked inference path below
  1235. # is incompatible with hidden state collection and we need the value to pick the right branch.
  1236. output_hidden_states = kwargs.get("output_hidden_states")
  1237. if output_hidden_states is None:
  1238. output_hidden_states = self.config.output_hidden_states
  1239. if (input_ids is None) ^ (inputs_embeds is not None):
  1240. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1241. if inputs_embeds is None:
  1242. inputs_embeds = self.embeddings(input_ids)
  1243. if use_cache and cache_params is None:
  1244. cache_params = xLSTMCache(
  1245. self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
  1246. )
  1247. hidden_states = inputs_embeds
  1248. if (
  1249. not self.training
  1250. and self.config.max_inference_chunksize < hidden_states.shape[1]
  1251. and not output_hidden_states
  1252. ):
  1253. offset = 0
  1254. with torch.no_grad():
  1255. if cache_params is None:
  1256. cache_params = xLSTMCache(config=self.config, max_batch_size=hidden_states.shape[0])
  1257. final_state = torch.zeros_like(hidden_states)
  1258. while offset < hidden_states.shape[1]:
  1259. hidden_states_chunk = hidden_states[
  1260. :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
  1261. ]
  1262. for layer_idx, xlstm_block in enumerate(self.blocks):
  1263. hidden_states_chunk, rnn_state = xlstm_block(
  1264. hidden_states_chunk,
  1265. state=cache_params.rnn_state[layer_idx],
  1266. )
  1267. for state_idx in range(len(cache_params.rnn_state[layer_idx])):
  1268. local_rnn_state = rnn_state[state_idx]
  1269. cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
  1270. cache_params.rnn_state_initial = False
  1271. final_state[
  1272. :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
  1273. ] = hidden_states_chunk
  1274. offset += self.config.max_inference_chunksize
  1275. hidden_states = final_state
  1276. else:
  1277. for layer_idx, xlstm_block in enumerate(self.blocks):
  1278. hidden_states, rnn_state = xlstm_block(
  1279. hidden_states,
  1280. cache_params.rnn_state[layer_idx] if cache_params is not None else None,
  1281. )
  1282. if cache_params:
  1283. for state_idx in range(len(cache_params.rnn_state[layer_idx])):
  1284. local_rnn_state = rnn_state[state_idx]
  1285. cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
  1286. cache_params.rnn_state_initial = False
  1287. if use_cache:
  1288. cache_params.seqlen_offset += inputs_embeds.shape[1]
  1289. hidden_states = self.out_norm(hidden_states)
  1290. return xLSTMOutput(
  1291. last_hidden_state=hidden_states,
  1292. cache_params=cache_params,
  1293. )
  1294. @dataclass
  1295. @auto_docstring
  1296. class xLSTMCausalLMOutput(ModelOutput):
  1297. r"""
  1298. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  1299. Language modeling loss (for next-token prediction).
  1300. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  1301. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  1302. cache_params (`xLSTMCache`, *optional*, carrying the RNN states):
  1303. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  1304. avoid providing the old `input_ids`.
  1305. """
  1306. loss: torch.FloatTensor | None = None
  1307. logits: torch.FloatTensor | None = None
  1308. cache_params: xLSTMCache | None = None
  1309. hidden_states: tuple[torch.FloatTensor] | None = None
  1310. @auto_docstring
  1311. class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
  1312. def __init__(self, config):
  1313. super().__init__(config)
  1314. self.backbone = xLSTMModel(config)
  1315. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1316. # Initialize weights and apply final processing
  1317. self.post_init()
  1318. def get_output_embeddings(self):
  1319. return self.lm_head
  1320. def set_output_embeddings(self, new_embeddings):
  1321. self.lm_head = new_embeddings
  1322. def get_input_embeddings(self):
  1323. return self.backbone.get_input_embeddings()
  1324. def set_input_embeddings(self, new_embeddings):
  1325. return self.backbone.set_input_embeddings(new_embeddings)
  1326. @can_return_tuple
  1327. @auto_docstring
  1328. def forward(
  1329. self,
  1330. input_ids: torch.LongTensor | None = None,
  1331. inputs_embeds: torch.FloatTensor | None = None,
  1332. cache_params: xLSTMCache | None = None,
  1333. labels: torch.LongTensor | None = None,
  1334. use_cache: bool | None = None,
  1335. **kwargs: Unpack[TransformersKwargs],
  1336. ) -> tuple | xLSTMCausalLMOutput:
  1337. r"""
  1338. cache_params (`xLSTMCache`, *optional*):
  1339. The xLSTMCache that carries the RNN states.
  1340. """
  1341. xlstm_outputs = self.backbone(
  1342. input_ids,
  1343. cache_params=cache_params,
  1344. inputs_embeds=inputs_embeds,
  1345. use_cache=use_cache,
  1346. **kwargs,
  1347. )
  1348. hidden_states = xlstm_outputs[0]
  1349. logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
  1350. if not self.training and self.config.max_inference_chunksize < logits.shape[1]:
  1351. offset = 0
  1352. with torch.no_grad():
  1353. while offset < logits.shape[1]:
  1354. logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap(
  1355. logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])],
  1356. self.config.output_logit_soft_cap,
  1357. )
  1358. offset += self.config.max_inference_chunksize
  1359. else:
  1360. logits = soft_cap(logits, self.config.output_logit_soft_cap)
  1361. loss = None
  1362. if labels is not None:
  1363. # move labels to correct device
  1364. labels = labels.to(logits.device)
  1365. # Shift so that tokens < nstate predict nstate
  1366. shift_logits = logits[..., :-1, :].contiguous()
  1367. shift_labels = labels[..., 1:].contiguous()
  1368. # Flatten the tokens
  1369. loss_fct = CrossEntropyLoss()
  1370. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  1371. return xLSTMCausalLMOutput(
  1372. loss=loss,
  1373. logits=logits,
  1374. cache_params=xlstm_outputs.cache_params,
  1375. hidden_states=xlstm_outputs.hidden_states,
  1376. )
  1377. __all__ = [
  1378. "xLSTMForCausalLM",
  1379. "xLSTMModel",
  1380. "xLSTMPreTrainedModel",
  1381. ]