modeling_persimmon.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Persimmon model."""
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...masking_utils import create_causal_mask
  28. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  29. from ...modeling_layers import (
  30. GenericForSequenceClassification,
  31. GenericForTokenClassification,
  32. GradientCheckpointingLayer,
  33. )
  34. from ...modeling_outputs import (
  35. BaseModelOutputWithPast,
  36. CausalLMOutputWithPast,
  37. )
  38. from ...modeling_rope_utils import (
  39. ROPE_INIT_FUNCTIONS,
  40. dynamic_rope_update,
  41. )
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  45. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_persimmon import PersimmonConfig
  48. logger = logging.get_logger(__name__)
  49. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
  50. class PersimmonRotaryEmbedding(nn.Module):
  51. inv_freq: torch.Tensor # fix linting for `register_buffer`
  52. def __init__(self, config: PersimmonConfig, device=None):
  53. super().__init__()
  54. self.max_seq_len_cached = config.max_position_embeddings
  55. self.original_max_seq_len = config.max_position_embeddings
  56. self.config = config
  57. self.rope_type = self.config.rope_parameters["rope_type"]
  58. rope_init_fn: Callable = self.compute_default_rope_parameters
  59. if self.rope_type != "default":
  60. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  61. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  62. self.register_buffer("inv_freq", inv_freq, persistent=False)
  63. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  64. @staticmethod
  65. # Ignore copy
  66. def compute_default_rope_parameters(
  67. config: PersimmonConfig | None = None,
  68. device: Optional["torch.device"] = None,
  69. seq_len: int | None = None,
  70. ) -> tuple["torch.Tensor", float]:
  71. """
  72. Computes the inverse frequencies according to the original RoPE implementation
  73. Args:
  74. config ([`~transformers.PreTrainedConfig`]):
  75. The model configuration.
  76. device (`torch.device`):
  77. The device to use for initialization of the inverse frequencies.
  78. seq_len (`int`, *optional*):
  79. The current sequence length. Unused for this type of RoPE.
  80. Returns:
  81. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  82. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  83. """
  84. base = config.rope_parameters["rope_theta"]
  85. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  86. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  87. dim = int(head_dim * partial_rotary_factor)
  88. attention_factor = 1.0 # Unused in this type of RoPE
  89. # Compute the inverse frequencies
  90. inv_freq = 1.0 / (
  91. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  92. )
  93. return inv_freq, attention_factor
  94. @torch.no_grad()
  95. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  96. def forward(self, x, position_ids):
  97. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  98. position_ids_expanded = position_ids[:, None, :].float()
  99. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  100. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  101. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  102. emb = torch.cat((freqs, freqs), dim=-1)
  103. cos = emb.cos() * self.attention_scaling
  104. sin = emb.sin() * self.attention_scaling
  105. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  106. # Copied from transformers.models.llama.modeling_llama.rotate_half
  107. def rotate_half(x):
  108. """Rotates half the hidden dims of the input."""
  109. x1 = x[..., : x.shape[-1] // 2]
  110. x2 = x[..., x.shape[-1] // 2 :]
  111. return torch.cat((-x2, x1), dim=-1)
  112. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  113. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  114. """Applies Rotary Position Embedding to the query and key tensors.
  115. Args:
  116. q (`torch.Tensor`): The query tensor.
  117. k (`torch.Tensor`): The key tensor.
  118. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  119. sin (`torch.Tensor`): The sine part of the rotary embedding.
  120. unsqueeze_dim (`int`, *optional*, defaults to 1):
  121. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  122. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  123. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  124. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  125. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  126. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  127. Returns:
  128. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  129. """
  130. cos = cos.unsqueeze(unsqueeze_dim)
  131. sin = sin.unsqueeze(unsqueeze_dim)
  132. q_embed = (q * cos) + (rotate_half(q) * sin)
  133. k_embed = (k * cos) + (rotate_half(k) * sin)
  134. return q_embed, k_embed
  135. # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXMLP with GPTNeoX->Persimmon
  136. class PersimmonMLP(nn.Module):
  137. def __init__(self, config):
  138. super().__init__()
  139. self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
  140. self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
  141. self.act = ACT2FN[config.hidden_act]
  142. def forward(self, hidden_states):
  143. hidden_states = self.dense_h_to_4h(hidden_states)
  144. hidden_states = self.act(hidden_states)
  145. hidden_states = self.dense_4h_to_h(hidden_states)
  146. return hidden_states
  147. def eager_attention_forward(
  148. module: nn.Module,
  149. query: torch.Tensor,
  150. key: torch.Tensor,
  151. value: torch.Tensor,
  152. attention_mask: torch.Tensor | None,
  153. scaling: float,
  154. dropout: float = 0.0,
  155. **kwargs,
  156. ):
  157. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  158. if attention_mask is not None:
  159. attn_weights = attn_weights + attention_mask
  160. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  161. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  162. attn_output = torch.matmul(attn_weights, value)
  163. attn_output = attn_output.transpose(1, 2).contiguous()
  164. return attn_output, attn_weights
  165. class PersimmonAttention(nn.Module):
  166. """Multi-headed attention from 'Attention Is All You Need' paper"""
  167. def __init__(self, config: PersimmonConfig, layer_idx: int | None = None):
  168. super().__init__()
  169. self.config = config
  170. self.layer_idx = layer_idx
  171. if layer_idx is None:
  172. logger.warning_once(
  173. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  174. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  175. "when creating this class."
  176. )
  177. self.hidden_size = config.hidden_size
  178. self.num_heads = config.num_attention_heads
  179. self.head_dim = self.hidden_size // self.num_heads
  180. self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"])
  181. self.is_causal = True
  182. if (self.head_dim * self.num_heads) != self.hidden_size:
  183. raise ValueError(
  184. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  185. f" and `num_heads`: {self.num_heads})."
  186. )
  187. self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
  188. self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
  189. self.qk_layernorm = config.qk_layernorm
  190. self.scaling = self.head_dim**-0.5
  191. if self.qk_layernorm:
  192. self.q_layernorm = nn.LayerNorm(
  193. config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
  194. )
  195. self.k_layernorm = nn.LayerNorm(
  196. config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
  197. )
  198. self.attention_dropout = nn.Dropout(config.attention_dropout)
  199. def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  200. """
  201. Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
  202. storage as `fused_qkv`
  203. Args:
  204. fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
  205. Returns:
  206. query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
  207. value: [batch_size, seq_length, num_heads, head_dim]
  208. """
  209. batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
  210. fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
  211. return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
  212. def forward(
  213. self,
  214. hidden_states: torch.Tensor,
  215. attention_mask: torch.Tensor | None = None,
  216. position_ids: torch.LongTensor | None = None,
  217. past_key_values: Cache | None = None,
  218. output_attentions: bool = False,
  219. use_cache: bool = False,
  220. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  221. **kwargs: Unpack[FlashAttentionKwargs],
  222. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  223. bsz, q_len, _ = hidden_states.size()
  224. # [batch_size, seq_length, 3 x hidden_size]
  225. fused_qkv = self.query_key_value(hidden_states)
  226. # 3 x [batch_size, seq_length, num_heads, head_dim]
  227. (query_states, key_states, value_states) = self._split_heads(fused_qkv)
  228. if self.qk_layernorm:
  229. query_states = self.q_layernorm(query_states)
  230. key_states = self.k_layernorm(key_states)
  231. # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
  232. query_states = query_states.transpose(1, 2)
  233. value_states = value_states.transpose(1, 2)
  234. key_states = key_states.transpose(1, 2)
  235. cos, sin = position_embeddings
  236. query_rot, query_pass = (
  237. query_states[..., : self.rotary_ndims],
  238. query_states[..., self.rotary_ndims :],
  239. )
  240. key_rot, key_pass = (
  241. key_states[..., : self.rotary_ndims],
  242. key_states[..., self.rotary_ndims :],
  243. )
  244. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  245. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  246. # [batch_size, seq_length, num_heads, head_dim]
  247. query_states = torch.cat((query_rot, query_pass), dim=-1)
  248. key_states = torch.cat((key_rot, key_pass), dim=-1)
  249. if past_key_values is not None:
  250. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  251. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  252. self.config._attn_implementation, eager_attention_forward
  253. )
  254. attn_output, attn_weights = attention_interface(
  255. self,
  256. query_states,
  257. key_states,
  258. value_states,
  259. attention_mask,
  260. dropout=0.0 if not self.training else self.config.attention_dropout,
  261. scaling=self.scaling,
  262. **kwargs,
  263. )
  264. attn_output = attn_output.reshape(bsz, q_len, -1)
  265. attn_output = self.dense(attn_output)
  266. return attn_output, attn_weights
  267. class PersimmonDecoderLayer(GradientCheckpointingLayer):
  268. def __init__(self, config: PersimmonConfig, layer_idx: int):
  269. super().__init__()
  270. self.hidden_size = config.hidden_size
  271. self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx)
  272. self.mlp = PersimmonMLP(config)
  273. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  274. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  275. self.dropout = nn.Dropout(config.hidden_dropout)
  276. def forward(
  277. self,
  278. hidden_states: torch.Tensor,
  279. attention_mask: torch.Tensor | None = None,
  280. position_ids: torch.LongTensor | None = None,
  281. past_key_values: Cache | None = None,
  282. use_cache: bool | None = False,
  283. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  284. **kwargs: Unpack[FlashAttentionKwargs],
  285. ) -> torch.Tensor:
  286. residual = hidden_states
  287. hidden_states = self.input_layernorm(hidden_states)
  288. # Self Attention
  289. hidden_states, _ = self.self_attn(
  290. hidden_states=hidden_states,
  291. attention_mask=attention_mask,
  292. position_ids=position_ids,
  293. past_key_values=past_key_values,
  294. use_cache=use_cache,
  295. position_embeddings=position_embeddings,
  296. **kwargs,
  297. )
  298. hidden_states = residual + hidden_states
  299. # Fully Connected
  300. residual = hidden_states
  301. hidden_states = self.post_attention_layernorm(hidden_states)
  302. hidden_states = self.mlp(hidden_states)
  303. hidden_states = self.dropout(hidden_states)
  304. hidden_states = hidden_states + residual
  305. return hidden_states
  306. @auto_docstring
  307. class PersimmonPreTrainedModel(PreTrainedModel):
  308. config: PersimmonConfig
  309. base_model_prefix = "model"
  310. supports_gradient_checkpointing = True
  311. _no_split_modules = ["PersimmonDecoderLayer"]
  312. _skip_keys_device_placement = "past_key_values"
  313. _can_compile_fullgraph = True
  314. _supports_sdpa = True
  315. _supports_flash_attn = True
  316. _supports_attention_backend = True
  317. _can_record_outputs = {
  318. "hidden_states": PersimmonDecoderLayer,
  319. "attentions": PersimmonAttention,
  320. }
  321. @auto_docstring
  322. class PersimmonModel(PersimmonPreTrainedModel):
  323. """
  324. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`]
  325. Args:
  326. config: PersimmonConfig
  327. """
  328. def __init__(self, config: PersimmonConfig):
  329. super().__init__(config)
  330. self.padding_idx = config.pad_token_id
  331. self.vocab_size = config.vocab_size
  332. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  333. self.layers = nn.ModuleList(
  334. [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  335. )
  336. self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  337. self.rotary_emb = PersimmonRotaryEmbedding(config=self.config)
  338. self.gradient_checkpointing = False
  339. # Initialize weights and apply final processing
  340. self.post_init()
  341. @merge_with_config_defaults
  342. @capture_outputs
  343. @auto_docstring
  344. def forward(
  345. self,
  346. input_ids: torch.LongTensor | None = None,
  347. attention_mask: torch.Tensor | None = None,
  348. position_ids: torch.LongTensor | None = None,
  349. past_key_values: Cache | None = None,
  350. inputs_embeds: torch.FloatTensor | None = None,
  351. use_cache: bool | None = None,
  352. **kwargs: Unpack[TransformersKwargs],
  353. ) -> BaseModelOutputWithPast:
  354. if (input_ids is None) ^ (inputs_embeds is not None):
  355. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  356. if use_cache and past_key_values is None:
  357. past_key_values = DynamicCache(config=self.config)
  358. if inputs_embeds is None:
  359. inputs_embeds = self.embed_tokens(input_ids)
  360. if position_ids is None:
  361. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  362. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  363. position_ids = position_ids.unsqueeze(0)
  364. causal_mask = create_causal_mask(
  365. config=self.config,
  366. inputs_embeds=inputs_embeds,
  367. attention_mask=attention_mask,
  368. past_key_values=past_key_values,
  369. position_ids=position_ids,
  370. )
  371. hidden_states = inputs_embeds
  372. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  373. for decoder_layer in self.layers:
  374. hidden_states = decoder_layer(
  375. hidden_states,
  376. attention_mask=causal_mask,
  377. position_ids=position_ids,
  378. past_key_values=past_key_values,
  379. use_cache=use_cache,
  380. position_embeddings=position_embeddings,
  381. **kwargs,
  382. )
  383. hidden_states = self.final_layernorm(hidden_states)
  384. return BaseModelOutputWithPast(
  385. last_hidden_state=hidden_states,
  386. past_key_values=past_key_values,
  387. )
  388. class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin):
  389. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  390. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon
  391. def __init__(self, config):
  392. super().__init__(config)
  393. self.model = PersimmonModel(config)
  394. self.vocab_size = config.vocab_size
  395. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  396. # Initialize weights and apply final processing
  397. self.post_init()
  398. @can_return_tuple
  399. @auto_docstring
  400. def forward(
  401. self,
  402. input_ids: torch.LongTensor | None = None,
  403. attention_mask: torch.Tensor | None = None,
  404. position_ids: torch.LongTensor | None = None,
  405. past_key_values: Cache | None = None,
  406. inputs_embeds: torch.FloatTensor | None = None,
  407. labels: torch.LongTensor | None = None,
  408. use_cache: bool | None = None,
  409. logits_to_keep: int | torch.Tensor = 0,
  410. **kwargs: Unpack[TransformersKwargs],
  411. ) -> CausalLMOutputWithPast:
  412. r"""
  413. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  414. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  415. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  416. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  417. Example:
  418. ```python
  419. >>> from transformers import AutoTokenizer, PersimmonForCausalLM
  420. >>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base")
  421. >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
  422. >>> prompt = "human: Hey, what should I eat for dinner?"
  423. >>> inputs = tokenizer(prompt, return_tensors="pt")
  424. >>> # Generate
  425. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  426. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  427. 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n'
  428. ```"""
  429. outputs: BaseModelOutputWithPast = self.model(
  430. input_ids=input_ids,
  431. attention_mask=attention_mask,
  432. position_ids=position_ids,
  433. past_key_values=past_key_values,
  434. inputs_embeds=inputs_embeds,
  435. use_cache=use_cache,
  436. **kwargs,
  437. )
  438. hidden_states = outputs.last_hidden_state
  439. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  440. logits = self.lm_head(hidden_states[:, slice_indices, :])
  441. loss = None
  442. if labels is not None:
  443. loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
  444. return CausalLMOutputWithPast(
  445. loss=loss,
  446. logits=logits,
  447. past_key_values=outputs.past_key_values,
  448. hidden_states=outputs.hidden_states,
  449. attentions=outputs.attentions,
  450. )
  451. class PersimmonForSequenceClassification(GenericForSequenceClassification, PersimmonPreTrainedModel): ...
  452. class PersimmonForTokenClassification(GenericForTokenClassification, PersimmonPreTrainedModel): ...
  453. __all__ = [
  454. "PersimmonForCausalLM",
  455. "PersimmonModel",
  456. "PersimmonPreTrainedModel",
  457. "PersimmonForSequenceClassification",
  458. "PersimmonForTokenClassification",
  459. ]