modeling_janus.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/janus/modular_janus.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_janus.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
  29. from ...generation.utils import GenerateDecoderOnlyOutput
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int
  35. from ...utils.generic import merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from ..auto import AutoModel
  38. from .configuration_janus import JanusConfig, JanusVisionConfig, JanusVQVAEConfig
  39. logger = logging.get_logger(__name__)
  40. @auto_docstring
  41. class JanusPreTrainedModel(PreTrainedModel):
  42. config: JanusConfig
  43. base_model_prefix = "model"
  44. input_modalities = ("image", "text")
  45. supports_gradient_checkpointing = True
  46. _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
  47. _skip_keys_device_placement = ["past_key_values", "causal_mask"]
  48. _supports_flash_attn = True
  49. _supports_sdpa = True
  50. _can_compile_fullgraph = True
  51. def _init_weights(self, module):
  52. super()._init_weights(module)
  53. if isinstance(module, JanusVisionEmbeddings):
  54. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  55. @dataclass
  56. @auto_docstring(
  57. custom_intro="""
  58. Base class for Janus VQ-VAE mode model outputs.
  59. """
  60. )
  61. class JanusVQVAEOutput(ModelOutput):
  62. r"""
  63. decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  64. Reconstructed pixel values after encoding and decoding the input.
  65. embedding_loss (`torch.FloatTensor`):
  66. Embedding loss.
  67. """
  68. decoded_pixel_values: torch.FloatTensor | None = None
  69. embedding_loss: torch.FloatTensor | None = None
  70. @dataclass
  71. @auto_docstring(
  72. custom_intro="""
  73. Base class for Janus model's outputs that may also contain a past key/values (to speed up sequential decoding).
  74. """
  75. )
  76. class JanusBaseModelOutputWithPast(ModelOutput):
  77. r"""
  78. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  79. Sequence of hidden-states at the output of the last layer of the model.
  80. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  81. hidden_size)` is output.
  82. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  83. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  84. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  85. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  86. input) to speed up sequential decoding.
  87. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  88. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  89. sequence_length, hidden_size)`.
  90. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  91. """
  92. last_hidden_state: torch.FloatTensor | None = None
  93. past_key_values: Cache | None = None
  94. hidden_states: tuple[torch.FloatTensor] | None = None
  95. attentions: tuple[torch.FloatTensor] | None = None
  96. image_hidden_states: tuple[torch.FloatTensor] | None = None
  97. @dataclass
  98. @auto_docstring(
  99. custom_intro="""
  100. Base class for Janus causal language model (or autoregressive) outputs.
  101. """
  102. )
  103. class JanusCausalLMOutputWithPast(ModelOutput):
  104. r"""
  105. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  106. Language modeling loss (for next-token prediction).
  107. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  108. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  109. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  110. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  111. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  112. `past_key_values` input) to speed up sequential decoding.
  113. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  114. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  115. sequence_length, hidden_size)`.
  116. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  117. """
  118. loss: torch.FloatTensor | None = None
  119. logits: torch.FloatTensor | None = None
  120. past_key_values: Cache | None = None
  121. hidden_states: tuple[torch.FloatTensor] | None = None
  122. attentions: tuple[torch.FloatTensor] | None = None
  123. image_hidden_states: tuple[torch.FloatTensor] | None = None
  124. class JanusVisionEmbeddings(nn.Module):
  125. def __init__(self, config: JanusVisionConfig):
  126. super().__init__()
  127. self.config = config
  128. self.embed_dim = config.hidden_size
  129. self.image_size = config.image_size
  130. self.patch_size = config.patch_size
  131. self.patch_embedding = nn.Conv2d(
  132. in_channels=config.num_channels,
  133. out_channels=self.embed_dim,
  134. kernel_size=self.patch_size,
  135. stride=self.patch_size,
  136. padding="valid",
  137. )
  138. self.num_patches = (self.image_size // self.patch_size) ** 2
  139. self.num_positions = self.num_patches
  140. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  141. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  142. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  143. """
  144. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  145. images. This method is also adapted to support torch.jit tracing and no class embeddings.
  146. Adapted from:
  147. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  148. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  149. """
  150. num_patches = embeddings.shape[1]
  151. num_positions = self.position_embedding.weight.shape[0]
  152. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  153. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  154. return self.position_embedding(self.position_ids)
  155. patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
  156. dim = embeddings.shape[-1]
  157. new_height = height // self.patch_size
  158. new_width = width // self.patch_size
  159. sqrt_num_positions = torch_int(num_positions**0.5)
  160. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  161. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  162. patch_pos_embed = nn.functional.interpolate(
  163. patch_pos_embed,
  164. size=(new_height, new_width),
  165. mode="bicubic",
  166. align_corners=False,
  167. )
  168. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  169. return patch_pos_embed
  170. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  171. _, _, height, width = pixel_values.shape
  172. target_dtype = self.patch_embedding.weight.dtype
  173. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  174. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  175. if interpolate_pos_encoding:
  176. pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
  177. else:
  178. pos_embeds = self.position_embedding(self.position_ids)
  179. embeddings = embeddings + pos_embeds
  180. return embeddings
  181. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  182. """
  183. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  184. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  185. """
  186. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  187. if n_rep == 1:
  188. return hidden_states
  189. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  190. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  191. def eager_attention_forward(
  192. module: nn.Module,
  193. query: torch.Tensor,
  194. key: torch.Tensor,
  195. value: torch.Tensor,
  196. attention_mask: torch.Tensor | None,
  197. scaling: float,
  198. dropout: float = 0.0,
  199. **kwargs: Unpack[TransformersKwargs],
  200. ):
  201. key_states = repeat_kv(key, module.num_key_value_groups)
  202. value_states = repeat_kv(value, module.num_key_value_groups)
  203. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  204. if attention_mask is not None:
  205. attn_weights = attn_weights + attention_mask
  206. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  207. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  208. attn_output = torch.matmul(attn_weights, value_states)
  209. attn_output = attn_output.transpose(1, 2).contiguous()
  210. return attn_output, attn_weights
  211. class JanusVisionAttention(nn.Module):
  212. """Attention Class for Janus Vision Encoder"""
  213. def __init__(self, config: JanusVisionConfig):
  214. super().__init__()
  215. self.config = config
  216. self.embed_dim = config.hidden_size
  217. self.num_heads = config.num_attention_heads
  218. self.head_dim = self.embed_dim // self.num_heads
  219. if self.head_dim * self.num_heads != self.embed_dim:
  220. raise ValueError(
  221. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  222. f" {self.num_heads})."
  223. )
  224. self.scale = self.head_dim**-0.5
  225. self.attention_dropout = config.attention_dropout
  226. proj_dropout = config.projection_dropout
  227. qk_norm = config.use_qk_norm
  228. self.is_causal = False
  229. # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
  230. self.num_key_value_groups = 1
  231. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  232. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  233. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  234. self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
  235. self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
  236. self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  237. self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  238. def forward(
  239. self,
  240. hidden_states: torch.Tensor,
  241. attention_mask: torch.Tensor | None = None,
  242. **kwargs: Unpack[TransformersKwargs],
  243. ):
  244. batch_size, seq_len, _ = hidden_states.size()
  245. query_states = self.q_proj(hidden_states)
  246. key_states = self.k_proj(hidden_states)
  247. value_states = self.v_proj(hidden_states)
  248. query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
  249. query_states = self.q_norm(query_states)
  250. key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
  251. key_states = self.k_norm(key_states)
  252. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  253. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  254. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  255. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  256. self.config._attn_implementation, eager_attention_forward
  257. )
  258. attn_output, attn_weights = attention_interface(
  259. self,
  260. query_states,
  261. key_states,
  262. value_states,
  263. attention_mask,
  264. dropout=0.0 if not self.training else self.attention_dropout,
  265. scaling=self.scale,
  266. is_causal=self.is_causal,
  267. **kwargs,
  268. )
  269. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  270. output = self.projection_layer(attn_output)
  271. output = self.projection_dropout(output)
  272. return output, attn_weights
  273. class JanusVisionMLP(nn.Module):
  274. def __init__(self, config: JanusVisionConfig):
  275. super().__init__()
  276. self.config = config
  277. self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
  278. self.activation_fn = ACT2FN[config.hidden_act] # Gelu act
  279. self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
  280. self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
  281. self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
  282. self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
  283. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  284. hidden_states = self.fc1(hidden_states)
  285. hidden_states = self.activation_fn(hidden_states)
  286. hidden_states = self.dropout1(hidden_states)
  287. hidden_states = self.fc2(hidden_states)
  288. hidden_states = self.dropout2(hidden_states)
  289. return hidden_states
  290. class JanusVisionEncoderLayer(GradientCheckpointingLayer):
  291. def __init__(self, config: JanusVisionConfig):
  292. super().__init__()
  293. self.embed_dim = config.hidden_size
  294. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  295. self.self_attn = JanusVisionAttention(config)
  296. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  297. self.mlp = JanusVisionMLP(config)
  298. self.config = config
  299. @auto_docstring
  300. def forward(
  301. self,
  302. hidden_states: torch.Tensor,
  303. attention_mask: torch.Tensor,
  304. **kwargs: Unpack[TransformersKwargs],
  305. ) -> torch.FloatTensor:
  306. residual = hidden_states
  307. hidden_states = self.layer_norm1(hidden_states)
  308. hidden_states, _ = self.self_attn(
  309. hidden_states=hidden_states,
  310. attention_mask=attention_mask,
  311. **kwargs,
  312. )
  313. hidden_states = residual + hidden_states
  314. residual = hidden_states
  315. hidden_states = self.layer_norm2(hidden_states)
  316. hidden_states = self.mlp(hidden_states)
  317. hidden_states = residual + hidden_states
  318. return hidden_states
  319. class JanusVisionEncoder(nn.Module):
  320. """
  321. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  322. [`JanusVisionEncoderLayer`].
  323. Args:
  324. config: JanusVisionConfig
  325. """
  326. def __init__(self, config: JanusVisionConfig):
  327. super().__init__()
  328. self.config = config
  329. self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  330. self.gradient_checkpointing = False
  331. # Ignore copy
  332. @auto_docstring
  333. def forward(
  334. self,
  335. inputs_embeds,
  336. attention_mask: torch.Tensor | None = None,
  337. **kwargs: Unpack[TransformersKwargs],
  338. ) -> BaseModelOutput:
  339. hidden_states = inputs_embeds
  340. for encoder_layer in self.layers:
  341. hidden_states = encoder_layer(
  342. hidden_states,
  343. attention_mask,
  344. **kwargs,
  345. )
  346. return BaseModelOutput(last_hidden_state=hidden_states)
  347. @auto_docstring
  348. class JanusVisionModel(JanusPreTrainedModel):
  349. main_input_name = "pixel_values"
  350. input_modalities = ("image",)
  351. config: JanusVisionConfig
  352. _can_record_outputs = {
  353. "hidden_states": JanusVisionEncoderLayer,
  354. "attentions": JanusVisionAttention,
  355. }
  356. def __init__(self, config: JanusVisionConfig):
  357. super().__init__(config)
  358. self.config = config
  359. embed_dim = config.hidden_size
  360. self.embeddings = JanusVisionEmbeddings(config)
  361. self.encoder = JanusVisionEncoder(config)
  362. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  363. self.post_init()
  364. @merge_with_config_defaults
  365. @capture_outputs(tie_last_hidden_states=False)
  366. @auto_docstring
  367. def forward(
  368. self,
  369. pixel_values: torch.FloatTensor | None = None,
  370. interpolate_pos_encoding: bool = False,
  371. **kwargs: Unpack[TransformersKwargs],
  372. ) -> tuple | BaseModelOutputWithPooling:
  373. if pixel_values is None:
  374. raise ValueError("You have to specify pixel_values")
  375. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  376. encoder_outputs: BaseModelOutput = self.encoder(
  377. inputs_embeds=hidden_states,
  378. **kwargs,
  379. )
  380. last_hidden_state = encoder_outputs.last_hidden_state
  381. last_hidden_state = self.post_layernorm(last_hidden_state)
  382. pooled_output = last_hidden_state[:, 0, :]
  383. pooled_output = self.post_layernorm(pooled_output)
  384. return BaseModelOutputWithPooling(
  385. last_hidden_state=last_hidden_state,
  386. pooler_output=pooled_output,
  387. )
  388. def get_input_embeddings(self):
  389. return self.embeddings
  390. class JanusVisionAlignerMLP(nn.Module):
  391. def __init__(self, config: JanusVisionConfig):
  392. super().__init__()
  393. self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
  394. self.hidden_layers = nn.ModuleList(
  395. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
  396. )
  397. self.activation_fn = ACT2FN[config.hidden_act]
  398. def forward(self, hidden_states):
  399. hidden_states = self.fc1(hidden_states)
  400. for layer in self.hidden_layers:
  401. hidden_states = self.activation_fn(hidden_states)
  402. hidden_states = layer(hidden_states)
  403. return hidden_states
  404. class JanusVQVAEVectorQuantizer(nn.Module):
  405. """
  406. A module for vector quantization using learned embedding vectors.
  407. This module implements the quantization process similar to te one described in
  408. the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
  409. input vectors into discrete codebook vectors, which are learned during training.
  410. Current implementation improves over previous ones by avoiding costly matrix multiplications
  411. and allowing for post-hoc remapping of indices.
  412. """
  413. def __init__(self, config: JanusVQVAEConfig):
  414. super().__init__()
  415. self.num_embeddings = config.num_embeddings
  416. self.embedding_dim = config.embed_dim
  417. self.beta = getattr(config, "beta", 0.25)
  418. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  419. self.quant_state_dims = [config.num_patches] * 2
  420. def forward(self, hidden_state: torch.Tensor):
  421. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  422. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  423. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  424. distances = (
  425. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
  426. + torch.sum(self.embedding.weight**2, dim=1)
  427. - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
  428. )
  429. min_encoding_indices = torch.argmin(distances, dim=1)
  430. hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
  431. # compute loss for embedding
  432. loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
  433. (hidden_state_quant - hidden_state.detach()) ** 2
  434. )
  435. # preserve gradients
  436. hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
  437. # reshape back to match original input shape
  438. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  439. return hidden_state_quant, loss, min_encoding_indices
  440. def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  441. batch_size = image_tokens.shape[0]
  442. emb_dim: int = self.embedding.weight.shape[-1]
  443. # get quantized latent vectors
  444. hidden_state_quant = self.embedding(image_tokens)
  445. # l2 normalization on the last dimension
  446. hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
  447. # reshape back to match original input shape
  448. hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
  449. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  450. return hidden_state_quant
  451. class JanusVQVAEResnetBlock(nn.Module):
  452. def __init__(
  453. self,
  454. config,
  455. in_channels,
  456. out_channels=None,
  457. conv_shortcut=False,
  458. ):
  459. super().__init__()
  460. self.in_channels = in_channels
  461. self.out_channels = in_channels if out_channels is None else out_channels
  462. self.use_conv_shortcut = conv_shortcut
  463. self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  464. self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  465. self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
  466. self.dropout = torch.nn.Dropout(config.dropout)
  467. self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
  468. if self.in_channels != self.out_channels:
  469. if self.use_conv_shortcut:
  470. self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  471. else:
  472. self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
  473. def forward(self, hidden_states):
  474. residual = hidden_states
  475. hidden_states = self.norm1(hidden_states)
  476. hidden_states *= torch.sigmoid(hidden_states)
  477. hidden_states = self.conv1(hidden_states)
  478. hidden_states = self.norm2(hidden_states)
  479. hidden_states *= torch.sigmoid(hidden_states)
  480. hidden_states = self.dropout(hidden_states)
  481. hidden_states = self.conv2(hidden_states)
  482. if self.in_channels != self.out_channels:
  483. if self.use_conv_shortcut:
  484. residual = self.conv_shortcut(residual)
  485. else:
  486. residual = self.nin_shortcut(residual)
  487. return residual + hidden_states
  488. class JanusVQVAEAttnBlock(nn.Module):
  489. def __init__(self, in_channels):
  490. super().__init__()
  491. self.in_channels = in_channels
  492. self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  493. self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  494. self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  495. self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  496. self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  497. def forward(self, hidden_states):
  498. residual = hidden_states
  499. hidden_states = self.norm(hidden_states)
  500. query_states = self.q(hidden_states)
  501. key_states = self.k(hidden_states)
  502. value_states = self.v(hidden_states)
  503. # compute attention
  504. batch_size, channels, height, width = query_states.shape
  505. query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
  506. key_states = key_states.reshape(batch_size, channels, height * width)
  507. attn_weights = torch.bmm(query_states, key_states)
  508. attn_weights = attn_weights * (int(channels) ** (-0.5))
  509. attn_weights = F.softmax(attn_weights, dim=2)
  510. # attend to values
  511. value_states = value_states.reshape(batch_size, channels, height * width)
  512. attn_weights = attn_weights.permute(0, 2, 1)
  513. attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
  514. attn_output = self.proj_out(attn_output)
  515. return residual + attn_output
  516. class JanusVQVAEConvDownsample(nn.Module):
  517. def __init__(self, in_channels):
  518. super().__init__()
  519. self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
  520. def forward(self, hidden_states):
  521. # no asymmetric padding in torch conv, must do it ourselves
  522. hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
  523. hidden_states = self.conv(hidden_states)
  524. return hidden_states
  525. class JanusVQVAEConvUpsample(nn.Module):
  526. def __init__(self, in_channels):
  527. super().__init__()
  528. self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  529. def forward(self, hidden_states):
  530. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  531. hidden_states = self.conv(hidden_states)
  532. return hidden_states
  533. class JanusVQVAEMidBlock(nn.Module):
  534. def __init__(self, config: JanusVQVAEConfig, channels: int):
  535. super().__init__()
  536. self.block_1 = JanusVQVAEResnetBlock(
  537. config=config,
  538. in_channels=channels,
  539. out_channels=channels,
  540. )
  541. self.attn_1 = JanusVQVAEAttnBlock(channels)
  542. self.block_2 = JanusVQVAEResnetBlock(
  543. config=config,
  544. in_channels=channels,
  545. out_channels=channels,
  546. )
  547. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  548. hidden_states = self.block_1(hidden_states)
  549. hidden_states = self.attn_1(hidden_states)
  550. hidden_states = self.block_2(hidden_states)
  551. return hidden_states
  552. class JanusVQVAEEncoder(nn.Module):
  553. def __init__(self, config):
  554. super().__init__()
  555. self.num_resolutions = len(config.channel_multiplier)
  556. self.num_res_blocks = config.num_res_blocks
  557. base_channels = config.base_channels
  558. in_channels = config.in_channels
  559. double_latent = config.double_latent
  560. latent_channels = config.latent_channels
  561. channel_multiplier = config.channel_multiplier
  562. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  563. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  564. self.in_channel_multiplier = in_channel_multiplier
  565. self.down = nn.ModuleList()
  566. for i_level in range(self.num_resolutions):
  567. block = nn.ModuleList()
  568. attn = nn.ModuleList()
  569. block_in = base_channels * in_channel_multiplier[i_level]
  570. block_out = base_channels * channel_multiplier[i_level]
  571. for i_block in range(self.num_res_blocks):
  572. block.append(
  573. JanusVQVAEResnetBlock(
  574. config=config,
  575. in_channels=block_in,
  576. out_channels=block_out,
  577. )
  578. )
  579. block_in = block_out
  580. if i_level == self.num_resolutions - 1:
  581. attn.append(JanusVQVAEAttnBlock(block_in))
  582. down = nn.Module()
  583. down.block = block
  584. down.attn = attn
  585. if i_level != self.num_resolutions - 1:
  586. down.downsample = JanusVQVAEConvDownsample(block_in)
  587. self.down.append(down)
  588. self.mid = JanusVQVAEMidBlock(config, block_in)
  589. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  590. self.conv_out = torch.nn.Conv2d(
  591. block_in,
  592. 2 * latent_channels if double_latent else latent_channels,
  593. kernel_size=3,
  594. stride=1,
  595. padding=1,
  596. )
  597. def forward(self, pixel_values: torch.LongTensor):
  598. # downsampling
  599. hidden_states = [self.conv_in(pixel_values)]
  600. for i_level in range(self.num_resolutions):
  601. for i_block in range(self.num_res_blocks):
  602. hidden_state = self.down[i_level].block[i_block](
  603. hidden_states[-1],
  604. )
  605. if len(self.down[i_level].attn) > 0:
  606. hidden_state = self.down[i_level].attn[i_block](hidden_state)
  607. hidden_states.append(hidden_state)
  608. if i_level != self.num_resolutions - 1:
  609. hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
  610. # middle
  611. last_hidden_state = hidden_states[-1]
  612. last_hidden_state = self.mid(last_hidden_state)
  613. # end
  614. last_hidden_state = self.norm_out(last_hidden_state)
  615. last_hidden_state *= torch.sigmoid(last_hidden_state)
  616. last_hidden_state = self.conv_out(last_hidden_state)
  617. return last_hidden_state
  618. class JanusVQVAEDecoder(nn.Module):
  619. def __init__(self, config):
  620. super().__init__()
  621. self.num_resolutions = len(config.channel_multiplier)
  622. self.num_res_blocks = config.num_res_blocks
  623. base_channels = config.base_channels
  624. latent_channels = config.latent_channels
  625. out_channels = config.out_channels
  626. # compute in_ch_mult, block_in and curr_res at lowest res
  627. block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
  628. # z to block_in
  629. self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
  630. # middle
  631. self.mid = JanusVQVAEMidBlock(config, block_in)
  632. # upsampling
  633. self.up = nn.ModuleList()
  634. for i_level in reversed(range(self.num_resolutions)):
  635. block = nn.ModuleList()
  636. attn = nn.ModuleList()
  637. block_out = base_channels * config.channel_multiplier[i_level]
  638. for i_block in range(self.num_res_blocks + 1):
  639. block.append(
  640. JanusVQVAEResnetBlock(
  641. config=config,
  642. in_channels=block_in,
  643. out_channels=block_out,
  644. )
  645. )
  646. block_in = block_out
  647. if i_level == self.num_resolutions - 1:
  648. attn.append(JanusVQVAEAttnBlock(block_in))
  649. up = nn.Module()
  650. up.block = block
  651. up.attn = attn
  652. if i_level != 0:
  653. up.upsample = JanusVQVAEConvUpsample(block_in)
  654. self.up.append(up)
  655. # end
  656. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  657. self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
  658. def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
  659. hidden_state = self.conv_in(hidden_state)
  660. # middle
  661. hidden_state = self.mid(hidden_state)
  662. # upsampling
  663. for i_level in range(self.num_resolutions):
  664. for i_block in range(self.num_res_blocks + 1):
  665. hidden_state = self.up[i_level].block[i_block](hidden_state)
  666. if len(self.up[i_level].attn) > 0:
  667. hidden_state = self.up[i_level].attn[i_block](hidden_state)
  668. if i_level != self.num_resolutions - 1:
  669. hidden_state = self.up[i_level].upsample(hidden_state)
  670. hidden_state = self.norm_out(hidden_state)
  671. hidden_state *= torch.sigmoid(hidden_state)
  672. hidden_state = self.conv_out(hidden_state)
  673. return hidden_state
  674. @dataclass
  675. @auto_docstring
  676. class JanusVQVAEModelOutput(BaseModelOutputWithPooling):
  677. r"""
  678. quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  679. Quantized last hidden state from the VQ-VAE model.
  680. image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
  681. Indices of the image tokens predicted by the VQ-VAE model.
  682. embedding_loss (`torch.FloatTensor`):
  683. The embedding loss computed during quantization.
  684. """
  685. quantized_last_hidden_state: torch.FloatTensor | None = None
  686. image_tokens: torch.FloatTensor | None = None
  687. embedding_loss: torch.FloatTensor | None = None
  688. @auto_docstring(
  689. custom_intro="""
  690. The VQ-VAE model used in Janus for encoding/decoding images into discrete tokens.
  691. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
  692. [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
  693. Taigman](https://huggingface.co/papers/2203.13131).
  694. """
  695. )
  696. class JanusVQVAE(JanusPreTrainedModel):
  697. config: JanusVQVAEConfig
  698. _no_split_modules = [
  699. "JanusVQVAEAttnBlock",
  700. "JanusVQVAEResnetBlock",
  701. "JanusVQVAEVectorQuantizer",
  702. ]
  703. _can_record_outputs = {
  704. "hidden_states": JanusVQVAEResnetBlock,
  705. "attentions": JanusVQVAEAttnBlock,
  706. }
  707. main_input_name = "pixel_values"
  708. def __init__(self, config: JanusVQVAEConfig):
  709. super().__init__(config)
  710. self.encoder = JanusVQVAEEncoder(config)
  711. self.quantize = JanusVQVAEVectorQuantizer(config)
  712. self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
  713. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
  714. self.eval() # Janus's VQ model is frozen
  715. self.decoder = JanusVQVAEDecoder(config)
  716. self.gradient_checkpointing = False
  717. self.post_init()
  718. @merge_with_config_defaults
  719. @capture_outputs
  720. def encode(self, pixel_values: torch.LongTensor, **kwargs: Unpack[TransformersKwargs]) -> JanusVQVAEModelOutput:
  721. hidden_states = self.encoder(pixel_values)
  722. conv_hidden_states = self.quant_conv(hidden_states)
  723. quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
  724. return JanusVQVAEModelOutput(
  725. last_hidden_state=hidden_states,
  726. quantized_last_hidden_state=quantized_last_hidden_state,
  727. image_tokens=indices,
  728. embedding_loss=emb_loss,
  729. )
  730. def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  731. """
  732. Decodes quantized token IDs into pixel values.
  733. Args:
  734. image_tokens (torch.LongTensor): Batch of token IDs.
  735. Returns:
  736. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  737. Pixel values decoded from the token IDs.
  738. """
  739. if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
  740. raise ValueError(
  741. f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
  742. f"but got shape `{image_tokens.shape}`."
  743. )
  744. codebook_entry = self.quantize.get_codebook_entry(image_tokens)
  745. hidden_states = self.post_quant_conv(codebook_entry)
  746. pixel_values = self.decoder(hidden_states)
  747. return pixel_values
  748. @can_return_tuple
  749. @auto_docstring
  750. def forward(
  751. self,
  752. pixel_values: torch.FloatTensor,
  753. **kwargs,
  754. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  755. batch_size = pixel_values.shape[0]
  756. encode_outputs = self.encode(pixel_values, return_dict=True, **kwargs)
  757. decoded_pixel_values = self.decode(encode_outputs.image_tokens.view(batch_size, -1))
  758. return JanusVQVAEOutput(decoded_pixel_values, encode_outputs.embedding_loss)
  759. class JanusVQVAEAlignerMLP(nn.Module):
  760. def __init__(self, config: JanusVQVAEConfig):
  761. super().__init__()
  762. self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
  763. self.hidden_layers = nn.ModuleList(
  764. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
  765. )
  766. self.activation_fn = ACT2FN[config.hidden_act]
  767. def forward(self, hidden_states):
  768. hidden_states = self.fc1(hidden_states)
  769. for layer in self.hidden_layers:
  770. hidden_states = self.activation_fn(hidden_states)
  771. hidden_states = layer(hidden_states)
  772. return hidden_states
  773. class JanusVQVAEHead(nn.Module):
  774. """Head used for sampling tokens in image generation, replacing the usual lm head."""
  775. def __init__(self, config: JanusVQVAEConfig):
  776. super().__init__()
  777. self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
  778. self.activation_fn = ACT2FN[config.hidden_act]
  779. self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
  780. def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
  781. hidden_states = self.proj_out(hidden_states)
  782. hidden_states = self.activation_fn(hidden_states)
  783. hidden_states = self.vision_head(hidden_states)
  784. return hidden_states
  785. @auto_docstring(
  786. custom_intro="""
  787. The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
  788. """
  789. )
  790. class JanusModel(JanusPreTrainedModel):
  791. def __init__(self, config: JanusConfig):
  792. super().__init__(config)
  793. self.config = config
  794. # This is necessary for backward compatibility, see SiglipModel initialization
  795. self.vision_model = JanusVisionModel._from_config(config.vision_config)
  796. self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
  797. self.vqmodel = JanusVQVAE._from_config(config.vq_config)
  798. # Below generation_* modules are used for Image generation.
  799. # Embeddings used for image generation, instead of Janus vision embeddings.
  800. self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
  801. self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
  802. self.generation_head = JanusVQVAEHead(self.vqmodel.config)
  803. self.language_model = AutoModel.from_config(config=config.text_config)
  804. self.gradient_checkpointing = False
  805. # Initialize weights and apply final processing.
  806. self.post_init()
  807. def get_input_embeddings(self):
  808. return self.language_model.get_input_embeddings()
  809. def set_input_embeddings(self, value):
  810. self.language_model.set_input_embeddings(value)
  811. @can_return_tuple
  812. @auto_docstring
  813. def get_image_features(
  814. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  815. ) -> tuple | BaseModelOutputWithPooling:
  816. vision_outputs = self.vision_model(pixel_values, return_dict=True, **kwargs)
  817. vision_outputs.pooler_output = self.aligner(vision_outputs.last_hidden_state)
  818. return vision_outputs
  819. def get_placeholder_mask(
  820. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  821. ):
  822. """
  823. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  824. equal to the length of multimodal features. If the lengths are different, an error is raised.
  825. """
  826. if input_ids is None:
  827. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  828. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  829. )
  830. special_image_mask = special_image_mask.all(-1)
  831. else:
  832. special_image_mask = input_ids == self.config.image_token_id
  833. n_image_tokens = special_image_mask.sum()
  834. n_image_features = image_features.shape[0] * image_features.shape[1]
  835. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  836. torch_compilable_check(
  837. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  838. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  839. )
  840. return special_image_mask
  841. @can_return_tuple
  842. @auto_docstring
  843. def forward(
  844. self,
  845. input_ids: torch.LongTensor | None = None,
  846. pixel_values: torch.FloatTensor | None = None,
  847. attention_mask: torch.Tensor | None = None,
  848. position_ids: torch.LongTensor | None = None,
  849. past_key_values: Cache | None = None,
  850. inputs_embeds: torch.FloatTensor | None = None,
  851. use_cache: bool | None = None,
  852. logits_to_keep: int | torch.Tensor = 0,
  853. **kwargs,
  854. ) -> JanusBaseModelOutputWithPast:
  855. if (input_ids is None) ^ (inputs_embeds is not None):
  856. raise ValueError(
  857. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  858. )
  859. if inputs_embeds is None:
  860. inputs_embeds = self.get_input_embeddings()(input_ids)
  861. if pixel_values is not None:
  862. image_embeds = self.get_image_features(pixel_values, return_dict=True).pooler_output
  863. image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
  864. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  865. image_attention_mask = self.get_placeholder_mask(
  866. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  867. )
  868. inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
  869. lm_output = self.language_model(
  870. inputs_embeds=inputs_embeds,
  871. attention_mask=attention_mask,
  872. position_ids=position_ids,
  873. past_key_values=past_key_values,
  874. use_cache=use_cache,
  875. logits_to_keep=logits_to_keep,
  876. **kwargs,
  877. )
  878. return JanusBaseModelOutputWithPast(
  879. last_hidden_state=lm_output.last_hidden_state,
  880. past_key_values=lm_output.past_key_values,
  881. hidden_states=lm_output.hidden_states,
  882. attentions=lm_output.attentions,
  883. image_hidden_states=image_embeds if pixel_values is not None else None,
  884. )
  885. class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
  886. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  887. output_modalities = ("image", "text")
  888. _can_compile_fullgraph = True
  889. def __init__(self, config: JanusConfig):
  890. super().__init__(config)
  891. self.config = config
  892. self.model = JanusModel(config)
  893. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  894. # Initialize weights and apply final processing.
  895. self.post_init()
  896. def get_input_embeddings(self):
  897. return self.model.language_model.get_input_embeddings()
  898. def set_input_embeddings(self, value):
  899. self.model.language_model.set_input_embeddings(value)
  900. def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
  901. hidden_state = self.model.generation_embeddings(inputs)
  902. hidden_state = self.model.generation_aligner(hidden_state)
  903. return hidden_state
  904. @can_return_tuple
  905. @auto_docstring
  906. def forward(
  907. self,
  908. input_ids: torch.LongTensor | None = None,
  909. pixel_values: torch.FloatTensor | None = None,
  910. attention_mask: torch.Tensor | None = None,
  911. position_ids: torch.LongTensor | None = None,
  912. past_key_values: Cache | None = None,
  913. inputs_embeds: torch.FloatTensor | None = None,
  914. labels: torch.LongTensor | None = None,
  915. use_cache: bool | None = None,
  916. logits_to_keep: int | torch.Tensor = 0,
  917. **kwargs: Unpack[TransformersKwargs],
  918. ) -> JanusCausalLMOutputWithPast:
  919. r"""
  920. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  921. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  922. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  923. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  924. """
  925. outputs = self.model(
  926. input_ids=input_ids,
  927. pixel_values=pixel_values,
  928. attention_mask=attention_mask,
  929. position_ids=position_ids,
  930. past_key_values=past_key_values,
  931. inputs_embeds=inputs_embeds,
  932. use_cache=use_cache,
  933. **kwargs,
  934. )
  935. hidden_states = outputs.last_hidden_state
  936. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  937. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  938. logits = self.lm_head(hidden_states[:, slice_indices, :])
  939. loss = None
  940. if labels is not None:
  941. loss = self.loss_function(
  942. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  943. )
  944. return JanusCausalLMOutputWithPast(
  945. loss=loss,
  946. logits=logits,
  947. past_key_values=outputs.past_key_values,
  948. hidden_states=outputs.hidden_states,
  949. attentions=outputs.attentions,
  950. image_hidden_states=outputs.image_hidden_states,
  951. )
  952. def prepare_inputs_for_generation(
  953. self,
  954. input_ids,
  955. pixel_values=None,
  956. past_key_values=None,
  957. attention_mask=None,
  958. inputs_embeds=None,
  959. logits_to_keep=None,
  960. is_first_iteration=False,
  961. **kwargs,
  962. ):
  963. # Overwritten -- extra custom processing
  964. model_inputs = super().prepare_inputs_for_generation(
  965. input_ids,
  966. past_key_values=past_key_values,
  967. inputs_embeds=inputs_embeds,
  968. attention_mask=attention_mask,
  969. logits_to_keep=logits_to_keep,
  970. is_first_iteration=is_first_iteration,
  971. **kwargs,
  972. )
  973. # Pixel values are used only in the first iteration if available
  974. # In subsequent iterations, they are already merged with text and cached
  975. # NOTE: first iteration doesn't have to be prefill, it can be the first
  976. # iteration with a question and cached system prompt (continue generate from cache)
  977. if is_first_iteration or not kwargs.get("use_cache", True):
  978. model_inputs["pixel_values"] = pixel_values
  979. return model_inputs
  980. def decode_image_tokens(self, image_tokens: torch.Tensor):
  981. """
  982. Decodes generated image tokens from language model to continuous pixel values
  983. with VQGAN module via upsampling.
  984. Args:
  985. image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
  986. The tensors corresponding to the input images.
  987. """
  988. decoded_image = self.model.vqmodel.decode(image_tokens)
  989. decoded_image = decoded_image.permute(0, 2, 3, 1)
  990. return decoded_image
  991. @torch.no_grad()
  992. def generate(
  993. self,
  994. inputs: torch.Tensor | None = None,
  995. attention_mask: torch.LongTensor | None = None,
  996. logits_processor: LogitsProcessorList | None = None,
  997. **kwargs,
  998. ):
  999. # 1. Handle generation config and model kwargs
  1000. # Pop generation_mode first since it's specific to Janus
  1001. generation_mode = kwargs.pop("generation_mode", "text")
  1002. generation_config, model_kwargs = self._prepare_generation_config(
  1003. kwargs.pop("generation_config", None), **kwargs
  1004. )
  1005. # Default to "text" generation if mode isn't provided
  1006. if generation_mode == "text":
  1007. # Set guidance_scale=None to prevent running UnbatchedCFG processor.
  1008. return super().generate(
  1009. inputs=inputs,
  1010. attention_mask=attention_mask,
  1011. generation_config=generation_config,
  1012. guidance_scale=None,
  1013. **model_kwargs,
  1014. )
  1015. # Validate generation mode
  1016. if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
  1017. raise ValueError(
  1018. "Got incompatible mode for Image Generation, should be one of greedy or sampling. "
  1019. "Ensure that beam search is de-activated by setting `num_beams=1`."
  1020. )
  1021. # Validate the configuration and model kwargs
  1022. generation_config.validate()
  1023. self._validate_model_kwargs(model_kwargs.copy())
  1024. # 2. Initialize logit processors
  1025. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  1026. # Set `use_cache=True` as we will be using input embeds for generation.
  1027. model_kwargs["use_cache"] = True
  1028. if generation_config.guidance_scale is None:
  1029. logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
  1030. generation_config.guidance_scale = 5
  1031. model_kwargs["guidance_scale"] = generation_config.guidance_scale
  1032. # 3. Prepare model inputs
  1033. input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
  1034. inputs, generation_config.bos_token_id, model_kwargs
  1035. )
  1036. dtype, device = input_ids.dtype, input_ids.device
  1037. if len(input_ids.shape) != 2:
  1038. raise ValueError(
  1039. f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
  1040. "Passing `inputs embeds` is not supported currently."
  1041. )
  1042. # Prepare special tokens which will be used generate internally.
  1043. kwargs_has_attention_mask = attention_mask is not None
  1044. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
  1045. # 4. Add CFG processor along with user passed logit processor.
  1046. if generation_config.guidance_scale and generation_config.guidance_scale > 1:
  1047. logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
  1048. generation_config.guidance_scale = None # Reset to prevent processor duplication.
  1049. # 5. Prepare logits processor
  1050. logits_processor = self._get_logits_processor(
  1051. generation_config=generation_config,
  1052. input_ids_seq_length=input_ids.shape[1],
  1053. encoder_input_ids=input_ids,
  1054. prefix_allowed_tokens_fn=None,
  1055. logits_processor=logits_processor,
  1056. device=device,
  1057. )
  1058. # 6. Expand inputs for multiple image generations per prompt.
  1059. input_ids, model_kwargs = self._expand_inputs_for_generation(
  1060. input_ids=input_ids,
  1061. attention_mask=attention_mask,
  1062. expand_size=generation_config.num_return_sequences,
  1063. **model_kwargs,
  1064. )
  1065. # 7. Prepare input and model caches
  1066. num_image_tokens = self.model.vision_model.config.num_image_tokens
  1067. batch_size, seq_len = input_ids.shape
  1068. input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits
  1069. attention_mask = model_kwargs.pop("attention_mask", None)
  1070. attention_mask = attention_mask.repeat(2, 1)
  1071. model_kwargs["attention_mask"] = attention_mask
  1072. # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
  1073. mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
  1074. input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
  1075. )
  1076. input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
  1077. inputs_embeds = self.get_input_embeddings()(input_tokens)
  1078. if model_kwargs.get("past_key_values", None) is None:
  1079. # Prepare cache if not provided.
  1080. model_kwargs["past_key_values"] = self._prepare_static_cache(
  1081. cache_implementation=generation_config.cache_implementation or "static",
  1082. # batch_size should account for both conditional/unconditional input; hence multiplied by 2.
  1083. batch_size=batch_size * 2,
  1084. # we should have at least a cache len of seq_len + num_image_tokens.
  1085. max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
  1086. model_kwargs=model_kwargs,
  1087. )
  1088. # Placeholder for generated tokens.
  1089. generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
  1090. # 8. init attention / hidden states / scores tuples
  1091. output_attentions = generation_config.output_attentions
  1092. output_hidden_states = generation_config.output_hidden_states
  1093. output_scores = generation_config.output_scores
  1094. output_logits = generation_config.output_logits
  1095. return_dict_in_generate = generation_config.return_dict_in_generate
  1096. raw_scores = () if (return_dict_in_generate and output_scores) else None
  1097. raw_logits = () if (return_dict_in_generate and output_logits) else None
  1098. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  1099. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  1100. for i in range(num_image_tokens):
  1101. # Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`.
  1102. # Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt)
  1103. # instead of our prepared `inputs_embeds` (1 new token).
  1104. # This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`.
  1105. # Set this to `True` is also necessary to match the expected output, see the more detailed comment
  1106. # https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374.
  1107. model_inputs = self.prepare_inputs_for_generation(
  1108. inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs
  1109. )
  1110. if "attention_mask" in model_inputs:
  1111. model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
  1112. outputs = self.model.language_model(
  1113. **model_inputs,
  1114. output_attentions=output_attentions,
  1115. output_hidden_states=output_hidden_states,
  1116. )
  1117. # Update model_kwargs like attention_mask for next generation.
  1118. model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
  1119. hidden_state = outputs.last_hidden_state[:, -1, :].clone()
  1120. # Generate scores using the generation head (Not using above defined LM Head)
  1121. scores = self.model.generation_head(hidden_state)
  1122. next_token_scores = logits_processor(input_ids, scores)
  1123. # Sample next token.
  1124. if generation_config.do_sample:
  1125. probs = torch.softmax(next_token_scores, dim=-1)
  1126. next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
  1127. else:
  1128. next_token = torch.argmax(next_token_scores, dim=-1)
  1129. generated_tokens[:, i] = next_token
  1130. # Prepare embeddings for the next step.
  1131. next_token = torch.cat([next_token, next_token])
  1132. next_token = next_token.unsqueeze(-1)
  1133. inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
  1134. if return_dict_in_generate:
  1135. if output_scores:
  1136. raw_scores += (scores,)
  1137. if output_logits:
  1138. raw_logits += (hidden_state.float(),)
  1139. if output_attentions:
  1140. decoder_attentions += outputs.attentions
  1141. if output_hidden_states:
  1142. decoder_hidden_states += outputs.hidden_states
  1143. if return_dict_in_generate:
  1144. return GenerateDecoderOnlyOutput(
  1145. sequences=generated_tokens,
  1146. scores=scores,
  1147. logits=raw_logits,
  1148. attentions=decoder_attentions,
  1149. hidden_states=decoder_hidden_states,
  1150. past_key_values=outputs.past_key_values,
  1151. )
  1152. else:
  1153. return generated_tokens
  1154. __all__ = ["JanusPreTrainedModel", "JanusForConditionalGeneration", "JanusModel", "JanusVQVAE", "JanusVisionModel"]