modeling_rwkv.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. # Copyright 2023 Bo Peng and HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch RWKV model."""
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...generation import GenerationMixin
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import (
  25. ModelOutput,
  26. auto_docstring,
  27. is_bitsandbytes_available,
  28. is_kernels_available,
  29. is_ninja_available,
  30. is_torch_cuda_available,
  31. logging,
  32. )
  33. from .configuration_rwkv import RwkvConfig
  34. logger = logging.get_logger(__name__)
  35. rwkv_cuda_kernel = None
  36. def load_wkv_cuda_kernel(context_length):
  37. global rwkv_cuda_kernel
  38. if not is_kernels_available():
  39. raise ImportError("kernels is not installed, please install it with `pip install kernels`")
  40. from ...integrations.hub_kernels import get_kernel
  41. rwkv_cuda_kernel = get_kernel("kernels-community/rwkv")
  42. rwkv_cuda_kernel.max_seq_length = context_length
  43. class RwkvLinearAttention(torch.autograd.Function):
  44. @staticmethod
  45. def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
  46. batch_size, seq_len, hidden_size = key.size()
  47. if seq_len > rwkv_cuda_kernel.max_seq_length:
  48. raise ValueError(
  49. f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
  50. f"{rwkv_cuda_kernel.max_seq_length} with this model."
  51. )
  52. if batch_size * hidden_size % min(hidden_size, 32) != 0:
  53. raise ValueError(
  54. f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
  55. f"multiple of {min(hidden_size, 32)}."
  56. )
  57. ctx.input_dtype = key.dtype
  58. if (
  59. time_decay.device.type != "cuda"
  60. or time_first.device.type != "cuda"
  61. or key.device.type != "cuda"
  62. or value.device.type != "cuda"
  63. ):
  64. raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
  65. time_decay = -torch.exp(time_decay.float().contiguous())
  66. if key.dtype == torch.float16:
  67. time_first = time_first.float()
  68. key = key.float()
  69. value = value.float()
  70. time_first = time_first.contiguous()
  71. key = key.contiguous()
  72. value = value.contiguous()
  73. # The CUDA kernel will fill this tensor.
  74. output = torch.empty_like(key, memory_format=torch.contiguous_format)
  75. if return_state or state is not None:
  76. if state is None:
  77. state = torch.zeros(
  78. batch_size,
  79. hidden_size,
  80. 3,
  81. dtype=torch.float32,
  82. device=key.device,
  83. memory_format=torch.contiguous_format,
  84. )
  85. state[:, :, 2] -= 1e38
  86. else:
  87. state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
  88. if key.dtype == torch.bfloat16:
  89. forward_func = rwkv_cuda_kernel.forward_with_state_bf16
  90. else:
  91. forward_func = rwkv_cuda_kernel.forward_with_state
  92. forward_func(time_decay, time_first, key, value, output, state)
  93. else:
  94. forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
  95. forward_func(time_decay, time_first, key, value, output)
  96. ctx.save_for_backward(time_decay, time_first, key, value, output)
  97. if state is not None:
  98. state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
  99. return output.to(ctx.input_dtype), state
  100. @staticmethod
  101. # g stands for grad
  102. def backward(ctx, g_output, g_state=None):
  103. input_dtype = ctx.input_dtype
  104. time_decay, time_first, key, value, output = ctx.saved_tensors
  105. # The CUDA kernel will fill those tensors.
  106. g_time_decay = torch.empty_like(
  107. time_decay,
  108. memory_format=torch.contiguous_format,
  109. dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
  110. )
  111. g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)
  112. g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
  113. g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
  114. if input_dtype == torch.float16:
  115. g_output = g_output.float()
  116. backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
  117. backward_func(
  118. time_decay,
  119. time_first,
  120. key,
  121. value,
  122. output,
  123. g_output.contiguous(),
  124. g_time_decay,
  125. g_time_first,
  126. g_key,
  127. g_value,
  128. )
  129. return (
  130. g_time_decay.to(input_dtype),
  131. g_time_first.to(input_dtype),
  132. g_key.to(input_dtype),
  133. g_value.to(input_dtype),
  134. None,
  135. None,
  136. )
  137. def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
  138. # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
  139. # within a torch.no_grad.
  140. _, seq_length, _ = key.size()
  141. output = torch.zeros_like(key)
  142. if state is None:
  143. num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
  144. den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
  145. max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
  146. else:
  147. num_state, den_state, max_state = state
  148. # For numerical stability
  149. # real_numerator_state = num_state * torch.exp(max_state)
  150. # real_denominator_state = den_state * torch.exp(max_state)
  151. time_decay = -torch.exp(time_decay)
  152. for current_index in range(seq_length):
  153. current_key = key[:, current_index].float()
  154. current_value = value[:, current_index]
  155. # wkv computation at time t
  156. max_for_output = torch.maximum(max_state, current_key + time_first)
  157. e1 = torch.exp(max_state - max_for_output)
  158. e2 = torch.exp(current_key + time_first - max_for_output)
  159. numerator = e1 * num_state + e2 * current_value
  160. denominator = e1 * den_state + e2
  161. output[:, current_index] = (numerator / denominator).to(output.dtype)
  162. # Update state for next iteration
  163. max_for_state = torch.maximum(max_state + time_decay, current_key)
  164. e1 = torch.exp(max_state + time_decay - max_for_state)
  165. e2 = torch.exp(current_key - max_for_state)
  166. num_state = e1 * num_state + e2 * current_value
  167. den_state = e1 * den_state + e2
  168. max_state = max_for_state
  169. if return_state or state is not None:
  170. state = [num_state, den_state, max_state]
  171. return output, state
  172. def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
  173. no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
  174. # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
  175. # in this case).
  176. one_token = key.size(1) == 1
  177. if rwkv_cuda_kernel is None or no_cuda or one_token:
  178. return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
  179. else:
  180. return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
  181. class RwkvSelfAttention(nn.Module):
  182. def __init__(self, config, layer_id=0):
  183. super().__init__()
  184. self.config = config
  185. kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length
  186. if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
  187. try:
  188. load_wkv_cuda_kernel(config.context_length)
  189. except Exception:
  190. logger.info("Could not load the custom CUDA kernel for RWKV attention.")
  191. self.layer_id = layer_id
  192. hidden_size = config.hidden_size
  193. attention_hidden_size = (
  194. config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
  195. )
  196. self.attention_hidden_size = attention_hidden_size
  197. self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
  198. self.time_first = nn.Parameter(torch.empty(attention_hidden_size))
  199. self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
  200. self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
  201. self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
  202. self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
  203. self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
  204. self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
  205. self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
  206. self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
  207. # TODO: maybe jit, otherwise move inside forward
  208. def extract_key_value(self, hidden, state=None):
  209. # Mix hidden with the previous timestep to produce key, value, receptance
  210. if hidden.size(1) == 1 and state is not None:
  211. shifted = state[1][:, :, self.layer_id]
  212. else:
  213. shifted = self.time_shift(hidden)
  214. if state is not None:
  215. shifted[:, 0] = state[1][:, :, self.layer_id]
  216. key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
  217. value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
  218. receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
  219. key = self.key(key)
  220. value = self.value(value)
  221. receptance = torch.sigmoid(self.receptance(receptance))
  222. if state is not None:
  223. state[1][:, :, self.layer_id] = hidden[:, -1]
  224. return receptance, key, value, state
  225. def forward(self, hidden, state=None, use_cache=False):
  226. receptance, key, value, state = self.extract_key_value(hidden, state=state)
  227. layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
  228. rwkv, layer_state = rwkv_linear_attention(
  229. self.time_decay,
  230. self.time_first,
  231. key,
  232. value,
  233. state=layer_state,
  234. return_state=use_cache,
  235. )
  236. if layer_state is not None:
  237. state[2][:, :, self.layer_id] = layer_state[0]
  238. state[3][:, :, self.layer_id] = layer_state[1]
  239. state[4][:, :, self.layer_id] = layer_state[2]
  240. return self.output(receptance * rwkv), state
  241. class RwkvFeedForward(nn.Module):
  242. def __init__(self, config, layer_id=0):
  243. super().__init__()
  244. self.config = config
  245. self.layer_id = layer_id
  246. hidden_size = config.hidden_size
  247. intermediate_size = (
  248. config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
  249. )
  250. self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
  251. self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
  252. self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
  253. self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
  254. self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
  255. self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
  256. def forward(self, hidden, state=None):
  257. if hidden.size(1) == 1 and state is not None:
  258. shifted = state[0][:, :, self.layer_id]
  259. else:
  260. shifted = self.time_shift(hidden)
  261. if state is not None:
  262. shifted[:, 0] = state[0][:, :, self.layer_id]
  263. key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
  264. receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
  265. key = torch.square(torch.relu(self.key(key)))
  266. value = self.value(key)
  267. receptance = torch.sigmoid(self.receptance(receptance))
  268. if state is not None:
  269. state[0][:, :, self.layer_id] = hidden[:, -1]
  270. return receptance * value, state
  271. class RwkvBlock(GradientCheckpointingLayer):
  272. def __init__(self, config, layer_id):
  273. super().__init__()
  274. self.config = config
  275. self.layer_id = layer_id
  276. if layer_id == 0:
  277. self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  278. self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  279. self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  280. self.attention = RwkvSelfAttention(config, layer_id)
  281. self.feed_forward = RwkvFeedForward(config, layer_id)
  282. def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
  283. if self.layer_id == 0:
  284. hidden = self.pre_ln(hidden)
  285. attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
  286. hidden = hidden + attention
  287. feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
  288. hidden = hidden + feed_forward
  289. outputs = (hidden, state)
  290. if output_attentions:
  291. outputs += (attention,)
  292. else:
  293. outputs += (None,)
  294. return outputs
  295. @auto_docstring
  296. class RwkvPreTrainedModel(PreTrainedModel):
  297. config: RwkvConfig
  298. base_model_prefix = "rwkv"
  299. _no_split_modules = ["RwkvBlock"]
  300. _keep_in_fp32_modules = ["time_decay", "time_first"]
  301. supports_gradient_checkpointing = True
  302. _is_stateful = True
  303. @torch.no_grad()
  304. def _init_weights(self, module: nn.Module):
  305. """Initialize the weights."""
  306. if isinstance(module, RwkvSelfAttention):
  307. layer_id = module.layer_id
  308. num_hidden_layers = module.config.num_hidden_layers
  309. hidden_size = module.config.hidden_size
  310. attention_hidden_size = module.attention_hidden_size
  311. ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
  312. ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
  313. time_weight = torch.tensor(
  314. [i / hidden_size for i in range(hidden_size)],
  315. dtype=module.time_mix_key.dtype,
  316. device=module.time_mix_key.device,
  317. )
  318. time_weight = time_weight[None, None, :]
  319. decay_speed = [
  320. -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
  321. for h in range(attention_hidden_size)
  322. ]
  323. decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
  324. zigzag = (
  325. torch.tensor(
  326. [(i + 1) % 3 - 1 for i in range(attention_hidden_size)],
  327. dtype=module.time_first.dtype,
  328. device=module.time_first.device,
  329. )
  330. * 0.5
  331. )
  332. init.copy_(module.time_decay, decay_speed)
  333. init.copy_(module.time_first, torch.ones_like(module.time_first * math.log(0.3) + zigzag))
  334. init.copy_(module.time_mix_key, torch.pow(time_weight, ratio_1_to_almost0))
  335. init.copy_(module.time_mix_value, torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
  336. init.copy_(module.time_mix_receptance, torch.pow(time_weight, 0.5 * ratio_1_to_almost0))
  337. elif isinstance(module, RwkvFeedForward):
  338. layer_id = module.layer_id
  339. num_hidden_layers = module.config.num_hidden_layers
  340. hidden_size = module.config.hidden_size
  341. ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
  342. time_weight = torch.tensor(
  343. [i / hidden_size for i in range(hidden_size)],
  344. dtype=module.time_mix_key.dtype,
  345. device=module.time_mix_key.device,
  346. )
  347. time_weight = time_weight[None, None, :]
  348. init.copy_(module.time_mix_key, torch.pow(time_weight, ratio_1_to_almost0))
  349. init.copy_(module.time_mix_receptance, torch.pow(time_weight, ratio_1_to_almost0))
  350. elif isinstance(module, nn.Linear):
  351. shape = module.weight.shape
  352. gain = 1.0
  353. scale = 1.0 # extra scale for gain
  354. if module.bias is not None:
  355. init.zeros_(module.bias)
  356. if shape[0] > shape[1]:
  357. gain = math.sqrt(shape[0] / shape[1])
  358. if shape[0] == self.config.vocab_size and shape[1] == self.config.hidden_size: # final projection?
  359. scale = 0.5
  360. gain *= scale
  361. init.orthogonal_(module.weight, gain=gain)
  362. elif isinstance(module, nn.Embedding):
  363. shape = module.weight.shape
  364. gain = 1e-4 * math.sqrt(max(shape[0], shape[1]))
  365. init.orthogonal_(module.weight, gain=gain)
  366. elif isinstance(module, nn.LayerNorm):
  367. init.ones_(module.weight)
  368. init.zeros_(module.bias)
  369. @dataclass
  370. @auto_docstring(
  371. custom_intro="""
  372. Class for the RWKV model outputs.
  373. """
  374. )
  375. class RwkvOutput(ModelOutput):
  376. r"""
  377. state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
  378. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  379. avoid providing the old `input_ids`.
  380. """
  381. last_hidden_state: torch.FloatTensor | None = None
  382. state: list[torch.FloatTensor] | None = None
  383. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  384. attentions: tuple[torch.FloatTensor, ...] | None = None
  385. @dataclass
  386. @auto_docstring(
  387. custom_intro="""
  388. Base class for causal language model (or autoregressive) outputs.
  389. """
  390. )
  391. class RwkvCausalLMOutput(ModelOutput):
  392. r"""
  393. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  394. Language modeling loss (for next-token prediction).
  395. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  396. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  397. state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
  398. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  399. avoid providing the old `input_ids`.
  400. """
  401. loss: torch.FloatTensor | None = None
  402. logits: torch.FloatTensor | None = None
  403. state: list[torch.FloatTensor] | None = None
  404. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  405. attentions: tuple[torch.FloatTensor, ...] | None = None
  406. @auto_docstring
  407. class RwkvModel(RwkvPreTrainedModel):
  408. def __init__(self, config):
  409. super().__init__(config)
  410. self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  411. self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
  412. self.ln_out = nn.LayerNorm(config.hidden_size)
  413. self.layers_are_rescaled = False
  414. self.gradient_checkpointing = False
  415. # Initialize weights and apply final processing
  416. self.post_init()
  417. def get_input_embeddings(self):
  418. return self.embeddings
  419. def set_input_embeddings(self, new_embeddings):
  420. self.embeddings = new_embeddings
  421. @auto_docstring
  422. def forward(
  423. self,
  424. input_ids: torch.LongTensor | None = None,
  425. attention_mask: torch.LongTensor | None = None,
  426. inputs_embeds: torch.FloatTensor | None = None,
  427. state: list[torch.FloatTensor] | None = None,
  428. use_cache: bool | None = None,
  429. output_attentions: bool | None = None,
  430. output_hidden_states: bool | None = None,
  431. return_dict: bool | None = None,
  432. **kwargs,
  433. ) -> tuple | RwkvOutput:
  434. r"""
  435. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  436. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  437. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  438. sequence tokens in the vocabulary.
  439. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  440. `input_ids`.
  441. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  442. [`PreTrainedTokenizer.__call__`] for details.
  443. [What are input IDs?](../glossary#input-ids)
  444. state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
  445. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  446. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  447. use_cache (`bool`, *optional*):
  448. If set to `True`, the last state is returned and can be used to quickly generate the next logits.
  449. """
  450. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  451. output_hidden_states = (
  452. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  453. )
  454. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  455. return_dict = return_dict if return_dict is not None else self.config.return_dict
  456. if attention_mask is not None:
  457. logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
  458. if self.training == self.layers_are_rescaled:
  459. self._rescale_layers()
  460. if input_ids is not None and inputs_embeds is not None:
  461. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  462. elif input_ids is None and inputs_embeds is None:
  463. raise ValueError("You have to specify either input_ids or inputs_embeds")
  464. if inputs_embeds is None:
  465. inputs_embeds = self.embeddings(input_ids)
  466. if use_cache and state is None:
  467. shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
  468. state = [
  469. torch.zeros(
  470. *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
  471. )
  472. for i in range(5)
  473. ]
  474. state[4] -= 1e30
  475. if self.gradient_checkpointing and self.training:
  476. if use_cache:
  477. logger.warning_once(
  478. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  479. )
  480. use_cache = False
  481. hidden_states = inputs_embeds
  482. all_self_attentions = () if output_attentions else None
  483. all_hidden_states = () if output_hidden_states else None
  484. for idx, block in enumerate(self.blocks):
  485. hidden_states, state, attentions = block(
  486. hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
  487. )
  488. if (
  489. self.layers_are_rescaled
  490. and self.config.rescale_every > 0
  491. and (idx + 1) % self.config.rescale_every == 0
  492. ):
  493. hidden_states = hidden_states / 2
  494. if output_hidden_states:
  495. all_hidden_states = all_hidden_states + (hidden_states,)
  496. if output_attentions:
  497. all_self_attentions = all_self_attentions + (attentions,)
  498. hidden_states = self.ln_out(hidden_states)
  499. if output_hidden_states:
  500. all_hidden_states = all_hidden_states + (hidden_states,)
  501. if not return_dict:
  502. return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None)
  503. return RwkvOutput(
  504. last_hidden_state=hidden_states,
  505. state=state,
  506. hidden_states=all_hidden_states,
  507. attentions=all_self_attentions,
  508. )
  509. def _rescale_layers(self):
  510. # Layers should be rescaled for inference only.
  511. if self.layers_are_rescaled == (not self.training):
  512. return
  513. if self.config.rescale_every > 0:
  514. with torch.no_grad():
  515. for block_id, block in enumerate(self.blocks):
  516. if self.training:
  517. block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
  518. block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
  519. else:
  520. # Deal with quantization statistics
  521. if hasattr(block.attention.output.weight, "SCB"):
  522. block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
  523. block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
  524. elif hasattr(block.attention.output.weight, "quant_state"):
  525. self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
  526. self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
  527. else:
  528. block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
  529. block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
  530. self.layers_are_rescaled = not self.training
  531. def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
  532. r"""
  533. Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
  534. be quantized again.
  535. """
  536. if not is_bitsandbytes_available():
  537. raise ImportError("Please install bitsandbytes to use this method.")
  538. import bitsandbytes as bnb
  539. dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
  540. dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
  541. # re-quantize the model:
  542. # we need to put it first on CPU then back to the device
  543. # this will create an overhead :/
  544. # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
  545. # bugs with bnb
  546. quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
  547. setattr(target_layer, "weight", quant_weight)
  548. @auto_docstring(
  549. custom_intro="""
  550. The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
  551. embeddings).
  552. """
  553. )
  554. class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin):
  555. _tied_weights_keys = {"head.weight": "rwkv.embeddings.weight"}
  556. def __init__(self, config):
  557. super().__init__(config)
  558. self.rwkv = RwkvModel(config)
  559. self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  560. # Initialize weights and apply final processing
  561. self.post_init()
  562. def get_output_embeddings(self):
  563. return self.head
  564. def set_output_embeddings(self, new_embeddings):
  565. self.head = new_embeddings
  566. @auto_docstring
  567. def forward(
  568. self,
  569. input_ids: torch.LongTensor | None = None,
  570. attention_mask: torch.LongTensor | None = None,
  571. inputs_embeds: torch.FloatTensor | None = None,
  572. state: list[torch.FloatTensor] | None = None,
  573. labels: torch.LongTensor | None = None,
  574. use_cache: bool | None = None,
  575. output_attentions: bool | None = None,
  576. output_hidden_states: bool | None = None,
  577. return_dict: bool | None = None,
  578. logits_to_keep: int | torch.Tensor = 0,
  579. **kwargs,
  580. ) -> tuple | RwkvCausalLMOutput:
  581. r"""
  582. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  583. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  584. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  585. sequence tokens in the vocabulary.
  586. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  587. `input_ids`.
  588. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  589. [`PreTrainedTokenizer.__call__`] for details.
  590. [What are input IDs?](../glossary#input-ids)
  591. state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
  592. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  593. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  594. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  595. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  596. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  597. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  598. use_cache (`bool`, *optional*):
  599. If set to `True`, the last state is returned and can be used to quickly generate the next logits.
  600. """
  601. return_dict = return_dict if return_dict is not None else self.config.return_dict
  602. rwkv_outputs = self.rwkv(
  603. input_ids,
  604. inputs_embeds=inputs_embeds,
  605. state=state,
  606. use_cache=use_cache,
  607. output_attentions=output_attentions,
  608. output_hidden_states=output_hidden_states,
  609. return_dict=return_dict,
  610. )
  611. hidden_states = rwkv_outputs[0]
  612. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  613. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  614. logits = self.head(hidden_states[:, slice_indices, :])
  615. loss = None
  616. if labels is not None:
  617. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  618. if not return_dict:
  619. output = (logits,) + rwkv_outputs[1:]
  620. return ((loss,) + output) if loss is not None else output
  621. return RwkvCausalLMOutput(
  622. loss=loss,
  623. logits=logits,
  624. state=rwkv_outputs.state,
  625. hidden_states=rwkv_outputs.hidden_states,
  626. attentions=rwkv_outputs.attentions,
  627. )
  628. __all__ = ["RwkvForCausalLM", "RwkvModel", "RwkvPreTrainedModel"]