| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600 |
- # Copyright 2025 NXAI GmbH. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch xLSTM Model."""
- from dataclasses import dataclass
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...generation import GenerationMixin
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_xlstm_available
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_xlstm import xLSTMConfig
- if is_xlstm_available():
- from xlstm.xlstm_large.model import RMSNorm as xLSTMRMSNorm
- from xlstm.xlstm_large.model import mLSTMBlock, mLSTMStateType, soft_cap
- external_xlstm = True
- class xLSTMBlock(GradientCheckpointingLayer, mLSTMBlock):
- pass
- else:
- from collections.abc import Callable
- from functools import partial
- from typing import Literal
- from .configuration_xlstm import round_up_to_next_multiple_of
- mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- mLSTMStateType = dict[int, mLSTMLayerStateType]
- external_xlstm = False
- def soft_cap(values: torch.Tensor, cap_value: float | torch.Tensor | None = None) -> torch.Tensor:
- """
- Soft caps a tensor to a value.
- Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention
- and output language heads to prevent large logits from dominating the softmax. See for example Gemma2:
- https://huggingface.co/papers/2408.00118
- Args:
- values: The tensor to cap.
- cap_value: The value to cap the values to. If None, no cap is applied.
- Returns:
- The capped values.
- """
- if cap_value is None:
- return values
- return cap_value * torch.tanh(values / cap_value)
- def mlstm_chunkwise_recurrent_fw_C(
- matK: torch.Tensor,
- matV: torch.Tensor,
- vecB: torch.Tensor,
- vecI: torch.Tensor,
- matC_states: torch.Tensor | None = None,
- vecN_states: torch.Tensor | None = None,
- scaMinter_states: torch.Tensor | None = None,
- matC_initial: torch.Tensor | None = None,
- vecN_initial: torch.Tensor | None = None,
- scaMinter_initial: torch.Tensor | None = None,
- qk_scale: float | None = None,
- chunk_size: int = 64,
- num_chunks: int = 1,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1]
- nc = num_chunks
- _dtype, _device = matK.dtype, matK.device
- if qk_scale is None:
- qk_scale = dhqk**-0.5
- # initialize the states tensors
- if matC_states is None:
- matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device)
- if vecN_states is None:
- vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device)
- if scaMinter_states is None:
- scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device)
- # assign the initial states to the running states
- matC_k = (
- torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device)
- if matC_initial is None
- else matC_initial
- )
- vecN_k = (
- torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial
- )
- scaM_inter_k = (
- torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device)
- if scaMinter_initial is None
- else scaMinter_initial
- )
- vecA = vecB[..., -1, None] - vecB + vecI
- scaG = vecB[..., -1]
- scaA_max = vecA.max(-1).values
- scaM_inter_k = scaM_inter_k.squeeze(-1)
- for key in range(0, num_chunks):
- # store the states from the previous iteration before updating them
- # in the first iteration, these are the initial states
- matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k
- vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k
- scaMinter_states[:, :, key] = scaM_inter_k
- # m_k update
- scaA_max_k = scaA_max[:, :, key]
- scaG_k = scaG[:, :, key]
- scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k)
- # C_k update
- matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale
- matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :]
- vecA_k = vecA[:, :, key, :]
- vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None]
- matK_chunk_gated = matK_chunk * vecAbar_k
- scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
- # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
- matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
- # n_k update
- vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1)
- # move to the next iteration
- scaM_inter_k = scaM_inter_k_next
- matC_k = matC_k_next
- vecN_k = vecN_k_next
- # store the states from the last iteration
- matC_states[:, :, -dhqk:, :] = matC_k
- vecN_states[:, :, -dhqk:] = vecN_k
- scaMinter_states[:, :, -1] = scaM_inter_k
- return matC_states, vecN_states, scaMinter_states
- def mlstm_chunkwise_parallel_fw_H(
- matQ: torch.Tensor,
- matK: torch.Tensor,
- matV: torch.Tensor,
- # these states must be all states up to the last chunk, i.e. :-1
- matC_states: torch.Tensor,
- vecN_states: torch.Tensor,
- scaMinter_states: torch.Tensor,
- vecI: torch.Tensor,
- vecB: torch.Tensor,
- qk_scale: float,
- chunk_size: int = 64,
- num_chunks: int = 1,
- eps: float = 1e-6,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- _device = matQ.device
- nc = num_chunks
- batch_size, nh, dqk, dhv = matC_states.shape
- dhqk = dqk // nc
- matC_k_states = matC_states.view(batch_size, nh, nc, dhqk, dhv)
- vecN_k_states = vecN_states.view(batch_size, nh, nc, dhqk)
- scaMinter_k_states = scaMinter_states
- matQ = matQ.view(batch_size, nh, nc, chunk_size, dhqk)
- matK = matK.view(batch_size, nh, nc, chunk_size, dhqk)
- matV = matV.view(batch_size, nh, nc, chunk_size, dhv)
- ltr = torch.tril(
- torch.ones(
- (chunk_size, chunk_size),
- dtype=torch.bool,
- device=_device,
- )
- )
- # Compute intra chunk contribution: H_intra
- matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :]
- matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf"))
- matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :]
- # max_state intra
- vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values
- # max_state combined
- vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None]
- vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k)
- vecM_k_combine = vecM_k_combine[:, :, :, :, None]
- vecM_b_inter = vecM_b_inter[:, :, :, :, None]
- matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine
- matD_chunk = torch.exp(matLogD_stabilized_chunk)
- matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale
- matM_chunk = matS_chunk * matD_chunk
- # ? Combine H_intra with H_inter
- vecBbar = torch.exp(vecM_b_inter - vecM_k_combine)
- matQ_chunk_gated = matQ * vecBbar * qk_scale
- matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV
- vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True)
- vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine))
- matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps)
- matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv)
- # we need the denominator and the overall max state for the backward pass
- vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size)
- vecM_out = vecM_k_combine.reshape(batch_size, nh, nc * chunk_size)
- return matH_out, vecN_out, vecM_out
- def mlstm_chunkwise_fw(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- cstate: torch.Tensor | None = None,
- nstate: torch.Tensor | None = None,
- mstate: torch.Tensor | None = None,
- qk_scale: float | None = None,
- return_last_states: bool = False,
- return_all_states: bool = False,
- chunk_size: int = 64,
- eps: float = 1e-6,
- ) -> tuple[
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
- tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
- ]:
- batch_size, nh, sequence_length, dhqk = query.shape
- if sequence_length % chunk_size != 0:
- raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
- nc = sequence_length // chunk_size
- vecI = igate.view(batch_size, nh, nc, chunk_size)
- vecF = fgate.view(batch_size, nh, nc, chunk_size)
- # compute the gates, the g and the a and b vectors
- vecF_logsig = fgate.logsigmoid(vecF)
- vecB = vecF_logsig.cumsum(-1)
- if qk_scale is None:
- qk_scale = dhqk**-0.5
- #! materialize the C_k, n_k, m_k states for each chunk
- matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
- matK=key,
- matV=value,
- vecB=vecB,
- vecI=vecI,
- matC_initial=cstate,
- vecN_initial=nstate,
- scaMinter_initial=mstate,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- )
- #! compute the outputs within each chunk
- matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
- matQ=query,
- matK=key,
- matV=value,
- matC_states=matC_k_states[:, :, :-dhqk, :],
- vecN_states=vecN_k_states[:, :, :-dhqk],
- scaMinter_states=scaMinter_k_states[:, :, :-1],
- vecI=vecI,
- vecB=vecB,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- eps=eps,
- )
- ret_tuple = (matH_out, vecN_out, vecM_out)
- if return_last_states:
- ret_tuple += (
- (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]),
- )
- else:
- ret_tuple += (None,)
- if return_all_states:
- ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),)
- else:
- ret_tuple += (None,)
- return ret_tuple
- def mlstm_chunkwise_native_autograd(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- c_initial: torch.Tensor | None = None,
- n_initial: torch.Tensor | None = None,
- m_initial: torch.Tensor | None = None,
- return_last_states: bool = False,
- eps: float = 1e-6,
- chunk_size: int = 64,
- **kwargs,
- ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
- batch_size, nh, sequence_length, dhqk = query.shape
- if sequence_length % chunk_size != 0:
- raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
- nc = sequence_length // chunk_size
- vecI = igate.view(batch_size, nh, nc, chunk_size)
- vecF = fgate.view(batch_size, nh, nc, chunk_size)
- # compute the gates, the g and the a and b vectors
- vecF_logsig = F.logsigmoid(vecF)
- vecB = vecF_logsig.cumsum(-1)
- qk_scale = dhqk**-0.5
- #! materialize the C_k, n_k, m_k states for each chunk
- matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
- matK=key,
- matV=value,
- vecB=vecB,
- vecI=vecI,
- matC_initial=c_initial,
- vecN_initial=n_initial,
- scaMinter_initial=m_initial,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- )
- #! compute the outputs within each chunk
- matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
- matQ=query,
- matK=key,
- matV=value,
- matC_states=matC_k_states[:, :, :-dhqk, :],
- vecN_states=vecN_k_states[:, :, :-dhqk],
- scaMinter_states=scaMinter_k_states[:, :, :-1],
- vecI=vecI,
- vecB=vecB,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- eps=eps,
- )
- last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:])
- if return_last_states:
- return matH_out, last_states
- else:
- return matH_out
- def mlstm_recurrent_step_native(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- cstate: torch.Tensor,
- nstate: torch.Tensor,
- mstate: torch.Tensor,
- eps: float = 1e-6,
- dtype_state: torch.dtype = torch.float32,
- **kwargs,
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
- """This is a single step of the mLSTM operation in recurrent form."""
- dtype_qkv = query.dtype
- matC_old = cstate.to(dtype=dtype_state)
- vecN_old = nstate.to(dtype=dtype_state)
- scaM_old = mstate.to(dtype=dtype_state)
- batch_size, nh, dhqk = query.shape
- _, _, dhhv = value.shape
- if query.shape != key.shape:
- raise ValueError("query and key must have the same shape")
- if matC_old.shape != (batch_size, nh, dhqk, dhhv):
- raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
- if vecN_old.shape != (batch_size, nh, dhqk):
- raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
- if scaM_old.shape != (batch_size, nh, 1):
- raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
- if igate.shape != (batch_size, nh, 1):
- raise ValueError(f"scaI has wrong shape, got {igate.shape}")
- if fgate.shape != (batch_size, nh, 1):
- raise ValueError(f"scaF has wrong shape, got {fgate.shape}")
- # gates
- scaF_log = torch.nn.functional.logsigmoid(fgate)
- # update rule
- scaM_state_new = torch.max(scaF_log + scaM_old, igate)
- scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new)
- scaI_act = torch.exp(igate - scaM_state_new)
- vecQ_scaled = query * (dhqk ** (-0.5))
- matC_state_new = scaF_act[:, :, :, None] * matC_old.clone() + scaI_act[:, :, :, None] * (
- key[:, :, :, None] @ value[:, :, None, :]
- )
- vecN_state_new = scaF_act * vecN_old.clone() + scaI_act * key
- h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv)
- h_num = h_num.squeeze(2).to(dtype=dtype_state)
- qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv)
- qn_dotproduct = qn_dotproduct.squeeze(2)
- max_val = torch.exp(-scaM_state_new)
- h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state)
- h = h_num / h_denom
- h = h.to(dtype=dtype_qkv)
- matC_state_new = matC_state_new.to(dtype=dtype_state)
- vecN_state_new = vecN_state_new.to(dtype=dtype_state)
- scaM_state_new = scaM_state_new.to(dtype=dtype_state)
- return h, (matC_state_new, vecN_state_new, scaM_state_new)
- def mlstm_recurrent_sequence_native(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- c_initial: torch.Tensor | None = None,
- n_initial: torch.Tensor | None = None,
- m_initial: torch.Tensor | None = None,
- return_last_states: bool = False,
- eps: float = 1e-6,
- dtype_state: torch.dtype = torch.float32,
- **kwargs,
- ) -> tuple[
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
- tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None,
- ]:
- batch_size, nh, sequence_length, dhqk = query.shape
- dhv = value.shape[-1]
- device = query.device
- if c_initial is not None:
- if n_initial is None or m_initial is None:
- raise ValueError("Initial states must be provided together.")
- if n_initial is None or m_initial is None:
- raise ValueError("Initial states must be provided together.")
- matC_state, vecN_state, vecM_state = (
- c_initial.to(dtype=dtype_state),
- n_initial.to(dtype=dtype_state),
- m_initial.to(dtype=dtype_state),
- )
- else:
- # memory state
- matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device)
- # normalizer state
- vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device)
- # max state
- vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device)
- vecH_list = []
- for t in range(sequence_length):
- # gates
- vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None]
- # projections
- vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :]
- # step
- vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native(
- cstate=matC_state,
- nstate=vecN_state,
- mstate=vecM_state,
- query=vecQ_t,
- key=vecK_t,
- value=vecV_t,
- igate=vecI_t,
- fgate=vecF_t,
- eps=eps,
- dtype_state=dtype_state,
- **kwargs,
- )
- vecH_list.append(vecH)
- matH = torch.stack(vecH_list, dim=-2)
- if return_last_states:
- return matH, (matC_state, vecN_state, vecM_state)
- else:
- return matH
- def wrap_chunkwise_pad_zeros(
- mlstm_chunkwise_kernel: Callable,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- fgate: torch.Tensor,
- igate: torch.Tensor,
- c_initial: torch.Tensor | None = None,
- n_initial: torch.Tensor | None = None,
- m_initial: torch.Tensor | None = None,
- return_last_states: bool = False,
- eps: float = 1e-6,
- autocast_kernel_dtype: torch.dtype = torch.bfloat16,
- chunk_size: int = 64,
- **kwargs,
- ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
- if return_last_states:
- raise ValueError(
- "We are padding zeros, so we cannot return last states,",
- "as they would be not the true last states.",
- )
- batch_size, nh, sequence_length, dhqk = query.shape
- S_unpadded = sequence_length
- # padding to chunk size for kernels
- if sequence_length % chunk_size != 0:
- S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size
- q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3])
- k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3])
- v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3])
- i_pad = igate.new_zeros(batch_size, nh, S_padded)
- f_pad = fgate.new_zeros(batch_size, nh, S_padded)
- q_pad[:, :, :S_unpadded, :] = query
- k_pad[:, :, :S_unpadded, :] = key
- v_pad[:, :, :S_unpadded, :] = value
- i_pad[:, :, :S_unpadded] = igate
- f_pad[:, :, :S_unpadded] = fgate
- else:
- q_pad = query
- k_pad = key
- v_pad = value
- i_pad = igate
- f_pad = fgate
- matH = mlstm_chunkwise_kernel(
- query=q_pad,
- key=k_pad,
- value=v_pad,
- igate=i_pad,
- fgate=f_pad,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- return_last_states=return_last_states,
- eps=eps,
- autocast_kernel_dtype=autocast_kernel_dtype,
- chunk_size=chunk_size,
- **kwargs,
- )
- matH = matH[:, :, :S_unpadded, :]
- return matH
- def wrap_chunkwise_arbitrary_sequence_length(
- mlstm_chunkwise_kernel: Callable,
- mlstm_sequence_kernel: Callable,
- mlstm_step_kernel: Callable,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- fgate: torch.Tensor,
- igate: torch.Tensor,
- c_initial: torch.Tensor | None = None,
- n_initial: torch.Tensor | None = None,
- m_initial: torch.Tensor | None = None,
- return_last_states: bool = True,
- eps: float = 1e-6,
- autocast_kernel_dtype: torch.dtype = torch.bfloat16,
- chunk_size: int = 64,
- enable_logging: bool = False,
- ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
- """This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length.
- For this it uses three kernels:
- - mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel.
- - mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence.
- - mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step.
- It tries to maximize the chunksizes to improve performance.
- It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16.
- At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length.
- E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary.
- For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM
- in a single step and loop over this in pytorch.
- Args:
- mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel
- mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence
- query: The query tensor (batch_size, nh, sequence_length, dhqk)
- key: The key tensor (batch_size, nh, sequence_length, dhqk)
- value: The value tensor (batch_size, nh, sequence_length, dhhv)
- fgate: The forget gate tensor (batch_size, nh, sequence_length)
- igate: The input gate tensor (batch_size, nh, sequence_length)
- c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv)
- n_initial: The initial hidden state tensor (batch_size, nh, dhqk)
- m_initial: The initial memory state tensor (batch_size, nh, 1)
- return_last_states: If True, the function will return the last states of the mLSTM
- eps: The epsilon value used for numerical stability
- autocast_kernel_dtype: The dtype used for the kernel computation
- chunk_size: The chunk size used for the chunkwise kernel
- enable_logging: If True, the function will log debug information. Default is False.
- Returns:
- The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM
- Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)).
- """
- batch_size, nh, sequence_length, dhqk = key.shape
- dhhv = value.shape[-1]
- c_state = (
- c_initial
- if c_initial is not None
- else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32)
- )
- n_state = (
- n_initial
- if n_initial is not None
- else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32)
- )
- m_state = (
- m_initial
- if m_initial is not None
- else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32)
- )
- if sequence_length > 1:
- # process the sequence length in chunks
- h_outs = []
- seq_len_start_idx = 0
- remaining_seq_len = sequence_length - seq_len_start_idx
- num_chunks = remaining_seq_len // chunk_size
- if num_chunks > 0:
- iter_seq_len = chunk_size * num_chunks
- seq_len_idx = seq_len_start_idx + iter_seq_len
- h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel(
- query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
- key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
- value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
- fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(),
- igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(),
- c_initial=c_state,
- n_initial=n_state,
- m_initial=m_state,
- chunk_size=chunk_size,
- return_last_states=True,
- autocast_kernel_dtype=autocast_kernel_dtype,
- eps=eps,
- )
- seq_len_start_idx += iter_seq_len
- h_outs.append(h_out)
- remaining_seq_len = sequence_length - seq_len_start_idx
- if remaining_seq_len > 0:
- # we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state
- h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel(
- query=query[..., seq_len_start_idx:sequence_length, :].contiguous(),
- key=key[..., seq_len_start_idx:sequence_length, :].contiguous(),
- value=value[..., seq_len_start_idx:sequence_length, :].contiguous(),
- igate=igate[..., seq_len_start_idx:sequence_length].contiguous(),
- fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(),
- c_initial=c_state,
- n_initial=n_state,
- m_initial=m_state,
- return_last_states=True,
- eps=eps,
- )
- h_outs.append(h_out)
- h_out = torch.concatenate(h_outs, dim=2)
- else:
- if sequence_length != 1:
- raise ValueError(
- f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence."
- )
- # process the sequence length in a single step
- # while this case is also captured by the regular mode above,
- # it avoids the overhead of the loop and calls the step kernel directly
- # The step function does not want a sequence dimension
- # qkv shape is (batch_size, nh, dhqk/dhv)
- # igate, fgate shape is (batch_size, nh, 1)
- h_out, (c_state, n_state, m_state) = mlstm_step_kernel(
- query=query.squeeze(2),
- key=key.squeeze(2),
- value=value.squeeze(2),
- igate=igate,
- fgate=fgate,
- cstate=c_state,
- nstate=n_state,
- mstate=m_state,
- eps=eps,
- )
- h_out = h_out[:, :, None, :]
- if return_last_states:
- return h_out, (c_state, n_state, m_state)
- else:
- return h_out
- class xLSTMBackend(nn.Module):
- """xLSTM Backend Module for PyTorch.
- This module wraps the xLSTM kernels and provides a high-level interface for training and inference.
- """
- config_class = xLSTMConfig
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd
- self.sequence_kernel_fn = mlstm_recurrent_sequence_native
- self.step_kernel_fn = mlstm_recurrent_step_native
- self._inference_fn = partial(
- wrap_chunkwise_arbitrary_sequence_length,
- mlstm_chunkwise_kernel=self.chunkwise_kernel_fn,
- mlstm_sequence_kernel=partial(
- self.sequence_kernel_fn,
- dtype_state=getattr(torch, config.inference_state_dtype),
- ),
- mlstm_step_kernel=partial(
- self.step_kernel_fn,
- dtype_state=getattr(torch, config.inference_state_dtype),
- ),
- chunk_size=config.chunk_size,
- eps=config.eps,
- autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
- return_last_states=True,
- )
- train_kernel_fn = partial(
- self.chunkwise_kernel_fn,
- autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
- eps=config.eps,
- chunk_size=config.chunk_size,
- )
- if "with_padding" in config.mode:
- train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn)
- self._train_fn = train_kernel_fn
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- c_initial: torch.Tensor | None = None,
- n_initial: torch.Tensor | None = None,
- m_initial: torch.Tensor | None = None,
- return_last_states: bool | None = None,
- mode: Literal["train", "inference"] | None = None,
- ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
- """Forward pass of the mLSTM backend.
- Depending on the configured mode, this method will call the appropriate kernel function.
- Args:
- query: The query tensor of shape (batch_size, nh, sequence_length, dhqk).
- key: The key tensor of shape (batch_size, nh, sequence_length, dhqk).
- value: The value tensor of shape (batch_size, nh, sequence_length, dhhv).
- igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length).
- fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length).
- c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv).
- Defaults to None.
- n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None.
- m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None.
- return_last_states: Whether to return the last states of the sequence. Defaults to None.
- If None, the value from the config is used.
- Returns:
- hidden states of shape (batch_size, nh, sequence_length, dhhv)
- hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv),
- the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1)
- """
- if mode is None:
- mode = self.config.mode
- if "train" in mode:
- if return_last_states is None:
- return_last_states = self.config.return_last_states
- if self.config.mode == "train_with_padding":
- if return_last_states:
- raise ValueError("return_last_states=True is not supported with train_with_padding mode.")
- return self._train_fn(
- query=query,
- key=key,
- value=value,
- igate=igate,
- fgate=fgate,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- return_last_states=return_last_states,
- )
- elif "inference" in mode:
- # inference mode always returns the last states
- return self._inference_fn(
- query=query,
- key=key,
- value=value,
- igate=igate,
- fgate=fgate,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- )
- else:
- raise ValueError(f"Unknown mode: {self.config.mode}")
- def extra_repr(self) -> str:
- return f"{self.config}"
- class xLSTMRMSNorm(nn.Module):
- """Root mean square normalization layer implementation similar
- to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
- It normalizes the input tensor by the root mean square of the last dimension.
- Args:
- num_features: The number of features in the input tensor.
- eps: A small value to avoid division by zero.
- use_weight: Whether to use a learnable weight.
- use_bias: Whether to use a learnable bias.
- force_float32_reductions: Whether to force float32 reductions.
- """
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-6,
- use_weight: bool = True,
- use_bias: bool = False,
- force_float32_reductions: bool = True,
- ):
- super().__init__()
- self.num_features = num_features
- self.eps = eps
- self.force_float32_reductions = force_float32_reductions
- if use_weight:
- self.weight = nn.Parameter(torch.ones(num_features))
- else:
- self.weight = None
- if use_bias:
- self.bias = nn.Parameter(torch.zeros(num_features))
- else:
- self.bias = None
- def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
- if self.weight is not None:
- x = x * self.weight
- if self.bias is not None:
- x = x + self.bias
- return x
- def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
- # apply rms norm over the last dimension, i.e. HD dimension
- in_dtype = x.dtype
- if self.force_float32_reductions:
- x = x.float()
- x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
- return x.to(in_dtype)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self._rms_normalize(x)
- x = self._apply_weight_bias(x)
- return x
- class xLSTMMultiHeadLayerNorm(nn.Module):
- """Multi-head version of the LayerNorm layer.
- It normalizes the last dimension of the input tensor.
- The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where:
- batch_size: batch size
- sequence_length: sequence length
- nh: number of heads
- DH: head dimension
- The normalization is applied over the last dimension (DH) of the input tensor.
- Args:
- num_heads: The number of heads.
- head_dim: The head dimension.
- eps: A small value to avoid division by zero.
- use_weight: Whether to use a learnable weight.
- use_bias: Whether to use a learnable bias.
- force_float32_reductions: Whether to force float32 reductions
- Returns:
- The normalized tensor with the shape (batch_size, sequence_length, nh * DH).
- """
- def __init__(
- self,
- num_heads: int,
- head_dim: int,
- eps: float = 1e-6,
- use_weight: bool = True,
- use_bias: bool = False,
- force_float32_reductions: bool = True,
- ):
- super().__init__()
- self.num_features = num_heads * head_dim
- self.eps = eps
- self.force_float32_reductions = force_float32_reductions
- if use_weight:
- self.weight = nn.Parameter(torch.ones(self.num_features))
- else:
- self.weight = None
- if use_bias:
- self.bias = nn.Parameter(torch.zeros(self.num_features))
- else:
- self.bias = None
- self.num_heads = num_heads
- self.head_dim = head_dim
- def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
- if self.weight is not None:
- x = x * self.weight
- if self.bias is not None:
- x = x + self.bias
- return x
- def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
- # apply layer norm over the last dimension, i.e. HD dimension
- in_dtype = x.dtype
- if self.force_float32_reductions:
- x = x.float()
- x_centered = x - x.mean(dim=-1, keepdim=True)
- y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
- return y.to(in_dtype)
- def forward(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- batch_size, sequence_length, nh, DH = x.shape
- if nh != self.num_heads:
- raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}")
- if self.head_dim != DH:
- raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
- x = self._layer_normalize(x)
- x = x.reshape(batch_size, sequence_length, -1)
- x = self._apply_weight_bias(x)
- return x
- class xLSTMFeedForward(nn.Module):
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.up_proj_dim = round_up_to_next_multiple_of(
- config.hidden_size * config.ffn_proj_factor,
- config.ffn_round_up_to_multiple_of,
- )
- if self.config.weight_mode == "single":
- self.proj_up_gate = nn.Linear(
- in_features=config.hidden_size,
- out_features=self.up_proj_dim,
- bias=self.config.use_bias,
- )
- self.proj_up = nn.Linear(
- in_features=config.hidden_size,
- out_features=self.up_proj_dim,
- bias=self.config.use_bias,
- )
- elif self.config.weight_mode == "fused":
- self.proj_up_gate_z = nn.Linear(
- in_features=config.hidden_size,
- out_features=2 * self.up_proj_dim,
- bias=self.config.use_bias,
- )
- self.proj_down = nn.Linear(
- in_features=self.up_proj_dim,
- out_features=config.hidden_size,
- bias=self.config.use_bias,
- )
- self.act_fn = nn.SiLU()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.config.weight_mode == "single":
- x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x)
- elif self.config.weight_mode == "fused":
- x = self.proj_up_gate_z(x)
- gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1)
- x = self.act_fn(gate) * z
- y = self.proj_down(x)
- return y
- class xLSTMLayer(nn.Module):
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.v_dim = int(config.hidden_size * config.v_dim_factor)
- self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
- if self.config.weight_mode == "single":
- self.q = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.qk_dim,
- bias=self.config.use_bias,
- )
- self.k = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.qk_dim,
- bias=self.config.use_bias,
- )
- self.v = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.v_dim,
- bias=self.config.use_bias,
- )
- self.ogate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.v_dim,
- bias=self.config.use_bias,
- )
- self.igate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.config.num_heads,
- bias=True,
- )
- self.fgate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.config.num_heads,
- bias=True,
- )
- elif self.config.weight_mode == "fused":
- self.qkv_opreact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=2 * self.qk_dim + 2 * self.v_dim,
- bias=self.config.use_bias,
- )
- self.ifgate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=2 * self.config.num_heads,
- bias=True,
- )
- self.ogate_act_fn = nn.Sigmoid()
- self.mlstm_backend = xLSTMBackend(config=self.config)
- self.multihead_norm = xLSTMMultiHeadLayerNorm(
- num_heads=self.config.num_heads,
- head_dim=self.v_dim // self.config.num_heads,
- eps=self.config.norm_eps,
- use_weight=True,
- use_bias=self.config.use_bias,
- force_float32_reductions=self.config.norm_reduction_force_float32,
- )
- self.out_proj = nn.Linear(
- in_features=self.v_dim,
- out_features=self.config.hidden_size,
- bias=self.config.use_bias,
- )
- def forward(
- self, x: torch.Tensor, state: mLSTMLayerStateType | None = None
- ) -> tuple[torch.Tensor, mLSTMLayerStateType | None]:
- if x.ndim != 3:
- raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
- batch_size, sequence_length, _ = x.shape
- if self.config.weight_mode == "single":
- query = self.q(x)
- key = self.k(x)
- value = self.v(x)
- o_preact = self.ogate_preact(x)
- i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
- f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
- elif self.config.weight_mode == "fused":
- qkv_opreact = self.qkv_opreact(x)
- query, key, value, o_preact = torch.tensor_split(
- qkv_opreact,
- (
- self.qk_dim,
- 2 * self.qk_dim,
- 2 * self.qk_dim + self.v_dim,
- ),
- dim=-1,
- )
- if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap)
- i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1)
- query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
- key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
- value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
- i_preact = i_preact.transpose(1, 2)
- f_preact = f_preact.transpose(1, 2)
- if state is None:
- c_initial, n_initial, m_initial = None, None, None
- else:
- c_initial, n_initial, m_initial = state
- h, state = self.mlstm_backend(
- query=query,
- key=key,
- value=value,
- igate=i_preact,
- fgate=f_preact,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- )
- expected_h_shape = (
- batch_size,
- self.config.num_heads,
- sequence_length,
- self.v_dim // self.config.num_heads,
- )
- if h.shape != expected_h_shape:
- raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
- h = h.transpose(1, 2)
- h_norm = self.multihead_norm(h)
- h_norm = h_norm.reshape(batch_size, sequence_length, -1)
- h_out = self.ogate_act_fn(o_preact) * h_norm
- y = self.out_proj(h_out)
- return y, state
- class xLSTMBlock(GradientCheckpointingLayer):
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.norm_mlstm = xLSTMRMSNorm(
- num_features=config.hidden_size,
- eps=config.norm_eps,
- use_weight=True,
- use_bias=config.use_bias,
- force_float32_reductions=config.norm_reduction_force_float32,
- )
- self.mlstm_layer = xLSTMLayer(config)
- self.norm_ffn = xLSTMRMSNorm(
- num_features=config.hidden_size,
- eps=config.norm_eps,
- use_weight=True,
- use_bias=config.use_bias,
- force_float32_reductions=config.norm_reduction_force_float32,
- )
- self.ffn = xLSTMFeedForward(config)
- def forward(self, x: torch.Tensor, state: mLSTMStateType | None = None) -> tuple[torch.Tensor, mLSTMStateType]:
- x_mlstm = self.norm_mlstm(x)
- x_mlstm, state = self.mlstm_layer(x_mlstm, state)
- x = x + x_mlstm
- x_ffn = self.norm_ffn(x)
- x_ffn = self.ffn(x_ffn)
- x = x + x_ffn
- return x, state
- def small_init_method(dim):
- """
- Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
- Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
- the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
- std = (2 / (5 * dim)) ** (1 / 2)
- def init_(tensor):
- return init.normal_(tensor, mean=0.0, std=std)
- return init_
- def wang_init_method(n_layers, dim):
- """
- Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
- """
- std = 2 / n_layers / dim ** (1 / 2)
- def init_(tensor):
- return init.normal_(tensor, mean=0.0, std=std)
- return init_
- class xLSTMPreTrainedModel(PreTrainedModel):
- """
- An abstract class for an interface to loading a pre-trained xLSTM model.
- """
- config_class = xLSTMConfig
- base_model_prefix = "backbone"
- _no_split_modules = ["xLSTMBlock"]
- supports_gradient_checkpointing = True
- _is_stateful = True
- _can_record_outputs = {
- "hidden_states": xLSTMBlock,
- }
- def _module_name_map(self, module):
- for name, mod in self.named_modules():
- if mod is module:
- return name
- return ""
- @torch.no_grad()
- def _init_weights(self, module):
- if isinstance(module, nn.Embedding):
- small_init_method(self.config.hidden_size)(self.embeddings.weight)
- elif isinstance(module, nn.Linear):
- if module.bias is not None:
- init.zeros_(module.bias)
- if self.config.weight_mode == "single" and "gate" in self._module_name_map(module):
- init.zeros_(module.weight)
- if "igate" in self._module_name_map(module):
- init.copy_(module.bias, -10.0 * torch.ones_like(module.bias))
- elif "fgate" in self._module_name_map(module):
- init.copy_(
- module.bias,
- torch.linspace(
- 3.0,
- 6.0,
- module.bias.shape[-1],
- ).to(
- device=module.bias.device,
- dtype=module.bias.dtype,
- ),
- )
- elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module):
- init.zeros_(module.weight)
- init.copy_(
- module.bias[: self.config.num_heads],
- module.bias[: self.config.num_heads]
- - module.bias[: self.config.num_heads]
- - 10.0 * torch.ones_like(module.bias),
- )
- init.copy_(
- module.bias[: self.config.num_heads],
- module.bias[: self.config.num_heads]
- - module.bias[self.config.num_heads :]
- + torch.linspace(
- 3.0,
- 6.0,
- module.bias.shape[-1],
- ).to(
- device=module.bias.device,
- dtype=module.bias.dtype,
- ),
- )
- elif "proj_down" in self._module_name_map(module):
- wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight)
- elif "out_proj" in self._module_name_map(module):
- wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight)
- elif module.weight is not None:
- small_init_method(self.config.hidden_size)(module.weight)
- elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"):
- init.ones_(module.weight)
- if hasattr(module, "bias") and module.bias is not None:
- init.zeros_(module.bias)
- class xLSTMCache:
- """
- Cache for xLSTM model which does not have attention mechanism and key value states.
- Arguments:
- config (`PreTrainedConfig):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- max_batch_size (`int`):
- The batch size with which the model will be used.
- dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
- The default `dtype` to use when initializing the layer.
- device (`torch.device` or `str`, *optional*):
- The device on which the cache should be initialized. Should be the same as the layer.
- Attributes:
- seqlen_offset: int
- dtype: torch.dtype
- Example:
- ```python
- >>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache
- >>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b")
- >>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b")
- >>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
- >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True)
- >>> outputs.cache_params
- xLSTMCache()
- """
- def __init__(
- self,
- config: xLSTMConfig,
- max_batch_size: int,
- dtype: torch.dtype = torch.bfloat16,
- device: str | None = None,
- **kwargs,
- ):
- self.seqlen_offset = 0
- self.dtype = dtype
- self.config = config
- self.rnn_state = {
- layer: (
- torch.zeros(
- [max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim],
- dtype=dtype,
- device=device,
- ),
- torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device),
- torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device),
- )
- for layer in range(config.num_hidden_layers)
- }
- def reset(self):
- self.rnn_state = {
- layer: (
- torch.zeros_like(self.rnn_state[layer][0]),
- torch.zeros_like(self.rnn_state[layer][1]),
- torch.zeros_like(self.rnn_state[layer][2]),
- )
- for layer in self.rnn_state
- }
- @dataclass
- @auto_docstring
- class xLSTMOutput(ModelOutput):
- r"""
- cache_params (`xLSTMCache`):
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
- avoid providing the old `input_ids`.
- """
- last_hidden_state: torch.FloatTensor | None
- cache_params: xLSTMCache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- @auto_docstring
- class xLSTMModel(xLSTMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- # use embbeding_dim and num_blocks once here to make use of them
- self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
- self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)])
- self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings
- def set_input_embeddings(self, new_embedding):
- self.embeddings = new_embedding
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.LongTensor | None = None,
- cache_params: xLSTMCache | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | xLSTMOutput:
- r"""
- cache_params (`xLSTMCache`, *optional*):
- The xLSTMCache that carries the RNN states.
- """
- # Resolved here (not just by @capture_outputs) because the chunked inference path below
- # is incompatible with hidden state collection and we need the value to pick the right branch.
- output_hidden_states = kwargs.get("output_hidden_states")
- if output_hidden_states is None:
- output_hidden_states = self.config.output_hidden_states
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embeddings(input_ids)
- if use_cache and cache_params is None:
- cache_params = xLSTMCache(
- self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
- )
- hidden_states = inputs_embeds
- if (
- not self.training
- and self.config.max_inference_chunksize < hidden_states.shape[1]
- and not output_hidden_states
- ):
- offset = 0
- with torch.no_grad():
- if cache_params is None:
- cache_params = xLSTMCache(config=self.config, max_batch_size=hidden_states.shape[0])
- final_state = torch.zeros_like(hidden_states)
- while offset < hidden_states.shape[1]:
- hidden_states_chunk = hidden_states[
- :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
- ]
- for layer_idx, xlstm_block in enumerate(self.blocks):
- hidden_states_chunk, rnn_state = xlstm_block(
- hidden_states_chunk,
- state=cache_params.rnn_state[layer_idx],
- )
- for state_idx in range(len(cache_params.rnn_state[layer_idx])):
- local_rnn_state = rnn_state[state_idx]
- cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
- cache_params.rnn_state_initial = False
- final_state[
- :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
- ] = hidden_states_chunk
- offset += self.config.max_inference_chunksize
- hidden_states = final_state
- else:
- for layer_idx, xlstm_block in enumerate(self.blocks):
- hidden_states, rnn_state = xlstm_block(
- hidden_states,
- cache_params.rnn_state[layer_idx] if cache_params is not None else None,
- )
- if cache_params:
- for state_idx in range(len(cache_params.rnn_state[layer_idx])):
- local_rnn_state = rnn_state[state_idx]
- cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
- cache_params.rnn_state_initial = False
- if use_cache:
- cache_params.seqlen_offset += inputs_embeds.shape[1]
- hidden_states = self.out_norm(hidden_states)
- return xLSTMOutput(
- last_hidden_state=hidden_states,
- cache_params=cache_params,
- )
- @dataclass
- @auto_docstring
- class xLSTMCausalLMOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- cache_params (`xLSTMCache`, *optional*, carrying the RNN states):
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
- avoid providing the old `input_ids`.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- cache_params: xLSTMCache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- @auto_docstring
- class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
- def __init__(self, config):
- super().__init__(config)
- self.backbone = xLSTMModel(config)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- def get_input_embeddings(self):
- return self.backbone.get_input_embeddings()
- def set_input_embeddings(self, new_embeddings):
- return self.backbone.set_input_embeddings(new_embeddings)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- cache_params: xLSTMCache | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | xLSTMCausalLMOutput:
- r"""
- cache_params (`xLSTMCache`, *optional*):
- The xLSTMCache that carries the RNN states.
- """
- xlstm_outputs = self.backbone(
- input_ids,
- cache_params=cache_params,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = xlstm_outputs[0]
- logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
- if not self.training and self.config.max_inference_chunksize < logits.shape[1]:
- offset = 0
- with torch.no_grad():
- while offset < logits.shape[1]:
- logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap(
- logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])],
- self.config.output_logit_soft_cap,
- )
- offset += self.config.max_inference_chunksize
- else:
- logits = soft_cap(logits, self.config.output_logit_soft_cap)
- loss = None
- if labels is not None:
- # move labels to correct device
- labels = labels.to(logits.device)
- # Shift so that tokens < nstate predict nstate
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
- return xLSTMCausalLMOutput(
- loss=loss,
- logits=logits,
- cache_params=xlstm_outputs.cache_params,
- hidden_states=xlstm_outputs.hidden_states,
- )
- __all__ = [
- "xLSTMForCausalLM",
- "xLSTMModel",
- "xLSTMPreTrainedModel",
- ]
|