modeling_chameleon.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132
  1. # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Chameleon model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. from functools import cached_property
  18. from typing import Optional
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_causal_mask
  26. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
  29. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...utils import (
  33. TransformersKwargs,
  34. auto_docstring,
  35. can_return_tuple,
  36. logging,
  37. torch_compilable_check,
  38. )
  39. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  40. from ...utils.output_capturing import capture_outputs
  41. from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
  42. logger = logging.get_logger(__name__)
  43. @dataclass
  44. @auto_docstring
  45. class ChameleonVQVAEModelOutput(BaseModelOutputWithPooling):
  46. r"""
  47. quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  48. Quantized last hidden state from the VQ-VAE model.
  49. image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
  50. Indices of the image tokens predicted by the VQ-VAE model.
  51. embedding_loss (`torch.FloatTensor`):
  52. The embedding loss computed during quantization.
  53. """
  54. quantized_last_hidden_state: torch.FloatTensor | None = None
  55. image_tokens: torch.FloatTensor | None = None
  56. embedding_loss: torch.FloatTensor | None = None
  57. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon
  58. class ChameleonRMSNorm(nn.Module):
  59. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  60. """
  61. ChameleonRMSNorm is equivalent to T5LayerNorm
  62. """
  63. super().__init__()
  64. self.weight = nn.Parameter(torch.ones(hidden_size))
  65. self.variance_epsilon = eps
  66. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  67. input_dtype = hidden_states.dtype
  68. hidden_states = hidden_states.to(torch.float32)
  69. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  70. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  71. return self.weight * hidden_states.to(input_dtype)
  72. def extra_repr(self):
  73. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  74. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
  75. class ChameleonRotaryEmbedding(nn.Module):
  76. inv_freq: torch.Tensor # fix linting for `register_buffer`
  77. def __init__(self, config: ChameleonConfig, device=None):
  78. super().__init__()
  79. self.max_seq_len_cached = config.max_position_embeddings
  80. self.original_max_seq_len = config.max_position_embeddings
  81. self.config = config
  82. self.rope_type = self.config.rope_parameters["rope_type"]
  83. rope_init_fn: Callable = self.compute_default_rope_parameters
  84. if self.rope_type != "default":
  85. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  86. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  87. self.register_buffer("inv_freq", inv_freq, persistent=False)
  88. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  89. @staticmethod
  90. def compute_default_rope_parameters(
  91. config: ChameleonConfig | None = None,
  92. device: Optional["torch.device"] = None,
  93. seq_len: int | None = None,
  94. ) -> tuple["torch.Tensor", float]:
  95. """
  96. Computes the inverse frequencies according to the original RoPE implementation
  97. Args:
  98. config ([`~transformers.PreTrainedConfig`]):
  99. The model configuration.
  100. device (`torch.device`):
  101. The device to use for initialization of the inverse frequencies.
  102. seq_len (`int`, *optional*):
  103. The current sequence length. Unused for this type of RoPE.
  104. Returns:
  105. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  106. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  107. """
  108. base = config.rope_parameters["rope_theta"]
  109. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  110. attention_factor = 1.0 # Unused in this type of RoPE
  111. # Compute the inverse frequencies
  112. inv_freq = 1.0 / (
  113. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  114. )
  115. return inv_freq, attention_factor
  116. @torch.no_grad()
  117. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  118. def forward(self, x, position_ids):
  119. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  120. position_ids_expanded = position_ids[:, None, :].float()
  121. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  122. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  123. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  124. emb = torch.cat((freqs, freqs), dim=-1)
  125. cos = emb.cos() * self.attention_scaling
  126. sin = emb.sin() * self.attention_scaling
  127. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  128. # Copied from transformers.models.llama.modeling_llama.rotate_half
  129. def rotate_half(x):
  130. """Rotates half the hidden dims of the input."""
  131. x1 = x[..., : x.shape[-1] // 2]
  132. x2 = x[..., x.shape[-1] // 2 :]
  133. return torch.cat((-x2, x1), dim=-1)
  134. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  135. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  136. """Applies Rotary Position Embedding to the query and key tensors.
  137. Args:
  138. q (`torch.Tensor`): The query tensor.
  139. k (`torch.Tensor`): The key tensor.
  140. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  141. sin (`torch.Tensor`): The sine part of the rotary embedding.
  142. unsqueeze_dim (`int`, *optional*, defaults to 1):
  143. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  144. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  145. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  146. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  147. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  148. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  149. Returns:
  150. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  151. """
  152. cos = cos.unsqueeze(unsqueeze_dim)
  153. sin = sin.unsqueeze(unsqueeze_dim)
  154. q_embed = (q * cos) + (rotate_half(q) * sin)
  155. k_embed = (k * cos) + (rotate_half(k) * sin)
  156. return q_embed, k_embed
  157. # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon
  158. class ChameleonMLP(nn.Module):
  159. def __init__(self, config):
  160. super().__init__()
  161. self.config = config
  162. self.hidden_size = config.hidden_size
  163. self.intermediate_size = config.intermediate_size
  164. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  165. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  166. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  167. self.act_fn = ACT2FN[config.hidden_act]
  168. # Ignore copy
  169. def forward(self, x):
  170. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  171. return down_proj
  172. class ChameleonLayerNorm(nn.LayerNorm):
  173. """
  174. LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
  175. from each shard separately to each head, instead of reducing. We can apply each head's own
  176. gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
  177. in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
  178. """
  179. def __init__(self, hidden_size, *args, **kwargs):
  180. super().__init__(hidden_size, *args, **kwargs)
  181. self.normalized_shape = (hidden_size[-1],)
  182. def forward(self, hidden_states):
  183. hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
  184. hidden_states = hidden_states * self.weight + self.bias
  185. return hidden_states
  186. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  187. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  188. """
  189. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  190. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  191. """
  192. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  193. if n_rep == 1:
  194. return hidden_states
  195. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  196. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  197. # Copied from transformers.models.llama.modeling_llama.eager_attention_forward
  198. def eager_attention_forward(
  199. module: nn.Module,
  200. query: torch.Tensor,
  201. key: torch.Tensor,
  202. value: torch.Tensor,
  203. attention_mask: torch.Tensor | None,
  204. scaling: float,
  205. dropout: float = 0.0,
  206. **kwargs: Unpack[TransformersKwargs],
  207. ):
  208. key_states = repeat_kv(key, module.num_key_value_groups)
  209. value_states = repeat_kv(value, module.num_key_value_groups)
  210. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  211. if attention_mask is not None:
  212. attn_weights = attn_weights + attention_mask
  213. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  214. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  215. attn_output = torch.matmul(attn_weights, value_states)
  216. attn_output = attn_output.transpose(1, 2).contiguous()
  217. return attn_output, attn_weights
  218. class ChameleonAttention(nn.Module):
  219. """Multi-headed attention from 'Attention Is All You Need' paper"""
  220. def __init__(self, config: ChameleonConfig, layer_idx: int | None = None):
  221. super().__init__()
  222. self.config = config
  223. self.layer_idx = layer_idx
  224. if layer_idx is None:
  225. logger.warning_once(
  226. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  227. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  228. "when creating this class."
  229. )
  230. self.attention_dropout = config.attention_dropout
  231. self.hidden_size = config.hidden_size
  232. self.num_heads = config.num_attention_heads
  233. self.head_dim = self.hidden_size // self.num_heads
  234. self.num_key_value_heads = config.num_key_value_heads
  235. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  236. self.max_position_embeddings = config.max_position_embeddings
  237. self.is_causal = True
  238. self.model_parallel_size = config.model_parallel_size
  239. self.scaling = self.head_dim**-0.5
  240. if (self.head_dim * self.num_heads) != self.hidden_size:
  241. raise ValueError(
  242. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  243. f" and `num_heads`: {self.num_heads})."
  244. )
  245. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  246. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  247. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  248. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  249. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  250. self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
  251. def forward(
  252. self,
  253. hidden_states: torch.Tensor,
  254. attention_mask: torch.Tensor | None = None,
  255. position_ids: torch.LongTensor | None = None,
  256. past_key_values: Cache | None = None,
  257. output_attentions: bool = False,
  258. use_cache: bool = False,
  259. position_embeddings: torch.Tensor | None = None,
  260. **kwargs,
  261. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  262. bsz, q_len, _ = hidden_states.size()
  263. query_states = self.q_proj(hidden_states)
  264. key_states = self.k_proj(hidden_states)
  265. value_states = self.v_proj(hidden_states)
  266. query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
  267. query_states = self.q_norm(query_states)
  268. key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
  269. key_states = self.k_norm(key_states)
  270. query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  271. key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  272. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  273. cos, sin = position_embeddings
  274. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  275. if past_key_values is not None:
  276. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  277. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  278. self.config._attn_implementation, eager_attention_forward
  279. )
  280. attn_output, attn_weights = attention_interface(
  281. self,
  282. query_states,
  283. key_states,
  284. value_states,
  285. attention_mask,
  286. dropout=0.0 if not self.training else self.attention_dropout,
  287. scaling=self.scaling,
  288. **kwargs,
  289. )
  290. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  291. attn_output = self.o_proj(attn_output)
  292. return attn_output, attn_weights
  293. # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
  294. class ChameleonDecoderLayer(GradientCheckpointingLayer):
  295. def __init__(self, config: ChameleonConfig, layer_idx: int):
  296. super().__init__()
  297. self.hidden_size = config.hidden_size
  298. self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
  299. self.mlp = ChameleonMLP(config)
  300. self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  301. self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  302. def forward(
  303. self,
  304. hidden_states: torch.Tensor,
  305. attention_mask: torch.Tensor | None = None,
  306. position_ids: torch.LongTensor | None = None,
  307. past_key_values: Cache | None = None,
  308. output_attentions: bool | None = False,
  309. use_cache: bool | None = False,
  310. position_embeddings: torch.Tensor | None = None,
  311. **kwargs,
  312. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  313. """
  314. Args:
  315. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  316. attention_mask (`torch.FloatTensor`, *optional*):
  317. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  318. query_sequence_length, key_sequence_length)` if default attention is used.
  319. output_attentions (`bool`, *optional*):
  320. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  321. returned tensors for more detail.
  322. use_cache (`bool`, *optional*):
  323. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  324. (see `past_key_values`).
  325. past_key_values (`Cache`, *optional*): cached past key and value projection states
  326. kwargs (`dict`, *optional*):
  327. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  328. into the model
  329. """
  330. residual = hidden_states
  331. hidden_states = self.input_layernorm(hidden_states)
  332. # Self Attention
  333. hidden_states, self_attn_weights = self.self_attn(
  334. hidden_states=hidden_states,
  335. attention_mask=attention_mask,
  336. position_ids=position_ids,
  337. past_key_values=past_key_values,
  338. output_attentions=output_attentions,
  339. use_cache=use_cache,
  340. position_embeddings=position_embeddings,
  341. **kwargs,
  342. )
  343. hidden_states = residual + hidden_states
  344. # Fully Connected
  345. residual = hidden_states
  346. hidden_states = self.post_attention_layernorm(hidden_states)
  347. hidden_states = self.mlp(hidden_states)
  348. hidden_states = residual + hidden_states
  349. outputs = (hidden_states,)
  350. if output_attentions:
  351. outputs += (self_attn_weights,)
  352. return outputs
  353. class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
  354. def __init__(self, config: ChameleonConfig, layer_idx: int):
  355. super().__init__()
  356. self.hidden_size = config.hidden_size
  357. self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
  358. self.mlp = ChameleonMLP(config)
  359. self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  360. self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  361. def forward(
  362. self,
  363. hidden_states: torch.Tensor,
  364. attention_mask: torch.Tensor | None = None,
  365. position_ids: torch.LongTensor | None = None,
  366. past_key_values: Cache | None = None,
  367. output_attentions: bool | None = False,
  368. use_cache: bool | None = False,
  369. position_embeddings: torch.Tensor | None = None,
  370. **kwargs,
  371. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  372. """
  373. Args:
  374. hidden_states (`torch.FloatTensor`):
  375. input to the layer of shape `(batch, seq_len, embed_dim)`
  376. attention_mask (`torch.FloatTensor`, *optional*):
  377. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  378. query_sequence_length, key_sequence_length)` if default attention is used.
  379. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  380. Indices of positions of each input sequence tokens in the position embeddings
  381. past_key_values (`Cache`, *optional*): cached past key and value projection states
  382. output_attentions (`bool`, *optional*):
  383. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  384. returned tensors for more detail.
  385. use_cache (`bool`, *optional*):
  386. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  387. (see `past_key_values`).
  388. """
  389. residual = hidden_states
  390. # Self Attention
  391. hidden_states, self_attn_weights = self.self_attn(
  392. hidden_states=hidden_states,
  393. attention_mask=attention_mask,
  394. position_ids=position_ids,
  395. past_key_values=past_key_values,
  396. output_attentions=output_attentions,
  397. use_cache=use_cache,
  398. position_embeddings=position_embeddings,
  399. **kwargs,
  400. )
  401. hidden_states = self.input_layernorm(hidden_states)
  402. hidden_states = residual + hidden_states
  403. # Fully Connected
  404. residual = hidden_states
  405. hidden_states = self.mlp(hidden_states)
  406. hidden_states = self.post_attention_layernorm(hidden_states)
  407. hidden_states = residual + hidden_states
  408. outputs = (hidden_states,)
  409. if output_attentions:
  410. outputs += (self_attn_weights,)
  411. return outputs
  412. class ChameleonVQVAEVectorQuantizer(nn.Module):
  413. """
  414. A module for vector quantization using learned embedding vectors.
  415. This module implements the quantization process similar to te one described in
  416. the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
  417. input vectors into discrete codebook vectors, which are learned during training.
  418. Current implementation improves over previous ones by avoiding costly matrix multiplications
  419. and allowing for post-hoc remapping of indices.
  420. """
  421. def __init__(self, config):
  422. super().__init__()
  423. self.num_embeddings = config.num_embeddings
  424. self.embedding_dim = config.embed_dim
  425. self.beta = getattr(config, "beta", 0.25)
  426. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  427. def forward(self, hidden_state: torch.Tensor):
  428. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  429. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  430. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  431. distances = (
  432. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
  433. + torch.sum(self.embedding.weight**2, dim=1)
  434. - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
  435. )
  436. min_encoding_indices = torch.argmin(distances, dim=1)
  437. hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
  438. # compute loss for embedding
  439. loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
  440. (hidden_state_quant - hidden_state.detach()) ** 2
  441. )
  442. # preserve gradients
  443. hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
  444. # reshape back to match original input shape
  445. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  446. return hidden_state_quant, loss, min_encoding_indices
  447. class ChameleonVQVAEEncoderConvDownsample(nn.Module):
  448. def __init__(self, in_channels):
  449. super().__init__()
  450. self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
  451. def forward(self, hidden_states):
  452. # no asymmetric padding in torch conv, must do it ourselves
  453. hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
  454. hidden_states = self.conv(hidden_states)
  455. return hidden_states
  456. class ChameleonVQVAEEncoderResnetBlock(nn.Module):
  457. def __init__(
  458. self,
  459. config,
  460. in_channels,
  461. out_channels=None,
  462. conv_shortcut=False,
  463. ):
  464. super().__init__()
  465. self.in_channels = in_channels
  466. self.out_channels = in_channels if out_channels is None else out_channels
  467. self.use_conv_shortcut = conv_shortcut
  468. self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  469. self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  470. self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
  471. self.dropout = torch.nn.Dropout(config.dropout)
  472. self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
  473. if self.in_channels != self.out_channels:
  474. if self.use_conv_shortcut:
  475. self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  476. else:
  477. self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
  478. def forward(self, hidden_states):
  479. residual = hidden_states
  480. hidden_states = self.norm1(hidden_states)
  481. hidden_states *= torch.sigmoid(hidden_states)
  482. hidden_states = self.conv1(hidden_states)
  483. hidden_states = self.norm2(hidden_states)
  484. hidden_states *= torch.sigmoid(hidden_states)
  485. hidden_states = self.dropout(hidden_states)
  486. hidden_states = self.conv2(hidden_states)
  487. if self.in_channels != self.out_channels:
  488. if self.use_conv_shortcut:
  489. residual = self.conv_shortcut(residual)
  490. else:
  491. residual = self.nin_shortcut(residual)
  492. return residual + hidden_states
  493. class ChameleonVQVAEEncoderAttnBlock(nn.Module):
  494. def __init__(self, in_channels):
  495. super().__init__()
  496. self.in_channels = in_channels
  497. self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  498. self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  499. self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  500. self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  501. self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  502. def forward(self, hidden_states):
  503. residual = hidden_states
  504. hidden_states = self.norm(hidden_states)
  505. query_states = self.q(hidden_states)
  506. key_states = self.k(hidden_states)
  507. value_states = self.v(hidden_states)
  508. # compute attention
  509. batch_size, channels, height, width = query_states.shape
  510. query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
  511. key_states = key_states.reshape(batch_size, channels, height * width)
  512. attn_weights = torch.bmm(query_states, key_states)
  513. attn_weights = attn_weights * (int(channels) ** (-0.5))
  514. attn_weights = F.softmax(attn_weights, dim=2)
  515. # attend to values
  516. value_states = value_states.reshape(batch_size, channels, height * width)
  517. attn_weights = attn_weights.permute(0, 2, 1)
  518. attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
  519. attn_output = self.proj_out(attn_output)
  520. return residual + attn_output
  521. class ChameleonVQVAEEncoder(nn.Module):
  522. def __init__(self, config):
  523. super().__init__()
  524. self.num_resolutions = len(config.channel_multiplier)
  525. self.num_res_blocks = config.num_res_blocks
  526. base_channels = config.base_channels
  527. resolution = config.resolution
  528. in_channels = config.in_channels
  529. double_latent = config.double_latent
  530. latent_channels = config.latent_channels
  531. channel_multiplier = config.channel_multiplier
  532. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  533. curr_res = resolution
  534. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  535. self.in_channel_multiplier = in_channel_multiplier
  536. self.down = nn.ModuleList()
  537. for i_level in range(self.num_resolutions):
  538. block = nn.ModuleList()
  539. attn = nn.ModuleList()
  540. block_in = base_channels * in_channel_multiplier[i_level]
  541. block_out = base_channels * channel_multiplier[i_level]
  542. for i_block in range(self.num_res_blocks):
  543. block.append(
  544. ChameleonVQVAEEncoderResnetBlock(
  545. config=config,
  546. in_channels=block_in,
  547. out_channels=block_out,
  548. )
  549. )
  550. block_in = block_out
  551. if (
  552. config.attn_resolutions is not None
  553. and curr_res in config.attn_resolutions
  554. and config.attn_type == "vanilla"
  555. ):
  556. attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
  557. down = nn.Module()
  558. down.block = block
  559. down.attn = attn
  560. if i_level != self.num_resolutions - 1:
  561. down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
  562. curr_res = curr_res // 2
  563. self.down.append(down)
  564. self.mid = nn.Module()
  565. self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
  566. config=config,
  567. in_channels=block_in,
  568. out_channels=block_in,
  569. )
  570. self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity()
  571. self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
  572. config=config,
  573. in_channels=block_in,
  574. out_channels=block_in,
  575. )
  576. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  577. self.conv_out = torch.nn.Conv2d(
  578. block_in,
  579. 2 * latent_channels if double_latent else latent_channels,
  580. kernel_size=3,
  581. stride=1,
  582. padding=1,
  583. )
  584. def forward(self, pixel_values: torch.LongTensor):
  585. # downsampling
  586. hidden_states = [self.conv_in(pixel_values)]
  587. for i_level in range(self.num_resolutions):
  588. for i_block in range(self.num_res_blocks):
  589. hidden_state = self.down[i_level].block[i_block](
  590. hidden_states[-1],
  591. )
  592. if len(self.down[i_level].attn) > 0:
  593. hidden_state = self.down[i_level].attn[i_block](hidden_state)
  594. hidden_states.append(hidden_state)
  595. if i_level != self.num_resolutions - 1:
  596. hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
  597. # middle
  598. last_hidden_state = hidden_states[-1]
  599. last_hidden_state = self.mid.block_1(last_hidden_state)
  600. last_hidden_state = self.mid.attn_1(last_hidden_state)
  601. last_hidden_state = self.mid.block_2(last_hidden_state)
  602. # end
  603. last_hidden_state = self.norm_out(last_hidden_state)
  604. last_hidden_state *= torch.sigmoid(last_hidden_state)
  605. last_hidden_state = self.conv_out(last_hidden_state)
  606. return last_hidden_state
  607. class ChameleonImageVocabularyMapping:
  608. """
  609. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  610. """
  611. def __init__(self, vocab_map):
  612. self.vocab_map = vocab_map
  613. self.image_token_id = vocab_map.get("<image>")
  614. @cached_property
  615. def val2name(self):
  616. return {v: k for k, v in self.vocab_map.items()}
  617. @cached_property
  618. def image_tokens(self):
  619. return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")])
  620. @cached_property
  621. def bpe2img(self):
  622. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  623. def remap(old_name: str) -> str:
  624. return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
  625. return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
  626. @cached_property
  627. def img2bpe(self):
  628. return {v: k for k, v in self.bpe2img.items()}
  629. @cached_property
  630. def bpe2img_search_tensors(self):
  631. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values()))
  632. @cached_property
  633. def img2bpe_mapping_tensor(self):
  634. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  635. for k, v in self.img2bpe.items():
  636. mapping[k] = v
  637. return mapping
  638. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  639. device = img_batch.device
  640. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  641. return img_tokens.to(device)
  642. @auto_docstring
  643. class ChameleonPreTrainedModel(PreTrainedModel):
  644. config: ChameleonConfig
  645. base_model_prefix = "model"
  646. input_modalities = ("image", "text")
  647. supports_gradient_checkpointing = True
  648. _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
  649. _skip_keys_device_placement = ["past_key_values", "causal_mask"]
  650. _supports_flash_attn = True
  651. _supports_sdpa = True
  652. _can_compile_fullgraph = True
  653. _supports_flex_attn = True
  654. _supports_attention_backend = True
  655. _can_record_outputs = {
  656. "hidden_states": [ChameleonDecoderLayer, ChameleonSwinDecoderLayer],
  657. "attentions": ChameleonAttention,
  658. }
  659. @auto_docstring(
  660. custom_intro="""
  661. The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
  662. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
  663. [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
  664. Taigman](https://huggingface.co/papers/2203.13131).
  665. """
  666. )
  667. class ChameleonVQVAE(ChameleonPreTrainedModel):
  668. config: ChameleonVQVAEConfig
  669. _no_split_modules = [
  670. "ChameleonVQVAEVectorQuantizer",
  671. "ChameleonVQVAEEncoderAttnBlock",
  672. "ChameleonVQVAEEncoderResnetBlock",
  673. ]
  674. _can_record_outputs = {
  675. "hidden_states": ChameleonVQVAEEncoderResnetBlock,
  676. "attentions": ChameleonVQVAEEncoderAttnBlock,
  677. }
  678. def __init__(self, config: ChameleonVQVAEConfig):
  679. super().__init__(config)
  680. self.encoder = ChameleonVQVAEEncoder(config)
  681. self.quantize = ChameleonVQVAEVectorQuantizer(config)
  682. self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
  683. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
  684. self.eval() # Chameleon's VQ model is frozen
  685. self.post_init()
  686. @merge_with_config_defaults
  687. @capture_outputs
  688. def encode(
  689. self, pixel_values: torch.LongTensor, **kwargs: Unpack[TransformersKwargs]
  690. ) -> ChameleonVQVAEModelOutput:
  691. hidden_states = self.encoder(pixel_values)
  692. conv_hidden_states = self.quant_conv(hidden_states)
  693. quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
  694. return ChameleonVQVAEModelOutput(
  695. last_hidden_state=hidden_states,
  696. quantized_last_hidden_state=quantized_last_hidden_state,
  697. image_tokens=indices,
  698. embedding_loss=emb_loss,
  699. )
  700. @auto_docstring
  701. class ChameleonModel(ChameleonPreTrainedModel):
  702. def __init__(self, config: ChameleonConfig):
  703. super().__init__(config)
  704. self.padding_idx = config.pad_token_id
  705. self.vocab_size = config.vocab_size
  706. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  707. self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
  708. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer
  709. self.layers = nn.ModuleList(
  710. [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  711. )
  712. self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  713. self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
  714. self.rotary_emb = ChameleonRotaryEmbedding(config=config)
  715. self.gradient_checkpointing = False
  716. # Initialize weights and apply final processing
  717. self.post_init()
  718. def get_image_tokens(self, pixel_values: torch.FloatTensor):
  719. """
  720. Tokenizes images into discrete tokens with VQGAN module. Converts
  721. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  722. special tokens.
  723. Args:
  724. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
  725. The tensors corresponding to the input images.
  726. """
  727. batch_size = pixel_values.shape[0]
  728. vqmodel_outputs: ChameleonVQVAEModelOutput = self.vqmodel.encode(pixel_values, return_dict=True)
  729. bpe_toks = self.vocabulary_mapping.convert_img2bpe(vqmodel_outputs.image_tokens)
  730. bpe_toks = bpe_toks.view(batch_size, -1)
  731. return bpe_toks
  732. @can_return_tuple
  733. @auto_docstring(
  734. custom_intro="Tokenizes images into discrete tokens with VQGAN module and embeds them with text embeddings layer."
  735. )
  736. def get_image_features(
  737. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  738. ) -> tuple | BaseModelOutputWithPooling:
  739. r"""
  740. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  741. The tensors corresponding to the input images.
  742. """
  743. batch_size = pixel_values.shape[0]
  744. vqmodel_outputs: ChameleonVQVAEModelOutput = self.vqmodel.encode(pixel_values, return_dict=True, **kwargs)
  745. bpe_tokens = self.vocabulary_mapping.convert_img2bpe(vqmodel_outputs.image_tokens).view(batch_size, -1)
  746. vqmodel_outputs.pooler_output = self.get_input_embeddings()(bpe_tokens)
  747. return vqmodel_outputs
  748. def get_placeholder_mask(
  749. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  750. ):
  751. """
  752. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  753. equal to the length of multimodal features. If the lengths are different, an error is raised.
  754. """
  755. if input_ids is None:
  756. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  757. torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  758. )
  759. special_image_mask = special_image_mask.all(-1)
  760. else:
  761. special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
  762. n_image_tokens = special_image_mask.sum()
  763. n_image_features = image_features.shape[0] * image_features.shape[1]
  764. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  765. torch_compilable_check(
  766. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  767. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  768. )
  769. return special_image_mask
  770. @merge_with_config_defaults
  771. @capture_outputs
  772. @auto_docstring
  773. def forward(
  774. self,
  775. input_ids: torch.LongTensor | None = None,
  776. pixel_values: torch.FloatTensor | None = None,
  777. attention_mask: torch.Tensor | None = None,
  778. position_ids: torch.LongTensor | None = None,
  779. past_key_values: Cache | None = None,
  780. inputs_embeds: torch.FloatTensor | None = None,
  781. use_cache: bool | None = None,
  782. **kwargs: Unpack[FlashAttentionKwargs],
  783. ) -> tuple | BaseModelOutputWithPast:
  784. if (input_ids is None) ^ (inputs_embeds is not None):
  785. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  786. if inputs_embeds is None:
  787. inputs_embeds = self.embed_tokens(input_ids)
  788. if pixel_values is not None:
  789. image_features = self.get_image_features(pixel_values, return_dict=True).pooler_output
  790. special_image_mask = self.get_placeholder_mask(
  791. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  792. )
  793. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  794. # torch.jit.trace() doesn't support cache objects in the output
  795. if use_cache and past_key_values is None and not torch.jit.is_tracing():
  796. past_key_values = DynamicCache(config=self.config)
  797. if position_ids is None:
  798. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  799. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  800. position_ids = position_ids.unsqueeze(0)
  801. causal_mask = create_causal_mask(
  802. config=self.config,
  803. inputs_embeds=inputs_embeds,
  804. attention_mask=attention_mask,
  805. past_key_values=past_key_values,
  806. position_ids=position_ids,
  807. )
  808. # embed positions
  809. hidden_states = inputs_embeds
  810. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  811. # decoder layers
  812. for decoder_layer in self.layers:
  813. layer_outputs = decoder_layer(
  814. hidden_states,
  815. attention_mask=causal_mask,
  816. position_ids=position_ids,
  817. past_key_values=past_key_values,
  818. use_cache=use_cache,
  819. position_embeddings=position_embeddings,
  820. **kwargs,
  821. )
  822. hidden_states = layer_outputs[0]
  823. hidden_states = self.norm(hidden_states)
  824. return BaseModelOutputWithPast(
  825. last_hidden_state=hidden_states,
  826. past_key_values=past_key_values,
  827. )
  828. @auto_docstring(
  829. custom_intro="""
  830. Chameleon Model with a head on top used for outputting logits for next token prediction.
  831. """
  832. )
  833. class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
  834. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  835. def __init__(self, config):
  836. super().__init__(config)
  837. self.model = ChameleonModel(config)
  838. self.vocab_size = config.vocab_size
  839. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  840. # Initialize weights and apply final processing
  841. self.post_init()
  842. def get_image_tokens(self, pixel_values):
  843. return self.model.get_image_tokens(pixel_values)
  844. @auto_docstring
  845. def get_image_features(
  846. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  847. ) -> tuple | BaseModelOutputWithPooling:
  848. return self.model.get_image_features(pixel_values, **kwargs)
  849. @can_return_tuple
  850. @auto_docstring
  851. def forward(
  852. self,
  853. input_ids: torch.LongTensor | None = None,
  854. pixel_values: torch.FloatTensor | None = None,
  855. attention_mask: torch.Tensor | None = None,
  856. position_ids: torch.LongTensor | None = None,
  857. past_key_values: Cache | None = None,
  858. inputs_embeds: torch.FloatTensor | None = None,
  859. labels: torch.LongTensor | None = None,
  860. use_cache: bool | None = None,
  861. logits_to_keep: int | torch.Tensor = 0,
  862. **kwargs: Unpack[TransformersKwargs],
  863. ) -> tuple | CausalLMOutputWithPast:
  864. r"""
  865. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  866. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  867. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  868. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  869. Example:
  870. ```python
  871. >>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
  872. >>> import torch
  873. >>> import httpx
  874. >>> from io import BytesIO
  875. >>> from PIL import Image
  876. >>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", dtype=torch.bfloat16)
  877. >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
  878. >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
  879. >>> url = "https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg"
  880. >>> with httpx.stream("GET", url) as response:
  881. ... image1 = Image.open(BytesIO(response.read()))
  882. >>> url = "https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg"
  883. >>> with httpx.stream("GET", url) as response:
  884. ... image2 = Image.open(BytesIO(response.read()))
  885. >>> inputs = processor(images=[image1, image2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
  886. >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
  887. >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  888. ```"""
  889. outputs: BaseModelOutputWithPast = self.model(
  890. input_ids=input_ids,
  891. pixel_values=pixel_values,
  892. attention_mask=attention_mask,
  893. position_ids=position_ids,
  894. past_key_values=past_key_values,
  895. inputs_embeds=inputs_embeds,
  896. use_cache=use_cache,
  897. **kwargs,
  898. )
  899. hidden_states = outputs[0]
  900. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  901. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  902. logits = self.lm_head(hidden_states[:, slice_indices, :])
  903. # Disallow image tokens which does not include special begin-image and end-image tokens
  904. image_tokens = self.model.vocabulary_mapping.image_tokens
  905. logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
  906. loss = None
  907. if labels is not None:
  908. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  909. return CausalLMOutputWithPast(
  910. loss=loss,
  911. logits=logits,
  912. past_key_values=outputs.past_key_values,
  913. hidden_states=outputs.hidden_states,
  914. attentions=outputs.attentions,
  915. )
  916. def prepare_inputs_for_generation(
  917. self,
  918. input_ids,
  919. pixel_values=None,
  920. past_key_values=None,
  921. attention_mask=None,
  922. inputs_embeds=None,
  923. position_ids=None,
  924. use_cache=True,
  925. is_first_iteration=False,
  926. **kwargs,
  927. ):
  928. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  929. model_inputs = super().prepare_inputs_for_generation(
  930. input_ids,
  931. pixel_values=pixel_values,
  932. past_key_values=past_key_values,
  933. attention_mask=attention_mask,
  934. inputs_embeds=inputs_embeds,
  935. position_ids=position_ids,
  936. use_cache=use_cache,
  937. is_first_iteration=is_first_iteration,
  938. **kwargs,
  939. )
  940. if not is_first_iteration and use_cache:
  941. # Pixel values are used only in the first iteration if available
  942. # In subsequent iterations, they are already merged with text and cached
  943. # NOTE: first iteration doesn't have to be prefill, it can be the first
  944. # iteration with a question and cached system prompt (continue generate from cache)
  945. model_inputs["pixel_values"] = None
  946. return model_inputs
  947. __all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]