| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758 |
- # Copyright 2023 Bo Peng and HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch RWKV model."""
- import math
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ... import initialization as init
- from ...generation import GenerationMixin
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_utils import PreTrainedModel
- from ...utils import (
- ModelOutput,
- auto_docstring,
- is_bitsandbytes_available,
- is_kernels_available,
- is_ninja_available,
- is_torch_cuda_available,
- logging,
- )
- from .configuration_rwkv import RwkvConfig
- logger = logging.get_logger(__name__)
- rwkv_cuda_kernel = None
- def load_wkv_cuda_kernel(context_length):
- global rwkv_cuda_kernel
- if not is_kernels_available():
- raise ImportError("kernels is not installed, please install it with `pip install kernels`")
- from ...integrations.hub_kernels import get_kernel
- rwkv_cuda_kernel = get_kernel("kernels-community/rwkv")
- rwkv_cuda_kernel.max_seq_length = context_length
- class RwkvLinearAttention(torch.autograd.Function):
- @staticmethod
- def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
- batch_size, seq_len, hidden_size = key.size()
- if seq_len > rwkv_cuda_kernel.max_seq_length:
- raise ValueError(
- f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
- f"{rwkv_cuda_kernel.max_seq_length} with this model."
- )
- if batch_size * hidden_size % min(hidden_size, 32) != 0:
- raise ValueError(
- f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
- f"multiple of {min(hidden_size, 32)}."
- )
- ctx.input_dtype = key.dtype
- if (
- time_decay.device.type != "cuda"
- or time_first.device.type != "cuda"
- or key.device.type != "cuda"
- or value.device.type != "cuda"
- ):
- raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
- time_decay = -torch.exp(time_decay.float().contiguous())
- if key.dtype == torch.float16:
- time_first = time_first.float()
- key = key.float()
- value = value.float()
- time_first = time_first.contiguous()
- key = key.contiguous()
- value = value.contiguous()
- # The CUDA kernel will fill this tensor.
- output = torch.empty_like(key, memory_format=torch.contiguous_format)
- if return_state or state is not None:
- if state is None:
- state = torch.zeros(
- batch_size,
- hidden_size,
- 3,
- dtype=torch.float32,
- device=key.device,
- memory_format=torch.contiguous_format,
- )
- state[:, :, 2] -= 1e38
- else:
- state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
- if key.dtype == torch.bfloat16:
- forward_func = rwkv_cuda_kernel.forward_with_state_bf16
- else:
- forward_func = rwkv_cuda_kernel.forward_with_state
- forward_func(time_decay, time_first, key, value, output, state)
- else:
- forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
- forward_func(time_decay, time_first, key, value, output)
- ctx.save_for_backward(time_decay, time_first, key, value, output)
- if state is not None:
- state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
- return output.to(ctx.input_dtype), state
- @staticmethod
- # g stands for grad
- def backward(ctx, g_output, g_state=None):
- input_dtype = ctx.input_dtype
- time_decay, time_first, key, value, output = ctx.saved_tensors
- # The CUDA kernel will fill those tensors.
- g_time_decay = torch.empty_like(
- time_decay,
- memory_format=torch.contiguous_format,
- dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
- )
- g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)
- g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
- g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
- if input_dtype == torch.float16:
- g_output = g_output.float()
- backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
- backward_func(
- time_decay,
- time_first,
- key,
- value,
- output,
- g_output.contiguous(),
- g_time_decay,
- g_time_first,
- g_key,
- g_value,
- )
- return (
- g_time_decay.to(input_dtype),
- g_time_first.to(input_dtype),
- g_key.to(input_dtype),
- g_value.to(input_dtype),
- None,
- None,
- )
- def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
- # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
- # within a torch.no_grad.
- _, seq_length, _ = key.size()
- output = torch.zeros_like(key)
- if state is None:
- num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
- den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
- max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
- else:
- num_state, den_state, max_state = state
- # For numerical stability
- # real_numerator_state = num_state * torch.exp(max_state)
- # real_denominator_state = den_state * torch.exp(max_state)
- time_decay = -torch.exp(time_decay)
- for current_index in range(seq_length):
- current_key = key[:, current_index].float()
- current_value = value[:, current_index]
- # wkv computation at time t
- max_for_output = torch.maximum(max_state, current_key + time_first)
- e1 = torch.exp(max_state - max_for_output)
- e2 = torch.exp(current_key + time_first - max_for_output)
- numerator = e1 * num_state + e2 * current_value
- denominator = e1 * den_state + e2
- output[:, current_index] = (numerator / denominator).to(output.dtype)
- # Update state for next iteration
- max_for_state = torch.maximum(max_state + time_decay, current_key)
- e1 = torch.exp(max_state + time_decay - max_for_state)
- e2 = torch.exp(current_key - max_for_state)
- num_state = e1 * num_state + e2 * current_value
- den_state = e1 * den_state + e2
- max_state = max_for_state
- if return_state or state is not None:
- state = [num_state, den_state, max_state]
- return output, state
- def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
- no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
- # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
- # in this case).
- one_token = key.size(1) == 1
- if rwkv_cuda_kernel is None or no_cuda or one_token:
- return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
- else:
- return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
- class RwkvSelfAttention(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.config = config
- kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length
- if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
- try:
- load_wkv_cuda_kernel(config.context_length)
- except Exception:
- logger.info("Could not load the custom CUDA kernel for RWKV attention.")
- self.layer_id = layer_id
- hidden_size = config.hidden_size
- attention_hidden_size = (
- config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
- )
- self.attention_hidden_size = attention_hidden_size
- self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
- self.time_first = nn.Parameter(torch.empty(attention_hidden_size))
- self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
- self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
- self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
- self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
- self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
- self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
- self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
- # TODO: maybe jit, otherwise move inside forward
- def extract_key_value(self, hidden, state=None):
- # Mix hidden with the previous timestep to produce key, value, receptance
- if hidden.size(1) == 1 and state is not None:
- shifted = state[1][:, :, self.layer_id]
- else:
- shifted = self.time_shift(hidden)
- if state is not None:
- shifted[:, 0] = state[1][:, :, self.layer_id]
- key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
- value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
- receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
- key = self.key(key)
- value = self.value(value)
- receptance = torch.sigmoid(self.receptance(receptance))
- if state is not None:
- state[1][:, :, self.layer_id] = hidden[:, -1]
- return receptance, key, value, state
- def forward(self, hidden, state=None, use_cache=False):
- receptance, key, value, state = self.extract_key_value(hidden, state=state)
- layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
- rwkv, layer_state = rwkv_linear_attention(
- self.time_decay,
- self.time_first,
- key,
- value,
- state=layer_state,
- return_state=use_cache,
- )
- if layer_state is not None:
- state[2][:, :, self.layer_id] = layer_state[0]
- state[3][:, :, self.layer_id] = layer_state[1]
- state[4][:, :, self.layer_id] = layer_state[2]
- return self.output(receptance * rwkv), state
- class RwkvFeedForward(nn.Module):
- def __init__(self, config, layer_id=0):
- super().__init__()
- self.config = config
- self.layer_id = layer_id
- hidden_size = config.hidden_size
- intermediate_size = (
- config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
- )
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
- self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
- self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
- self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
- self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
- def forward(self, hidden, state=None):
- if hidden.size(1) == 1 and state is not None:
- shifted = state[0][:, :, self.layer_id]
- else:
- shifted = self.time_shift(hidden)
- if state is not None:
- shifted[:, 0] = state[0][:, :, self.layer_id]
- key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
- receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
- key = torch.square(torch.relu(self.key(key)))
- value = self.value(key)
- receptance = torch.sigmoid(self.receptance(receptance))
- if state is not None:
- state[0][:, :, self.layer_id] = hidden[:, -1]
- return receptance * value, state
- class RwkvBlock(GradientCheckpointingLayer):
- def __init__(self, config, layer_id):
- super().__init__()
- self.config = config
- self.layer_id = layer_id
- if layer_id == 0:
- self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- self.attention = RwkvSelfAttention(config, layer_id)
- self.feed_forward = RwkvFeedForward(config, layer_id)
- def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
- if self.layer_id == 0:
- hidden = self.pre_ln(hidden)
- attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
- hidden = hidden + attention
- feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
- hidden = hidden + feed_forward
- outputs = (hidden, state)
- if output_attentions:
- outputs += (attention,)
- else:
- outputs += (None,)
- return outputs
- @auto_docstring
- class RwkvPreTrainedModel(PreTrainedModel):
- config: RwkvConfig
- base_model_prefix = "rwkv"
- _no_split_modules = ["RwkvBlock"]
- _keep_in_fp32_modules = ["time_decay", "time_first"]
- supports_gradient_checkpointing = True
- _is_stateful = True
- @torch.no_grad()
- def _init_weights(self, module: nn.Module):
- """Initialize the weights."""
- if isinstance(module, RwkvSelfAttention):
- layer_id = module.layer_id
- num_hidden_layers = module.config.num_hidden_layers
- hidden_size = module.config.hidden_size
- attention_hidden_size = module.attention_hidden_size
- ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
- ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
- time_weight = torch.tensor(
- [i / hidden_size for i in range(hidden_size)],
- dtype=module.time_mix_key.dtype,
- device=module.time_mix_key.device,
- )
- time_weight = time_weight[None, None, :]
- decay_speed = [
- -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
- for h in range(attention_hidden_size)
- ]
- decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
- zigzag = (
- torch.tensor(
- [(i + 1) % 3 - 1 for i in range(attention_hidden_size)],
- dtype=module.time_first.dtype,
- device=module.time_first.device,
- )
- * 0.5
- )
- init.copy_(module.time_decay, decay_speed)
- init.copy_(module.time_first, torch.ones_like(module.time_first * math.log(0.3) + zigzag))
- init.copy_(module.time_mix_key, torch.pow(time_weight, ratio_1_to_almost0))
- init.copy_(module.time_mix_value, torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
- init.copy_(module.time_mix_receptance, torch.pow(time_weight, 0.5 * ratio_1_to_almost0))
- elif isinstance(module, RwkvFeedForward):
- layer_id = module.layer_id
- num_hidden_layers = module.config.num_hidden_layers
- hidden_size = module.config.hidden_size
- ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
- time_weight = torch.tensor(
- [i / hidden_size for i in range(hidden_size)],
- dtype=module.time_mix_key.dtype,
- device=module.time_mix_key.device,
- )
- time_weight = time_weight[None, None, :]
- init.copy_(module.time_mix_key, torch.pow(time_weight, ratio_1_to_almost0))
- init.copy_(module.time_mix_receptance, torch.pow(time_weight, ratio_1_to_almost0))
- elif isinstance(module, nn.Linear):
- shape = module.weight.shape
- gain = 1.0
- scale = 1.0 # extra scale for gain
- if module.bias is not None:
- init.zeros_(module.bias)
- if shape[0] > shape[1]:
- gain = math.sqrt(shape[0] / shape[1])
- if shape[0] == self.config.vocab_size and shape[1] == self.config.hidden_size: # final projection?
- scale = 0.5
- gain *= scale
- init.orthogonal_(module.weight, gain=gain)
- elif isinstance(module, nn.Embedding):
- shape = module.weight.shape
- gain = 1e-4 * math.sqrt(max(shape[0], shape[1]))
- init.orthogonal_(module.weight, gain=gain)
- elif isinstance(module, nn.LayerNorm):
- init.ones_(module.weight)
- init.zeros_(module.bias)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Class for the RWKV model outputs.
- """
- )
- class RwkvOutput(ModelOutput):
- r"""
- state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
- avoid providing the old `input_ids`.
- """
- last_hidden_state: torch.FloatTensor | None = None
- state: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for causal language model (or autoregressive) outputs.
- """
- )
- class RwkvCausalLMOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
- avoid providing the old `input_ids`.
- """
- loss: torch.FloatTensor | None = None
- logits: torch.FloatTensor | None = None
- state: list[torch.FloatTensor] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @auto_docstring
- class RwkvModel(RwkvPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
- self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
- self.ln_out = nn.LayerNorm(config.hidden_size)
- self.layers_are_rescaled = False
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings
- def set_input_embeddings(self, new_embeddings):
- self.embeddings = new_embeddings
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- state: list[torch.FloatTensor] | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | RwkvOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
- `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
- sequence tokens in the vocabulary.
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
- `input_ids`.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
- If passed along, the model uses the previous state in all the blocks (which will give the output for the
- `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
- use_cache (`bool`, *optional*):
- If set to `True`, the last state is returned and can be used to quickly generate the next logits.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if attention_mask is not None:
- logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
- if self.training == self.layers_are_rescaled:
- self._rescale_layers()
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is None and inputs_embeds is None:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embeddings(input_ids)
- if use_cache and state is None:
- shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
- state = [
- torch.zeros(
- *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
- )
- for i in range(5)
- ]
- state[4] -= 1e30
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- hidden_states = inputs_embeds
- all_self_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
- for idx, block in enumerate(self.blocks):
- hidden_states, state, attentions = block(
- hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
- )
- if (
- self.layers_are_rescaled
- and self.config.rescale_every > 0
- and (idx + 1) % self.config.rescale_every == 0
- ):
- hidden_states = hidden_states / 2
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if output_attentions:
- all_self_attentions = all_self_attentions + (attentions,)
- hidden_states = self.ln_out(hidden_states)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None)
- return RwkvOutput(
- last_hidden_state=hidden_states,
- state=state,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- def _rescale_layers(self):
- # Layers should be rescaled for inference only.
- if self.layers_are_rescaled == (not self.training):
- return
- if self.config.rescale_every > 0:
- with torch.no_grad():
- for block_id, block in enumerate(self.blocks):
- if self.training:
- block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
- block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
- else:
- # Deal with quantization statistics
- if hasattr(block.attention.output.weight, "SCB"):
- block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
- block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
- elif hasattr(block.attention.output.weight, "quant_state"):
- self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
- self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
- else:
- block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
- block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
- self.layers_are_rescaled = not self.training
- def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
- r"""
- Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
- be quantized again.
- """
- if not is_bitsandbytes_available():
- raise ImportError("Please install bitsandbytes to use this method.")
- import bitsandbytes as bnb
- dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
- dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
- # re-quantize the model:
- # we need to put it first on CPU then back to the device
- # this will create an overhead :/
- # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
- # bugs with bnb
- quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
- setattr(target_layer, "weight", quant_weight)
- @auto_docstring(
- custom_intro="""
- The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
- embeddings).
- """
- )
- class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"head.weight": "rwkv.embeddings.weight"}
- def __init__(self, config):
- super().__init__(config)
- self.rwkv = RwkvModel(config)
- self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.head
- def set_output_embeddings(self, new_embeddings):
- self.head = new_embeddings
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- state: list[torch.FloatTensor] | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs,
- ) -> tuple | RwkvCausalLMOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
- `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
- sequence tokens in the vocabulary.
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
- `input_ids`.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
- If passed along, the model uses the previous state in all the blocks (which will give the output for the
- `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
- use_cache (`bool`, *optional*):
- If set to `True`, the last state is returned and can be used to quickly generate the next logits.
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- rwkv_outputs = self.rwkv(
- input_ids,
- inputs_embeds=inputs_embeds,
- state=state,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = rwkv_outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
- if not return_dict:
- output = (logits,) + rwkv_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return RwkvCausalLMOutput(
- loss=loss,
- logits=logits,
- state=rwkv_outputs.state,
- hidden_states=rwkv_outputs.hidden_states,
- attentions=rwkv_outputs.attentions,
- )
- __all__ = ["RwkvForCausalLM", "RwkvModel", "RwkvPreTrainedModel"]
|