modular_zamba2.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. # Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from collections.abc import Callable
  17. from itertools import cycle
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...integrations.hub_kernels import lazy_load_kernel
  24. from ...masking_utils import create_causal_mask
  25. from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
  29. from ...utils.generic import merge_with_config_defaults
  30. from ...utils.import_utils import resolve_internal_import
  31. from ...utils.output_capturing import capture_outputs
  32. from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
  33. from ..mamba2.modeling_mamba2 import pad_tensor_by_size, reshape_into_chunks, segment_sum
  34. from ..zamba.modeling_zamba import (
  35. ZambaAttention,
  36. ZambaAttentionDecoderLayer,
  37. ZambaForCausalLM,
  38. ZambaForSequenceClassification,
  39. ZambaHybridLayer,
  40. ZambaMambaDecoderLayer,
  41. ZambaModel,
  42. ZambaRMSNorm,
  43. eager_attention_forward,
  44. )
  45. from .configuration_zamba2 import Zamba2Config
  46. _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B"
  47. logger = logging.get_logger(__name__)
  48. class Zamba2RMSNormGated(torch.nn.Module):
  49. def __init__(self, hidden_size, group_size, eps=1e-6):
  50. super().__init__()
  51. self.weight = nn.Parameter(torch.ones(hidden_size))
  52. self.variance_epsilon = eps
  53. self.group_size = group_size
  54. def forward(self, hidden_states, gate=None):
  55. input_dtype = hidden_states.dtype
  56. hidden_states = hidden_states.to(torch.float32)
  57. if gate is not None:
  58. hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  59. *prefix_dims, last_dim = hidden_states.shape
  60. group_count = last_dim // self.group_size
  61. hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
  62. variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
  63. hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
  64. hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
  65. return self.weight * hidden_states.to(input_dtype)
  66. class Zamba2RMSNorm(ZambaRMSNorm):
  67. pass
  68. class Zamba2RotaryEmbedding(LlamaRotaryEmbedding):
  69. pass
  70. class Zamba2Attention(ZambaAttention):
  71. """
  72. Multi-headed attention from 'Attention Is All You Need' paper.
  73. Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
  74. The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
  75. The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
  76. (see fig. 2 in https://huggingface.co/papers/2405.16712).
  77. Additionally, replaced
  78. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
  79. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
  80. Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
  81. layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
  82. expressivity with a small memory overhead (see Fig. 2 of https://huggingface.co/papers/2411.15242).
  83. """
  84. def __init__(
  85. self,
  86. config: Zamba2Config,
  87. layer_idx: int | None = None,
  88. num_fwd_mem_blocks: int | None = None,
  89. block_id: int | None = None,
  90. ):
  91. super().__init__(config, layer_idx)
  92. self.num_fwd_mem_blocks = num_fwd_mem_blocks
  93. self.layer_block_map = config.hybrid_layer_ids
  94. self.block_id = block_id
  95. if config.use_shared_attention_adapter:
  96. self.linear_q_adapter_list = nn.ModuleList([])
  97. self.linear_k_adapter_list = nn.ModuleList([])
  98. self.linear_v_adapter_list = nn.ModuleList([])
  99. for i in range(self.num_fwd_mem_blocks):
  100. if i % config.num_mem_blocks == block_id:
  101. linear_q_adapter = nn.Sequential(
  102. nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
  103. nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
  104. )
  105. linear_k_adapter = nn.Sequential(
  106. nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
  107. nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
  108. )
  109. linear_v_adapter = nn.Sequential(
  110. nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
  111. nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
  112. )
  113. else:
  114. linear_q_adapter = nn.Identity()
  115. linear_k_adapter = nn.Identity()
  116. linear_v_adapter = nn.Identity()
  117. self.linear_q_adapter_list.append(linear_q_adapter)
  118. self.linear_k_adapter_list.append(linear_k_adapter)
  119. self.linear_v_adapter_list.append(linear_v_adapter)
  120. self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)}
  121. def forward(
  122. self,
  123. hidden_states: torch.Tensor,
  124. layer_idx: int,
  125. attention_mask: torch.Tensor | None = None,
  126. past_key_values: Cache | None = None,
  127. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  128. **kwargs: Unpack[TransformersKwargs],
  129. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  130. input_shape = hidden_states.shape[:-1]
  131. hidden_shape = (*input_shape, -1, self.head_dim)
  132. query_states = self.q_proj(hidden_states)
  133. key_states = self.k_proj(hidden_states)
  134. value_states = self.v_proj(hidden_states)
  135. if self.config.use_shared_attention_adapter:
  136. adapter_layer_idx = self.layer_dic[layer_idx]
  137. query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
  138. key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
  139. value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
  140. query_states = query_states.view(hidden_shape).transpose(1, 2)
  141. key_states = key_states.view(hidden_shape).transpose(1, 2)
  142. value_states = value_states.view(hidden_shape).transpose(1, 2)
  143. if self.config.use_mem_rope:
  144. cos, sin = position_embeddings
  145. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  146. if past_key_values is not None:
  147. key_states, value_states = past_key_values.update(key_states, value_states, layer_idx)
  148. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  149. self.config._attn_implementation, eager_attention_forward
  150. )
  151. attn_output, attn_weights = attention_interface(
  152. self,
  153. query_states,
  154. key_states,
  155. value_states,
  156. attention_mask,
  157. dropout=0.0 if not self.training else self.attention_dropout,
  158. scaling=self.scaling,
  159. **kwargs,
  160. )
  161. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  162. attn_output = self.o_proj(attn_output)
  163. return attn_output, attn_weights
  164. class Zamba2MambaMixer(nn.Module):
  165. """
  166. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  167. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  168. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  169. and is why Mamba is called **selective** state spaces)
  170. """
  171. def __init__(self, config: Zamba2Config, layer_idx: int | None = None):
  172. super().__init__()
  173. self.config = config
  174. self.hidden_size = config.hidden_size
  175. self.ssm_state_size = config.mamba_d_state
  176. self.conv_kernel_size = config.mamba_d_conv
  177. self.intermediate_size = int(config.mamba_expand * self.hidden_size)
  178. self.layer_idx = layer_idx
  179. self.use_conv_bias = config.use_conv_bias
  180. self.activation = "silu"
  181. self.act = nn.SiLU()
  182. self.use_mem_eff_path = config.use_mem_eff_path
  183. self.n_groups = config.mamba_ngroups
  184. self.head_dim = config.mamba_headdim
  185. self.num_heads = self.config.n_mamba_heads
  186. self.chunk_size = config.chunk_size
  187. self.time_step_limit = config.time_step_limit
  188. self.time_step_min = config.time_step_min
  189. self.time_step_max = config.time_step_max
  190. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  191. self.conv1d = nn.Conv1d(
  192. in_channels=self.conv_dim,
  193. out_channels=self.conv_dim,
  194. bias=True,
  195. kernel_size=config.mamba_d_conv,
  196. groups=self.conv_dim,
  197. padding=config.mamba_d_conv - 1,
  198. )
  199. # projection of the input hidden states
  200. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  201. self.in_proj = nn.Linear(
  202. self.hidden_size,
  203. projection_size,
  204. bias=config.add_bias_linear,
  205. )
  206. # selective projection used to make dt, B and C input dependent
  207. # time step projection (discretization)
  208. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  209. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  210. # S4D real initialization. These are not discretized!
  211. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  212. A = torch.arange(1, self.num_heads + 1)
  213. self.A_log = nn.Parameter(torch.log(A))
  214. self.norm = Zamba2RMSNormGated(
  215. self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
  216. )
  217. self.D = nn.Parameter(torch.ones(self.num_heads))
  218. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
  219. global causal_conv1d_update, causal_conv1d_fn
  220. causal_conv1d = lazy_load_kernel("causal-conv1d")
  221. causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
  222. causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
  223. global selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  224. mamba_ssm = lazy_load_kernel("mamba-ssm")
  225. selective_state_update = resolve_internal_import(
  226. mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
  227. )
  228. mamba_chunk_scan_combined = resolve_internal_import(
  229. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_chunk_scan_combined"
  230. )
  231. mamba_split_conv1d_scan_combined = resolve_internal_import(
  232. mamba_ssm, chained_path="ops.triton.ssd_combined.mamba_split_conv1d_scan_combined"
  233. )
  234. global is_fast_path_available
  235. is_fast_path_available = all(
  236. (
  237. selective_state_update,
  238. mamba_chunk_scan_combined,
  239. mamba_split_conv1d_scan_combined,
  240. causal_conv1d_fn,
  241. causal_conv1d_update,
  242. )
  243. )
  244. if not is_fast_path_available:
  245. logger.warning_once(
  246. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  247. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  248. " https://github.com/Dao-AILab/causal-conv1d"
  249. )
  250. def cuda_kernels_forward(
  251. self,
  252. hidden_states: torch.Tensor,
  253. cache_params: Cache | None = None,
  254. attention_mask: torch.Tensor | None = None,
  255. ):
  256. # set up dimensions for reshapes later
  257. batch_size, seq_len, _ = hidden_states.shape
  258. groups_time_state_size = self.n_groups * self.ssm_state_size
  259. d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
  260. # getting projected states from cache if it exists
  261. if cache_params is not None and cache_params.has_previous_state(self.layer_idx):
  262. in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
  263. d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
  264. split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
  265. _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
  266. hidden_states_B_C = causal_conv1d_update(
  267. hidden_states_B_C,
  268. cache_params.layers[self.layer_idx].conv_states,
  269. self.conv1d.weight.squeeze(1),
  270. self.conv1d.bias,
  271. self.activation,
  272. )
  273. hidden_states, B, C = torch.split(
  274. hidden_states_B_C,
  275. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  276. dim=-1,
  277. )
  278. A = -torch.exp(self.A_log.float()) # (nheads,)
  279. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  280. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  281. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  282. D = self.D[:, None, ...].expand(-1, self.head_dim)
  283. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  284. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  285. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  286. hidden_states = selective_state_update(
  287. cache_params.layers[self.layer_idx].recurrent_states,
  288. hidden_states_reshaped,
  289. dt,
  290. A,
  291. B,
  292. C,
  293. D,
  294. z=None,
  295. dt_bias=dt_bias,
  296. dt_softplus=True,
  297. )
  298. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  299. hidden_states = self.norm(hidden_states, gate)
  300. out = self.out_proj(hidden_states)[:, None, ...]
  301. # if no cache is found, calling the kernel
  302. else:
  303. if attention_mask is not None and not torch.all(attention_mask == 1):
  304. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  305. dtype = hidden_states.dtype
  306. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  307. # 1. Gated MLP's linear projection
  308. projected_states = self.in_proj(hidden_states)
  309. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  310. dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit}
  311. if attention_mask is not None:
  312. input_not_masked = torch.all(attention_mask == 1)
  313. else:
  314. input_not_masked = True
  315. if self.use_mem_eff_path and self.training and cache_params is None and input_not_masked:
  316. out, ssm_state = mamba_split_conv1d_scan_combined(
  317. projected_states,
  318. self.conv1d.weight.squeeze(1),
  319. self.conv1d.bias,
  320. self.dt_bias,
  321. A,
  322. D=self.D,
  323. chunk_size=self.chunk_size,
  324. seq_idx=None,
  325. activation=self.activation,
  326. rmsnorm_weight=self.norm.weight,
  327. rmsnorm_eps=self.norm.variance_epsilon,
  328. outproj_weight=self.out_proj.weight,
  329. outproj_bias=self.out_proj.bias,
  330. headdim=self.head_dim,
  331. ngroups=self.n_groups,
  332. norm_before_gate=False,
  333. return_final_states=True,
  334. **dt_limit_kwargs,
  335. )
  336. else:
  337. gate, hidden_states_B_C, time_step = torch.split(
  338. projected_states,
  339. [self.intermediate_size, self.conv_dim, self.num_heads],
  340. dim=-1,
  341. )
  342. # 1D Convolution
  343. if cache_params is not None:
  344. hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
  345. conv_state = nn.functional.pad(
  346. hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0)
  347. )
  348. conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
  349. if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
  350. hidden_states_B_C = self.act(
  351. self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
  352. ) # (B, L, self.d_inner + 2 * ngroups * d_state)
  353. else:
  354. hidden_states_B_C = causal_conv1d_fn(
  355. x=hidden_states_B_C.transpose(1, 2),
  356. weight=self.conv1d.weight.squeeze(1),
  357. bias=self.conv1d.bias,
  358. activation=self.activation,
  359. ).transpose(1, 2)[:, :seq_len]
  360. hidden_states, B, C = torch.split(
  361. hidden_states_B_C,
  362. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  363. dim=-1,
  364. )
  365. if attention_mask is not None and not torch.all(attention_mask == 1):
  366. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  367. dtype = hidden_states.dtype
  368. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  369. scan_output, ssm_state = mamba_chunk_scan_combined(
  370. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  371. time_step,
  372. A,
  373. B.view(batch_size, seq_len, self.n_groups, -1),
  374. C.view(batch_size, seq_len, self.n_groups, -1),
  375. chunk_size=self.chunk_size,
  376. D=self.D,
  377. z=None,
  378. seq_idx=None,
  379. return_final_states=True,
  380. dt_bias=self.dt_bias,
  381. dt_softplus=True,
  382. **dt_limit_kwargs,
  383. )
  384. if ssm_state is not None and cache_params is not None:
  385. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  386. scan_output = scan_output.view(batch_size, seq_len, -1)
  387. # Multiply "gate" branch and apply extra normalization layer
  388. scan_output = self.norm(scan_output, gate)
  389. out = self.out_proj(scan_output)
  390. return out
  391. # fmt: off
  392. def torch_forward(self, input_states, cache_params: Cache | None=None, attention_mask: torch.Tensor | None = None):
  393. batch_size, seq_len, _ = input_states.shape
  394. dtype = input_states.dtype
  395. # Gated MLP's linear projection
  396. if cache_params is not None and cache_params.has_previous_state(self.layer_idx):
  397. projected_states = self.in_proj(input_states)
  398. else:
  399. if attention_mask is not None:
  400. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  401. input_states = (input_states * attention_mask[:, :, None]).to(dtype)
  402. projected_states = self.in_proj(input_states)
  403. d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
  404. _, _, gate, hidden_states, dt = projected_states.split(
  405. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  406. )
  407. hidden_states = hidden_states.transpose(1, 2)
  408. use_precomputed_state = cache_params is not None and cache_params.has_previous_state(self.layer_idx)
  409. # Convolution sequence transformation
  410. if use_precomputed_state:
  411. conv_state = cache_params.update_conv_state(hidden_states, self.layer_idx)
  412. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  413. if self.use_conv_bias:
  414. hidden_states += self.conv1d.bias
  415. hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
  416. else:
  417. if cache_params is not None:
  418. conv_state = nn.functional.pad(
  419. hidden_states,
  420. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  421. )
  422. conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
  423. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len].transpose(1, 2))
  424. if attention_mask is not None:
  425. dtype = hidden_states.dtype
  426. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  427. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  428. hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
  429. A = -torch.exp(self.A_log.float()) # [num_heads]
  430. if use_precomputed_state:
  431. # Note: there is no need to pad parameter matrices here, as there is just one new token
  432. # for batched generation
  433. dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
  434. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  435. # [num_heads] -> [num_heads, head_dim]
  436. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  437. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  438. dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
  439. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  440. # [bsz, num_heads, head_dim, state_size]
  441. dA = torch.exp(dt[..., None] * A)
  442. # Discretize B
  443. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  444. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  445. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  446. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  447. B = B.reshape(batch_size, -1, B.shape[-1])
  448. # [bsz, num_heads, head_dim, state_size]
  449. dB = dt[..., None] * B[..., None, :]
  450. # Discretize x into dB
  451. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  452. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  453. dBx = dB * hidden_states[..., None]
  454. # State calculation
  455. ssm_states = cache_params.layers[self.layer_idx].recurrent_states.clone()
  456. ssm_states = ssm_states * dA + dBx
  457. ssm_states = cache_params.update_recurrent_state(ssm_states, self.layer_idx)
  458. # Subsequent output
  459. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  460. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  461. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  462. C = C.reshape(batch_size, -1, C.shape[-1])
  463. # [bsz, num_heads, head_dim]
  464. ssm_states = ssm_states.to(C.dtype) # Shape: [b, h, d, n]
  465. # Reshape ssm_states to merge the first two dimensions
  466. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  467. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  468. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  469. y = y.view(batch_size, self.num_heads, self.head_dim)
  470. # D skip connection
  471. # [num_heads] -> [num_heads, head_dim]
  472. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  473. y = (y + hidden_states * D).to(y.dtype)
  474. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  475. y = y.reshape(batch_size, -1)[:, None, ...]
  476. else:
  477. # begin ssd naive implementation without einsums
  478. dt = nn.functional.softplus(dt + self.dt_bias)
  479. dt = torch.clamp(dt, self.time_step_min)
  480. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  481. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  482. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  483. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  484. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  485. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  486. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  487. # Discretize x and A
  488. hidden_states = hidden_states * dt[..., None]
  489. A = A.to(hidden_states.dtype) * dt
  490. # Rearrange into blocks/chunks
  491. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  492. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  493. A = A.permute(0, 3, 1, 2)
  494. A_cumsum = torch.cumsum(A, dim=-1)
  495. # 1. Compute the output for each intra-chunk (diagonal blocks)
  496. # This is the analog of a causal mask
  497. L = torch.exp(segment_sum(A))
  498. # First, contraction of C and B to get G (attention-weights like)
  499. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
  500. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  501. # Step 2: Compute M, equivalent to applying attention mask to weights
  502. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  503. M = M_intermediate.sum(dim=-1)
  504. # Step 3: Compute Y_diag (apply to values)
  505. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
  506. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  507. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  508. B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
  509. # permute back B * decay states
  510. states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
  511. previous_states = torch.zeros_like(states[:, :1])
  512. states = torch.cat([previous_states, states], dim=1)
  513. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  514. states_permuted = states.permute(0, 2, 1, 3, 4)
  515. result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
  516. new_states = result.permute(0, 2, 1, 3, 4)
  517. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  518. # Compute state -> output conversion per chunk
  519. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  520. state_decay_out = torch.exp(A_cumsum)
  521. # compute Yoff
  522. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  523. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  524. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  525. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  526. y = Y_diag + Y_off
  527. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  528. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  529. y = y + D_residual
  530. # Cutting off padded chunks
  531. if pad_size > 0:
  532. y = y[:, :seq_len, :, :]
  533. y = y.reshape(batch_size, seq_len, -1)
  534. if ssm_state is not None and cache_params is not None:
  535. cache_params.update_recurrent_state(ssm_state, self.layer_idx)
  536. scan_output = self.norm(y, gate)
  537. # end ssd naive
  538. # 4. Final linear projection
  539. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  540. return contextualized_states
  541. # fmt: on
  542. def forward(
  543. self,
  544. hidden_states,
  545. cache_params: Cache | None = None,
  546. attention_mask: torch.Tensor | None = None,
  547. **kwargs,
  548. ):
  549. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type and not is_torchdynamo_compiling():
  550. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  551. return self.torch_forward(hidden_states, cache_params, attention_mask)
  552. class Zamba2MLP(nn.Module):
  553. def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int | None = None):
  554. """
  555. This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
  556. is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
  557. """
  558. super().__init__()
  559. self.config = config
  560. self.hidden_size = config.hidden_size
  561. self.intermediate_size = config.intermediate_size
  562. self.num_fwd_mem_blocks = num_fwd_mem_blocks
  563. self.block_id = block_id
  564. self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear)
  565. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
  566. self.act_fn = ACT2FN[config.hidden_act]
  567. self.gate_up_proj_adapter_list = nn.ModuleList([])
  568. for i in range(self.num_fwd_mem_blocks):
  569. if i % config.num_mem_blocks == block_id:
  570. gate_up_proj_adapter = nn.Sequential(
  571. nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False),
  572. nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False),
  573. )
  574. else:
  575. gate_up_proj_adapter = nn.Identity()
  576. self.gate_up_proj_adapter_list.append(gate_up_proj_adapter)
  577. layer_block_map = config.hybrid_layer_ids
  578. self.layer_dic = {value: index for index, value in enumerate(layer_block_map)}
  579. def forward(self, hidden_state, layer_idx=None):
  580. gate_up_state = self.gate_up_proj(hidden_state)
  581. layer_idx = self.layer_dic[layer_idx]
  582. gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state)
  583. gate_up_state = torch.chunk(gate_up_state, 2, dim=-1)
  584. hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1]
  585. output = self.down_proj(hidden_state)
  586. return output
  587. class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer):
  588. def __init__(self, config: Zamba2Config, block_id: int | None = None, layer_idx: int | None = None):
  589. self.block_id = block_id
  590. num_gs = len(config.hybrid_layer_ids)
  591. super().__init__(config, layer_idx)
  592. self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id)
  593. self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id)
  594. def forward(
  595. self,
  596. hidden_states: torch.Tensor,
  597. original_hidden_states: torch.Tensor,
  598. layer_idx: int,
  599. attention_mask: torch.Tensor | None = None,
  600. past_key_values: Cache | None = None,
  601. position_embeddings: torch.LongTensor | None = None,
  602. **kwargs: Unpack[TransformersKwargs],
  603. ) -> tuple[torch.FloatTensor]:
  604. """
  605. Args:
  606. hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
  607. original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
  608. This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
  609. concatenated tensor is then used as input of the pre-attention RMSNorm
  610. (see fig. 2 in https://huggingface.co/papers/2405.16712).
  611. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  612. `(batch, sequence_length)` where padding elements are indicated by 0.
  613. past_key_values (`Cache`, *optional*): cached past key and value projection states
  614. use_cache (`bool`, *optional*):
  615. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  616. (see `past_key_values`).
  617. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  618. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  619. with `head_dim` being the embedding dimension of each attention head.
  620. """
  621. hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1)
  622. hidden_states = self.input_layernorm(hidden_states)
  623. hidden_states, _ = self.self_attn(
  624. hidden_states=hidden_states,
  625. layer_idx=layer_idx,
  626. attention_mask=attention_mask,
  627. past_key_values=past_key_values,
  628. position_embeddings=position_embeddings,
  629. **kwargs,
  630. )
  631. hidden_states = self.pre_ff_layernorm(hidden_states)
  632. hidden_states = self.feed_forward(hidden_states, layer_idx)
  633. return hidden_states
  634. class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer):
  635. def __init__(self, config: Zamba2Config, layer_idx: int):
  636. super().__init__(config, layer_idx)
  637. self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx)
  638. self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  639. class Zamba2HybridLayer(ZambaHybridLayer):
  640. def __init__(
  641. self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer
  642. ):
  643. super().__init__(shared_transformer, linear, mamba)
  644. del self.shared_transf
  645. self.shared_transformer = shared_transformer
  646. def forward(
  647. self,
  648. hidden_states: torch.Tensor,
  649. original_hidden_states: torch.Tensor | None = None,
  650. layer_idx: int | None = None,
  651. attention_mask: torch.Tensor | None = None,
  652. causal_mask: torch.Tensor | None = None,
  653. past_key_values: Cache | None = None,
  654. use_cache: bool | None = False,
  655. position_embeddings: torch.LongTensor | None = None,
  656. position_ids: torch.LongTensor | None = None,
  657. **kwargs: Unpack[TransformersKwargs],
  658. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  659. """
  660. Args:
  661. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  662. original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
  663. hidden activations to form the input of the shared transformer layer.
  664. layer_idx (`int`): layer number.
  665. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  666. `(batch, sequence_length)` where padding elements are indicated by 0.
  667. past_key_values (`Cache`, *optional*): cached past key and value projection states
  668. use_cache (`bool`, *optional*):
  669. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  670. (see `past_key_values`).
  671. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  672. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  673. with `head_dim` being the embedding dimension of each attention head.
  674. """
  675. transformer_hidden_states = self.shared_transformer(
  676. hidden_states,
  677. original_hidden_states=original_hidden_states,
  678. layer_idx=layer_idx,
  679. attention_mask=causal_mask,
  680. past_key_values=past_key_values,
  681. position_embeddings=position_embeddings,
  682. position_ids=position_ids,
  683. **kwargs,
  684. )
  685. transformer_hidden_states = self.linear(transformer_hidden_states)
  686. hidden_states = self.mamba_decoder(
  687. hidden_states,
  688. transformer_hidden_states=transformer_hidden_states,
  689. attention_mask=attention_mask,
  690. past_key_values=past_key_values,
  691. use_cache=use_cache,
  692. position_embeddings=position_embeddings,
  693. **kwargs,
  694. )
  695. return hidden_states
  696. @auto_docstring
  697. class Zamba2PreTrainedModel(PreTrainedModel):
  698. config: Zamba2Config
  699. base_model_prefix = "model"
  700. supports_gradient_checkpointing = True
  701. _no_split_modules = ["Zamba2HybridLayer", "Zamba2MambaDecoderLayer"]
  702. _skip_keys_device_placement = "past_key_values"
  703. _supports_flash_attn = True
  704. _supports_flex_attn = True
  705. _supports_sdpa = True
  706. _is_stateful = True
  707. _can_record_outputs = {
  708. "hidden_states": Zamba2MambaDecoderLayer,
  709. "attentions": Zamba2Attention,
  710. }
  711. @torch.no_grad()
  712. def _init_weights(self, module):
  713. super()._init_weights(module)
  714. if isinstance(module, Zamba2MambaMixer):
  715. dt = torch.exp(
  716. torch.rand(self.config.n_mamba_heads)
  717. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  718. + math.log(self.config.time_step_min)
  719. ).clamp(min=self.config.time_step_floor)
  720. # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  721. inv_dt = dt + torch.log(-torch.expm1(-dt))
  722. init.copy_(module.dt_bias, inv_dt)
  723. A = torch.arange(1, module.num_heads + 1)
  724. init.copy_(module.A_log, torch.log(A))
  725. init.ones_(module.D)
  726. class Zamba2Model(ZambaModel, Zamba2PreTrainedModel):
  727. """
  728. Model consisting of *config.num_hidden_layers* layers.
  729. Args:
  730. config: Zamba2Config
  731. """
  732. def __init__(self, config: Zamba2Config):
  733. Zamba2PreTrainedModel.__init__(self, config)
  734. self.config = config
  735. self.padding_idx = config.pad_token_id
  736. self.vocab_size = config.vocab_size
  737. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  738. self.layers_block_type = config.layers_block_type
  739. self.layers = self.get_layers()
  740. self._attn_implementation = config._attn_implementation
  741. self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  742. if config.use_mem_rope:
  743. if config.use_long_context:
  744. logger.warning_once(
  745. "`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`."
  746. )
  747. self.rotary_emb = Zamba2RotaryEmbedding(config)
  748. self.gradient_checkpointing = False
  749. # Initialize weights and apply final processing
  750. self.post_init()
  751. def get_layers(self):
  752. layers = []
  753. self._tied_weights_keys = {}
  754. self.first_transformer_layer_id = 0
  755. unique_hybrid_blocks = []
  756. for layer_id, layer_type in enumerate(self.layers_block_type):
  757. mamba_layer = Zamba2MambaDecoderLayer(self.config, layer_idx=layer_id)
  758. if layer_type == "hybrid":
  759. prefix_pattern = f"layers.{layer_id}.shared_transformer"
  760. # Zamba ties Hybrid module weights by repeating blocks after every
  761. # `num_mem_blocks`. So if `num_mem_blocks=2`, the blocks looks like
  762. # [1, 2, 1, 2, 1, 2] where all "ones" share the same set of weights.
  763. if (
  764. not isinstance(unique_hybrid_blocks, list)
  765. or len(unique_hybrid_blocks) >= self.config.num_mem_blocks
  766. ):
  767. if isinstance(unique_hybrid_blocks, list):
  768. unique_hybrid_blocks = cycle(unique_hybrid_blocks)
  769. target_pattern = next(unique_hybrid_blocks)
  770. self._tied_weights_keys.update({prefix_pattern: target_pattern})
  771. else:
  772. # Store source patterns to which the subsequent modules will be tied
  773. unique_hybrid_blocks.append(prefix_pattern)
  774. block_id = layer_id % self.config.num_mem_blocks
  775. attn_block = Zamba2AttentionDecoderLayer(self.config, block_id=block_id)
  776. linear_layer = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
  777. layers.append(Zamba2HybridLayer(attn_block, linear_layer, mamba_layer))
  778. else:
  779. layers.append(mamba_layer)
  780. return nn.ModuleList(layers)
  781. @merge_with_config_defaults
  782. @capture_outputs
  783. @auto_docstring
  784. def forward(
  785. self,
  786. input_ids: torch.LongTensor | None = None,
  787. attention_mask: torch.Tensor | None = None,
  788. position_ids: torch.LongTensor | None = None,
  789. past_key_values: Cache | None = None,
  790. inputs_embeds: torch.FloatTensor | None = None,
  791. use_cache: bool | None = None,
  792. **kwargs: Unpack[TransformersKwargs],
  793. ) -> tuple | BaseModelOutputWithPast:
  794. if (input_ids is None) ^ (inputs_embeds is not None):
  795. raise ValueError(
  796. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  797. )
  798. if inputs_embeds is None:
  799. inputs_embeds = self.embed_tokens(input_ids)
  800. hidden_states = inputs_embeds
  801. original_hidden_states = torch.clone(inputs_embeds)
  802. # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer
  803. if use_cache and past_key_values is None:
  804. past_key_values = DynamicCache(config=self.config)
  805. if position_ids is None:
  806. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  807. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  808. position_ids = position_ids.unsqueeze(0)
  809. causal_mask = create_causal_mask(
  810. config=self.config,
  811. inputs_embeds=inputs_embeds,
  812. attention_mask=attention_mask,
  813. past_key_values=past_key_values,
  814. position_ids=position_ids,
  815. )
  816. # create position embeddings to be shared across the decoder layers
  817. if self.config.use_mem_rope:
  818. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  819. else:
  820. position_embeddings = None
  821. for layer_idx, layer in enumerate(self.layers):
  822. hidden_states = layer(
  823. hidden_states,
  824. original_hidden_states,
  825. layer_idx,
  826. attention_mask,
  827. causal_mask,
  828. past_key_values=past_key_values,
  829. use_cache=use_cache,
  830. position_embeddings=position_embeddings,
  831. position_ids=position_ids,
  832. **kwargs,
  833. )
  834. hidden_states = self.final_layernorm(hidden_states)
  835. return BaseModelOutputWithPast(
  836. last_hidden_state=hidden_states,
  837. past_key_values=past_key_values if use_cache else None,
  838. )
  839. class Zamba2ForCausalLM(ZambaForCausalLM):
  840. def __init__(self, config: Zamba2Config):
  841. super().__init__(config)
  842. self.model = Zamba2Model(config)
  843. self.post_init()
  844. class Zamba2ForSequenceClassification(ZambaForSequenceClassification):
  845. def __init__(self, config: Zamba2Config):
  846. super().__init__(config)
  847. self.model = Zamba2Model(config)
  848. self.post_init()
  849. @can_return_tuple
  850. @auto_docstring
  851. def forward(
  852. self,
  853. input_ids: torch.LongTensor | None = None,
  854. attention_mask: torch.Tensor | None = None,
  855. position_ids: torch.LongTensor | None = None,
  856. past_key_values: Cache | None = None,
  857. inputs_embeds: torch.FloatTensor | None = None,
  858. labels: torch.LongTensor | None = None,
  859. use_cache: bool | None = None,
  860. logits_to_keep: int | torch.Tensor = 0,
  861. **kwargs: Unpack[TransformersKwargs],
  862. ) -> tuple | SequenceClassifierOutputWithPast:
  863. r"""
  864. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  865. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  866. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  867. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  868. """
  869. transformer_outputs: BaseModelOutputWithPast = self.model(
  870. input_ids,
  871. attention_mask=attention_mask,
  872. position_ids=position_ids,
  873. past_key_values=past_key_values,
  874. inputs_embeds=inputs_embeds,
  875. use_cache=use_cache,
  876. **kwargs,
  877. )
  878. hidden_states = transformer_outputs[0]
  879. logits = self.score(hidden_states)
  880. if input_ids is not None:
  881. batch_size = input_ids.shape[0]
  882. else:
  883. batch_size = inputs_embeds.shape[0]
  884. if self.config.pad_token_id is None and batch_size != 1:
  885. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  886. if self.config.pad_token_id is None:
  887. last_non_pad_token = -1
  888. elif input_ids is not None:
  889. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  890. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  891. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  892. else:
  893. last_non_pad_token = -1
  894. logger.warning_once(
  895. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  896. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  897. )
  898. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  899. loss = None
  900. if labels is not None:
  901. loss = self.loss_function(
  902. logits=pooled_logits, labels=labels, pooled_logits=pooled_logits, config=self.config, **kwargs
  903. )
  904. return SequenceClassifierOutputWithPast(
  905. loss=loss,
  906. logits=pooled_logits,
  907. past_key_values=transformer_outputs.past_key_values,
  908. hidden_states=transformer_outputs.hidden_states,
  909. attentions=transformer_outputs.attentions,
  910. )
  911. __all__ = [
  912. "Zamba2ForCausalLM",
  913. "Zamba2ForSequenceClassification",
  914. "Zamba2Model",
  915. "Zamba2PreTrainedModel",
  916. ]