modular_janus.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. # Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. from dataclasses import dataclass
  16. import torch
  17. import torch.nn.functional as F
  18. from huggingface_hub.dataclasses import strict
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache
  23. from ...configuration_utils import PreTrainedConfig
  24. from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
  25. from ...generation.utils import GenerateDecoderOnlyOutput
  26. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import (
  30. TransformersKwargs,
  31. auto_docstring,
  32. can_return_tuple,
  33. is_vision_available,
  34. logging,
  35. torch_compilable_check,
  36. )
  37. from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
  38. from ..blip_2.modeling_blip_2 import Blip2VisionModel
  39. from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
  40. from ..chameleon.modeling_chameleon import (
  41. ChameleonVQVAE,
  42. ChameleonVQVAEEncoderAttnBlock,
  43. ChameleonVQVAEEncoderConvDownsample,
  44. ChameleonVQVAEEncoderResnetBlock,
  45. ChameleonVQVAEVectorQuantizer,
  46. )
  47. from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
  48. from ..llama.modeling_llama import eager_attention_forward
  49. from ..siglip.configuration_siglip import SiglipVisionConfig
  50. from ..siglip.modeling_siglip import SiglipEncoder, SiglipEncoderLayer, SiglipVisionEmbeddings
  51. if is_vision_available():
  52. pass
  53. logger = logging.get_logger(__name__)
  54. # General docstring
  55. @auto_docstring(checkpoint="deepseek-community/Janus-Pro-1B")
  56. @strict
  57. class JanusVisionConfig(SiglipVisionConfig):
  58. r"""
  59. projection_dropout (`float`, *optional*, defaults to 0.0):
  60. Dropout probability for the projection layer.
  61. num_image_tokens (`int`, *optional*, defaults to 576):
  62. Number of image tokens.
  63. """
  64. hidden_size: int = 1024
  65. num_hidden_layers: int = 24
  66. num_attention_heads: int = 16
  67. image_size: int | list[int] | tuple[int, int] = 384
  68. hidden_act: str = "gelu"
  69. mlp_ratio: float | int = 4.0
  70. attention_bias: bool = True
  71. hidden_dropout_rate: float | int = 0.0
  72. projection_dim: int = 2048
  73. projection_dropout: float | int = 0.0
  74. use_qk_norm: bool = False
  75. initializer_range: float = 0.02
  76. depth: int = 2
  77. num_image_tokens: int = 576
  78. intermediate_size = AttributeError()
  79. @auto_docstring(checkpoint="deepseek-community/Janus-Pro-1B")
  80. @strict
  81. class JanusVQVAEConfig(ChameleonVQVAEConfig):
  82. r"""
  83. base_channels (`int`, *optional*, defaults to 128):
  84. Base channel count.
  85. channel_multiplier (`list[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
  86. Channel multipliers for each resolution.
  87. num_res_blocks (`int`, *optional*, defaults to 2):
  88. Number of residual blocks.
  89. num_patches (`int`, *optional*, defaults to 32):
  90. Num of patches the input images can be divided into.
  91. out_channels (`int`, *optional*, defaults to 3):
  92. Number of out channels.
  93. image_token_embed_dim (`int`, *optional*, defaults to 2048):
  94. Dimension of image embeddings. It should be same as the dimensionality of text embeddings.
  95. """
  96. embed_dim: int = 8
  97. num_embeddings: int = 16384
  98. double_latent: bool = False
  99. latent_channels: int = 256
  100. num_patches: int = 32
  101. in_channels: int = 3
  102. out_channels: int = 3
  103. base_channels: int = 128
  104. channel_multiplier: list[int] | tuple[int, ...] = (1, 1, 2, 2, 4)
  105. num_res_blocks: int = 2
  106. dropout: float | int = 0.0
  107. initializer_range: float = 0.02
  108. projection_dim: int = 2048
  109. num_hidden_layers: int = 2
  110. hidden_act: str = "gelu"
  111. image_token_embed_dim: int = 2048
  112. resolution = AttributeError()
  113. attn_resolutions = AttributeError()
  114. attn_type = AttributeError()
  115. @auto_docstring(checkpoint="deepseek-community/Janus-Pro-1B")
  116. @strict
  117. class JanusConfig(PreTrainedConfig):
  118. r"""
  119. Example:
  120. ```python
  121. >>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig
  122. >>> # Initializing a Janus vision config
  123. >>> vision_config = JanusVisionConfig()
  124. >>> # Initializing a Llama config
  125. >>> text_config = LlamaConfig()
  126. >>> # Initializing a VQ config
  127. >>> vq_config = JanusVQVAEConfig()
  128. >>> # Initializing a Janus Pro 1B style configuration
  129. >>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config)
  130. >>> # Initializing a model from the Janus Pro 1B style configuration
  131. >>> model = JanusForConditionalGeneration(configuration)
  132. >>> # Accessing the model configuration
  133. >>> configuration = model.config
  134. ```"""
  135. model_type = "janus"
  136. sub_configs = {
  137. "text_config": AutoConfig,
  138. "vision_config": JanusVisionConfig,
  139. "vq_config": JanusVQVAEConfig,
  140. }
  141. text_config: dict | PreTrainedConfig | None = None
  142. vision_config: dict | PreTrainedConfig | None = None
  143. vq_config: dict | PreTrainedConfig | None = None
  144. image_token_id: int = 100581
  145. def __post_init__(self, **kwargs):
  146. if isinstance(self.text_config, dict):
  147. self.text_config["model_type"] = self.text_config.get("model_type", "llama")
  148. self.text_config = CONFIG_MAPPING[self.text_config["model_type"]](**self.text_config)
  149. elif self.text_config is None:
  150. logger.info("`text_config` is None. Initializing with default values")
  151. self.text_config = CONFIG_MAPPING["llama"]()
  152. if self.vision_config is None:
  153. logger.info("`vision_config` is None. Initializing with default JanusVisionConfig values")
  154. self.vision_config = JanusVisionConfig()
  155. elif isinstance(self.vision_config, dict):
  156. self.vision_config = JanusVisionConfig(**self.vision_config)
  157. if self.vq_config is None:
  158. logger.info("`vq_config` is None. Initializing with default JanusVQVAEConfig values")
  159. self.vq_config = JanusVQVAEConfig()
  160. elif isinstance(self.vq_config, dict):
  161. self.vq_config = JanusVQVAEConfig(**self.vq_config)
  162. # This dimension is required when decoding discrete image tokens to continuous input.
  163. self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size
  164. super().__post_init__(**kwargs)
  165. @auto_docstring
  166. class JanusPreTrainedModel(PreTrainedModel):
  167. config: JanusConfig
  168. base_model_prefix = "model"
  169. input_modalities = ("image", "text")
  170. supports_gradient_checkpointing = True
  171. _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
  172. _skip_keys_device_placement = ["past_key_values", "causal_mask"]
  173. _supports_flash_attn = True
  174. _supports_sdpa = True
  175. _can_compile_fullgraph = True
  176. def _init_weights(self, module):
  177. super()._init_weights(module)
  178. if isinstance(module, JanusVisionEmbeddings):
  179. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  180. @dataclass
  181. @auto_docstring(
  182. custom_intro="""
  183. Base class for Janus VQ-VAE mode model outputs.
  184. """
  185. )
  186. class JanusVQVAEOutput(ModelOutput):
  187. r"""
  188. decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  189. Reconstructed pixel values after encoding and decoding the input.
  190. embedding_loss (`torch.FloatTensor`):
  191. Embedding loss.
  192. """
  193. decoded_pixel_values: torch.FloatTensor | None = None
  194. embedding_loss: torch.FloatTensor | None = None
  195. class JanusBaseModelOutputWithPast(IdeficsBaseModelOutputWithPast):
  196. pass
  197. class JanusCausalLMOutputWithPast(IdeficsCausalLMOutputWithPast):
  198. pass
  199. class JanusVisionEmbeddings(SiglipVisionEmbeddings):
  200. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  201. _, _, height, width = pixel_values.shape
  202. target_dtype = self.patch_embedding.weight.dtype
  203. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  204. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  205. if interpolate_pos_encoding:
  206. pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
  207. else:
  208. pos_embeds = self.position_embedding(self.position_ids)
  209. embeddings = embeddings + pos_embeds
  210. return embeddings
  211. class JanusVisionAttention(nn.Module):
  212. """Attention Class for Janus Vision Encoder"""
  213. def __init__(self, config: JanusVisionConfig):
  214. super().__init__()
  215. self.config = config
  216. self.embed_dim = config.hidden_size
  217. self.num_heads = config.num_attention_heads
  218. self.head_dim = self.embed_dim // self.num_heads
  219. if self.head_dim * self.num_heads != self.embed_dim:
  220. raise ValueError(
  221. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  222. f" {self.num_heads})."
  223. )
  224. self.scale = self.head_dim**-0.5
  225. self.attention_dropout = config.attention_dropout
  226. proj_dropout = config.projection_dropout
  227. qk_norm = config.use_qk_norm
  228. self.is_causal = False
  229. # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
  230. self.num_key_value_groups = 1
  231. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  232. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  233. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  234. self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
  235. self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
  236. self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  237. self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  238. def forward(
  239. self,
  240. hidden_states: torch.Tensor,
  241. attention_mask: torch.Tensor | None = None,
  242. **kwargs: Unpack[TransformersKwargs],
  243. ):
  244. batch_size, seq_len, _ = hidden_states.size()
  245. query_states = self.q_proj(hidden_states)
  246. key_states = self.k_proj(hidden_states)
  247. value_states = self.v_proj(hidden_states)
  248. query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
  249. query_states = self.q_norm(query_states)
  250. key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
  251. key_states = self.k_norm(key_states)
  252. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  253. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  254. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  255. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  256. self.config._attn_implementation, eager_attention_forward
  257. )
  258. attn_output, attn_weights = attention_interface(
  259. self,
  260. query_states,
  261. key_states,
  262. value_states,
  263. attention_mask,
  264. dropout=0.0 if not self.training else self.attention_dropout,
  265. scaling=self.scale,
  266. is_causal=self.is_causal,
  267. **kwargs,
  268. )
  269. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  270. output = self.projection_layer(attn_output)
  271. output = self.projection_dropout(output)
  272. return output, attn_weights
  273. class JanusVisionMLP(nn.Module):
  274. def __init__(self, config: JanusVisionConfig):
  275. super().__init__()
  276. self.config = config
  277. self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
  278. self.activation_fn = ACT2FN[config.hidden_act] # Gelu act
  279. self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
  280. self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
  281. self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
  282. self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
  283. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  284. hidden_states = self.fc1(hidden_states)
  285. hidden_states = self.activation_fn(hidden_states)
  286. hidden_states = self.dropout1(hidden_states)
  287. hidden_states = self.fc2(hidden_states)
  288. hidden_states = self.dropout2(hidden_states)
  289. return hidden_states
  290. class JanusVisionEncoderLayer(SiglipEncoderLayer):
  291. def __init__(self, config: JanusVisionConfig):
  292. super().__init__(config)
  293. self.config = config
  294. self.embed_dim = config.hidden_size
  295. self.self_attn = JanusVisionAttention(config)
  296. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  297. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  298. self.mlp = JanusVisionMLP(config)
  299. class JanusVisionEncoder(SiglipEncoder):
  300. def __init__(self, config: JanusVisionConfig):
  301. super().__init__(config)
  302. self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  303. class JanusVisionModel(Blip2VisionModel):
  304. _can_record_outputs = {
  305. "hidden_states": JanusVisionEncoderLayer,
  306. "attentions": JanusVisionAttention,
  307. }
  308. def __init__(self, config: JanusVisionConfig):
  309. super().__init__(config)
  310. self.encoder = JanusVisionEncoder(config)
  311. def forward(
  312. self,
  313. pixel_values: torch.FloatTensor | None = None,
  314. interpolate_pos_encoding: bool = False,
  315. **kwargs: Unpack[TransformersKwargs],
  316. ) -> tuple | BaseModelOutputWithPooling:
  317. if pixel_values is None:
  318. raise ValueError("You have to specify pixel_values")
  319. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  320. encoder_outputs: BaseModelOutput = self.encoder(
  321. inputs_embeds=hidden_states,
  322. **kwargs,
  323. )
  324. last_hidden_state = encoder_outputs.last_hidden_state
  325. last_hidden_state = self.post_layernorm(last_hidden_state)
  326. pooled_output = last_hidden_state[:, 0, :]
  327. pooled_output = self.post_layernorm(pooled_output)
  328. return BaseModelOutputWithPooling(
  329. last_hidden_state=last_hidden_state,
  330. pooler_output=pooled_output,
  331. )
  332. class JanusVisionAlignerMLP(nn.Module):
  333. def __init__(self, config: JanusVisionConfig):
  334. super().__init__()
  335. self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
  336. self.hidden_layers = nn.ModuleList(
  337. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
  338. )
  339. self.activation_fn = ACT2FN[config.hidden_act]
  340. def forward(self, hidden_states):
  341. hidden_states = self.fc1(hidden_states)
  342. for layer in self.hidden_layers:
  343. hidden_states = self.activation_fn(hidden_states)
  344. hidden_states = layer(hidden_states)
  345. return hidden_states
  346. class JanusVQVAEVectorQuantizer(ChameleonVQVAEVectorQuantizer):
  347. def __init__(self, config: JanusVQVAEConfig):
  348. super().__init__(config)
  349. self.quant_state_dims = [config.num_patches] * 2
  350. def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  351. batch_size = image_tokens.shape[0]
  352. emb_dim: int = self.embedding.weight.shape[-1]
  353. # get quantized latent vectors
  354. hidden_state_quant = self.embedding(image_tokens)
  355. # l2 normalization on the last dimension
  356. hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
  357. # reshape back to match original input shape
  358. hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
  359. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  360. return hidden_state_quant
  361. class JanusVQVAEResnetBlock(ChameleonVQVAEEncoderResnetBlock):
  362. pass
  363. class JanusVQVAEAttnBlock(ChameleonVQVAEEncoderAttnBlock):
  364. pass
  365. class JanusVQVAEConvDownsample(ChameleonVQVAEEncoderConvDownsample):
  366. pass
  367. class JanusVQVAEConvUpsample(nn.Module):
  368. def __init__(self, in_channels):
  369. super().__init__()
  370. self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  371. def forward(self, hidden_states):
  372. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  373. hidden_states = self.conv(hidden_states)
  374. return hidden_states
  375. class JanusVQVAEMidBlock(nn.Module):
  376. def __init__(self, config: JanusVQVAEConfig, channels: int):
  377. super().__init__()
  378. self.block_1 = JanusVQVAEResnetBlock(
  379. config=config,
  380. in_channels=channels,
  381. out_channels=channels,
  382. )
  383. self.attn_1 = JanusVQVAEAttnBlock(channels)
  384. self.block_2 = JanusVQVAEResnetBlock(
  385. config=config,
  386. in_channels=channels,
  387. out_channels=channels,
  388. )
  389. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  390. hidden_states = self.block_1(hidden_states)
  391. hidden_states = self.attn_1(hidden_states)
  392. hidden_states = self.block_2(hidden_states)
  393. return hidden_states
  394. class JanusVQVAEEncoder(nn.Module):
  395. def __init__(self, config):
  396. super().__init__()
  397. self.num_resolutions = len(config.channel_multiplier)
  398. self.num_res_blocks = config.num_res_blocks
  399. base_channels = config.base_channels
  400. in_channels = config.in_channels
  401. double_latent = config.double_latent
  402. latent_channels = config.latent_channels
  403. channel_multiplier = config.channel_multiplier
  404. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  405. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  406. self.in_channel_multiplier = in_channel_multiplier
  407. self.down = nn.ModuleList()
  408. for i_level in range(self.num_resolutions):
  409. block = nn.ModuleList()
  410. attn = nn.ModuleList()
  411. block_in = base_channels * in_channel_multiplier[i_level]
  412. block_out = base_channels * channel_multiplier[i_level]
  413. for i_block in range(self.num_res_blocks):
  414. block.append(
  415. JanusVQVAEResnetBlock(
  416. config=config,
  417. in_channels=block_in,
  418. out_channels=block_out,
  419. )
  420. )
  421. block_in = block_out
  422. if i_level == self.num_resolutions - 1:
  423. attn.append(JanusVQVAEAttnBlock(block_in))
  424. down = nn.Module()
  425. down.block = block
  426. down.attn = attn
  427. if i_level != self.num_resolutions - 1:
  428. down.downsample = JanusVQVAEConvDownsample(block_in)
  429. self.down.append(down)
  430. self.mid = JanusVQVAEMidBlock(config, block_in)
  431. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  432. self.conv_out = torch.nn.Conv2d(
  433. block_in,
  434. 2 * latent_channels if double_latent else latent_channels,
  435. kernel_size=3,
  436. stride=1,
  437. padding=1,
  438. )
  439. def forward(self, pixel_values: torch.LongTensor):
  440. # downsampling
  441. hidden_states = [self.conv_in(pixel_values)]
  442. for i_level in range(self.num_resolutions):
  443. for i_block in range(self.num_res_blocks):
  444. hidden_state = self.down[i_level].block[i_block](
  445. hidden_states[-1],
  446. )
  447. if len(self.down[i_level].attn) > 0:
  448. hidden_state = self.down[i_level].attn[i_block](hidden_state)
  449. hidden_states.append(hidden_state)
  450. if i_level != self.num_resolutions - 1:
  451. hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
  452. # middle
  453. last_hidden_state = hidden_states[-1]
  454. last_hidden_state = self.mid(last_hidden_state)
  455. # end
  456. last_hidden_state = self.norm_out(last_hidden_state)
  457. last_hidden_state *= torch.sigmoid(last_hidden_state)
  458. last_hidden_state = self.conv_out(last_hidden_state)
  459. return last_hidden_state
  460. class JanusVQVAEDecoder(nn.Module):
  461. def __init__(self, config):
  462. super().__init__()
  463. self.num_resolutions = len(config.channel_multiplier)
  464. self.num_res_blocks = config.num_res_blocks
  465. base_channels = config.base_channels
  466. latent_channels = config.latent_channels
  467. out_channels = config.out_channels
  468. # compute in_ch_mult, block_in and curr_res at lowest res
  469. block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
  470. # z to block_in
  471. self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
  472. # middle
  473. self.mid = JanusVQVAEMidBlock(config, block_in)
  474. # upsampling
  475. self.up = nn.ModuleList()
  476. for i_level in reversed(range(self.num_resolutions)):
  477. block = nn.ModuleList()
  478. attn = nn.ModuleList()
  479. block_out = base_channels * config.channel_multiplier[i_level]
  480. for i_block in range(self.num_res_blocks + 1):
  481. block.append(
  482. JanusVQVAEResnetBlock(
  483. config=config,
  484. in_channels=block_in,
  485. out_channels=block_out,
  486. )
  487. )
  488. block_in = block_out
  489. if i_level == self.num_resolutions - 1:
  490. attn.append(JanusVQVAEAttnBlock(block_in))
  491. up = nn.Module()
  492. up.block = block
  493. up.attn = attn
  494. if i_level != 0:
  495. up.upsample = JanusVQVAEConvUpsample(block_in)
  496. self.up.append(up)
  497. # end
  498. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  499. self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
  500. def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
  501. hidden_state = self.conv_in(hidden_state)
  502. # middle
  503. hidden_state = self.mid(hidden_state)
  504. # upsampling
  505. for i_level in range(self.num_resolutions):
  506. for i_block in range(self.num_res_blocks + 1):
  507. hidden_state = self.up[i_level].block[i_block](hidden_state)
  508. if len(self.up[i_level].attn) > 0:
  509. hidden_state = self.up[i_level].attn[i_block](hidden_state)
  510. if i_level != self.num_resolutions - 1:
  511. hidden_state = self.up[i_level].upsample(hidden_state)
  512. hidden_state = self.norm_out(hidden_state)
  513. hidden_state *= torch.sigmoid(hidden_state)
  514. hidden_state = self.conv_out(hidden_state)
  515. return hidden_state
  516. class JanusVQVAE(ChameleonVQVAE):
  517. _no_split_modules = [
  518. "JanusVQVAEAttnBlock",
  519. "JanusVQVAEResnetBlock",
  520. "JanusVQVAEVectorQuantizer",
  521. ]
  522. _can_record_outputs = {
  523. "hidden_states": JanusVQVAEResnetBlock,
  524. "attentions": JanusVQVAEAttnBlock,
  525. }
  526. main_input_name = "pixel_values"
  527. def __init__(self, config: JanusVQVAEConfig):
  528. super().__init__(config)
  529. self.decoder = JanusVQVAEDecoder(config)
  530. self.gradient_checkpointing = False
  531. # Initialize the VQVAE model.
  532. self.post_init()
  533. def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  534. """
  535. Decodes quantized token IDs into pixel values.
  536. Args:
  537. image_tokens (torch.LongTensor): Batch of token IDs.
  538. Returns:
  539. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  540. Pixel values decoded from the token IDs.
  541. """
  542. if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
  543. raise ValueError(
  544. f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
  545. f"but got shape `{image_tokens.shape}`."
  546. )
  547. codebook_entry = self.quantize.get_codebook_entry(image_tokens)
  548. hidden_states = self.post_quant_conv(codebook_entry)
  549. pixel_values = self.decoder(hidden_states)
  550. return pixel_values
  551. @can_return_tuple
  552. @auto_docstring
  553. def forward(
  554. self,
  555. pixel_values: torch.FloatTensor,
  556. **kwargs,
  557. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  558. batch_size = pixel_values.shape[0]
  559. encode_outputs = self.encode(pixel_values, return_dict=True, **kwargs)
  560. decoded_pixel_values = self.decode(encode_outputs.image_tokens.view(batch_size, -1))
  561. return JanusVQVAEOutput(decoded_pixel_values, encode_outputs.embedding_loss)
  562. class JanusVQVAEAlignerMLP(nn.Module):
  563. def __init__(self, config: JanusVQVAEConfig):
  564. super().__init__()
  565. self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
  566. self.hidden_layers = nn.ModuleList(
  567. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
  568. )
  569. self.activation_fn = ACT2FN[config.hidden_act]
  570. def forward(self, hidden_states):
  571. hidden_states = self.fc1(hidden_states)
  572. for layer in self.hidden_layers:
  573. hidden_states = self.activation_fn(hidden_states)
  574. hidden_states = layer(hidden_states)
  575. return hidden_states
  576. class JanusVQVAEHead(nn.Module):
  577. """Head used for sampling tokens in image generation, replacing the usual lm head."""
  578. def __init__(self, config: JanusVQVAEConfig):
  579. super().__init__()
  580. self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
  581. self.activation_fn = ACT2FN[config.hidden_act]
  582. self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
  583. def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
  584. hidden_states = self.proj_out(hidden_states)
  585. hidden_states = self.activation_fn(hidden_states)
  586. hidden_states = self.vision_head(hidden_states)
  587. return hidden_states
  588. @auto_docstring(
  589. custom_intro="""
  590. The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
  591. """
  592. )
  593. class JanusModel(JanusPreTrainedModel):
  594. def __init__(self, config: JanusConfig):
  595. super().__init__(config)
  596. self.config = config
  597. # This is necessary for backward compatibility, see SiglipModel initialization
  598. self.vision_model = JanusVisionModel._from_config(config.vision_config)
  599. self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
  600. self.vqmodel = JanusVQVAE._from_config(config.vq_config)
  601. # Below generation_* modules are used for Image generation.
  602. # Embeddings used for image generation, instead of Janus vision embeddings.
  603. self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
  604. self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
  605. self.generation_head = JanusVQVAEHead(self.vqmodel.config)
  606. self.language_model = AutoModel.from_config(config=config.text_config)
  607. self.gradient_checkpointing = False
  608. # Initialize weights and apply final processing.
  609. self.post_init()
  610. def get_input_embeddings(self):
  611. return self.language_model.get_input_embeddings()
  612. def set_input_embeddings(self, value):
  613. self.language_model.set_input_embeddings(value)
  614. @can_return_tuple
  615. @auto_docstring
  616. def get_image_features(
  617. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  618. ) -> tuple | BaseModelOutputWithPooling:
  619. vision_outputs = self.vision_model(pixel_values, return_dict=True, **kwargs)
  620. vision_outputs.pooler_output = self.aligner(vision_outputs.last_hidden_state)
  621. return vision_outputs
  622. def get_placeholder_mask(
  623. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  624. ):
  625. """
  626. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  627. equal to the length of multimodal features. If the lengths are different, an error is raised.
  628. """
  629. if input_ids is None:
  630. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  631. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  632. )
  633. special_image_mask = special_image_mask.all(-1)
  634. else:
  635. special_image_mask = input_ids == self.config.image_token_id
  636. n_image_tokens = special_image_mask.sum()
  637. n_image_features = image_features.shape[0] * image_features.shape[1]
  638. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  639. torch_compilable_check(
  640. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  641. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  642. )
  643. return special_image_mask
  644. @can_return_tuple
  645. @auto_docstring
  646. def forward(
  647. self,
  648. input_ids: torch.LongTensor | None = None,
  649. pixel_values: torch.FloatTensor | None = None,
  650. attention_mask: torch.Tensor | None = None,
  651. position_ids: torch.LongTensor | None = None,
  652. past_key_values: Cache | None = None,
  653. inputs_embeds: torch.FloatTensor | None = None,
  654. use_cache: bool | None = None,
  655. logits_to_keep: int | torch.Tensor = 0,
  656. **kwargs,
  657. ) -> JanusBaseModelOutputWithPast:
  658. if (input_ids is None) ^ (inputs_embeds is not None):
  659. raise ValueError(
  660. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  661. )
  662. if inputs_embeds is None:
  663. inputs_embeds = self.get_input_embeddings()(input_ids)
  664. if pixel_values is not None:
  665. image_embeds = self.get_image_features(pixel_values, return_dict=True).pooler_output
  666. image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
  667. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  668. image_attention_mask = self.get_placeholder_mask(
  669. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  670. )
  671. inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
  672. lm_output = self.language_model(
  673. inputs_embeds=inputs_embeds,
  674. attention_mask=attention_mask,
  675. position_ids=position_ids,
  676. past_key_values=past_key_values,
  677. use_cache=use_cache,
  678. logits_to_keep=logits_to_keep,
  679. **kwargs,
  680. )
  681. return JanusBaseModelOutputWithPast(
  682. last_hidden_state=lm_output.last_hidden_state,
  683. past_key_values=lm_output.past_key_values,
  684. hidden_states=lm_output.hidden_states,
  685. attentions=lm_output.attentions,
  686. image_hidden_states=image_embeds if pixel_values is not None else None,
  687. )
  688. class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
  689. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  690. output_modalities = ("image", "text")
  691. _can_compile_fullgraph = True
  692. def __init__(self, config: JanusConfig):
  693. super().__init__(config)
  694. self.config = config
  695. self.model = JanusModel(config)
  696. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  697. # Initialize weights and apply final processing.
  698. self.post_init()
  699. def get_input_embeddings(self):
  700. return self.model.language_model.get_input_embeddings()
  701. def set_input_embeddings(self, value):
  702. self.model.language_model.set_input_embeddings(value)
  703. def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
  704. hidden_state = self.model.generation_embeddings(inputs)
  705. hidden_state = self.model.generation_aligner(hidden_state)
  706. return hidden_state
  707. @can_return_tuple
  708. @auto_docstring
  709. def forward(
  710. self,
  711. input_ids: torch.LongTensor | None = None,
  712. pixel_values: torch.FloatTensor | None = None,
  713. attention_mask: torch.Tensor | None = None,
  714. position_ids: torch.LongTensor | None = None,
  715. past_key_values: Cache | None = None,
  716. inputs_embeds: torch.FloatTensor | None = None,
  717. labels: torch.LongTensor | None = None,
  718. use_cache: bool | None = None,
  719. logits_to_keep: int | torch.Tensor = 0,
  720. **kwargs: Unpack[TransformersKwargs],
  721. ) -> JanusCausalLMOutputWithPast:
  722. r"""
  723. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  724. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  725. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  726. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  727. """
  728. outputs = self.model(
  729. input_ids=input_ids,
  730. pixel_values=pixel_values,
  731. attention_mask=attention_mask,
  732. position_ids=position_ids,
  733. past_key_values=past_key_values,
  734. inputs_embeds=inputs_embeds,
  735. use_cache=use_cache,
  736. **kwargs,
  737. )
  738. hidden_states = outputs.last_hidden_state
  739. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  740. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  741. logits = self.lm_head(hidden_states[:, slice_indices, :])
  742. loss = None
  743. if labels is not None:
  744. loss = self.loss_function(
  745. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  746. )
  747. return JanusCausalLMOutputWithPast(
  748. loss=loss,
  749. logits=logits,
  750. past_key_values=outputs.past_key_values,
  751. hidden_states=outputs.hidden_states,
  752. attentions=outputs.attentions,
  753. image_hidden_states=outputs.image_hidden_states,
  754. )
  755. def prepare_inputs_for_generation(
  756. self,
  757. input_ids,
  758. pixel_values=None,
  759. past_key_values=None,
  760. attention_mask=None,
  761. inputs_embeds=None,
  762. logits_to_keep=None,
  763. is_first_iteration=False,
  764. **kwargs,
  765. ):
  766. # Overwritten -- extra custom processing
  767. model_inputs = super().prepare_inputs_for_generation(
  768. input_ids,
  769. past_key_values=past_key_values,
  770. inputs_embeds=inputs_embeds,
  771. attention_mask=attention_mask,
  772. logits_to_keep=logits_to_keep,
  773. is_first_iteration=is_first_iteration,
  774. **kwargs,
  775. )
  776. # Pixel values are used only in the first iteration if available
  777. # In subsequent iterations, they are already merged with text and cached
  778. # NOTE: first iteration doesn't have to be prefill, it can be the first
  779. # iteration with a question and cached system prompt (continue generate from cache)
  780. if is_first_iteration or not kwargs.get("use_cache", True):
  781. model_inputs["pixel_values"] = pixel_values
  782. return model_inputs
  783. def decode_image_tokens(self, image_tokens: torch.Tensor):
  784. """
  785. Decodes generated image tokens from language model to continuous pixel values
  786. with VQGAN module via upsampling.
  787. Args:
  788. image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
  789. The tensors corresponding to the input images.
  790. """
  791. decoded_image = self.model.vqmodel.decode(image_tokens)
  792. decoded_image = decoded_image.permute(0, 2, 3, 1)
  793. return decoded_image
  794. @torch.no_grad()
  795. def generate(
  796. self,
  797. inputs: torch.Tensor | None = None,
  798. attention_mask: torch.LongTensor | None = None,
  799. logits_processor: LogitsProcessorList | None = None,
  800. **kwargs,
  801. ):
  802. # 1. Handle generation config and model kwargs
  803. # Pop generation_mode first since it's specific to Janus
  804. generation_mode = kwargs.pop("generation_mode", "text")
  805. generation_config, model_kwargs = self._prepare_generation_config(
  806. kwargs.pop("generation_config", None), **kwargs
  807. )
  808. # Default to "text" generation if mode isn't provided
  809. if generation_mode == "text":
  810. # Set guidance_scale=None to prevent running UnbatchedCFG processor.
  811. return super().generate(
  812. inputs=inputs,
  813. attention_mask=attention_mask,
  814. generation_config=generation_config,
  815. guidance_scale=None,
  816. **model_kwargs,
  817. )
  818. # Validate generation mode
  819. if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
  820. raise ValueError(
  821. "Got incompatible mode for Image Generation, should be one of greedy or sampling. "
  822. "Ensure that beam search is de-activated by setting `num_beams=1`."
  823. )
  824. # Validate the configuration and model kwargs
  825. generation_config.validate()
  826. self._validate_model_kwargs(model_kwargs.copy())
  827. # 2. Initialize logit processors
  828. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  829. # Set `use_cache=True` as we will be using input embeds for generation.
  830. model_kwargs["use_cache"] = True
  831. if generation_config.guidance_scale is None:
  832. logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
  833. generation_config.guidance_scale = 5
  834. model_kwargs["guidance_scale"] = generation_config.guidance_scale
  835. # 3. Prepare model inputs
  836. input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
  837. inputs, generation_config.bos_token_id, model_kwargs
  838. )
  839. dtype, device = input_ids.dtype, input_ids.device
  840. if len(input_ids.shape) != 2:
  841. raise ValueError(
  842. f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
  843. "Passing `inputs embeds` is not supported currently."
  844. )
  845. # Prepare special tokens which will be used generate internally.
  846. kwargs_has_attention_mask = attention_mask is not None
  847. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
  848. # 4. Add CFG processor along with user passed logit processor.
  849. if generation_config.guidance_scale and generation_config.guidance_scale > 1:
  850. logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
  851. generation_config.guidance_scale = None # Reset to prevent processor duplication.
  852. # 5. Prepare logits processor
  853. logits_processor = self._get_logits_processor(
  854. generation_config=generation_config,
  855. input_ids_seq_length=input_ids.shape[1],
  856. encoder_input_ids=input_ids,
  857. prefix_allowed_tokens_fn=None,
  858. logits_processor=logits_processor,
  859. device=device,
  860. )
  861. # 6. Expand inputs for multiple image generations per prompt.
  862. input_ids, model_kwargs = self._expand_inputs_for_generation(
  863. input_ids=input_ids,
  864. attention_mask=attention_mask,
  865. expand_size=generation_config.num_return_sequences,
  866. **model_kwargs,
  867. )
  868. # 7. Prepare input and model caches
  869. num_image_tokens = self.model.vision_model.config.num_image_tokens
  870. batch_size, seq_len = input_ids.shape
  871. input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits
  872. attention_mask = model_kwargs.pop("attention_mask", None)
  873. attention_mask = attention_mask.repeat(2, 1)
  874. model_kwargs["attention_mask"] = attention_mask
  875. # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
  876. mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
  877. input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
  878. )
  879. input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
  880. inputs_embeds = self.get_input_embeddings()(input_tokens)
  881. if model_kwargs.get("past_key_values", None) is None:
  882. # Prepare cache if not provided.
  883. model_kwargs["past_key_values"] = self._prepare_static_cache(
  884. cache_implementation=generation_config.cache_implementation or "static",
  885. # batch_size should account for both conditional/unconditional input; hence multiplied by 2.
  886. batch_size=batch_size * 2,
  887. # we should have at least a cache len of seq_len + num_image_tokens.
  888. max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
  889. model_kwargs=model_kwargs,
  890. )
  891. # Placeholder for generated tokens.
  892. generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
  893. # 8. init attention / hidden states / scores tuples
  894. output_attentions = generation_config.output_attentions
  895. output_hidden_states = generation_config.output_hidden_states
  896. output_scores = generation_config.output_scores
  897. output_logits = generation_config.output_logits
  898. return_dict_in_generate = generation_config.return_dict_in_generate
  899. raw_scores = () if (return_dict_in_generate and output_scores) else None
  900. raw_logits = () if (return_dict_in_generate and output_logits) else None
  901. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  902. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  903. for i in range(num_image_tokens):
  904. # Set `is_first_iteration=True` to force using `inputs_embeds` instead of `input_ids`.
  905. # Without this, `prepare_inputs_for_generation` would use `input_ids` (the full prompt)
  906. # instead of our prepared `inputs_embeds` (1 new token).
  907. # This causes CUDA error: device-side assert triggered, seen around the call to ` self.self_attn`.
  908. # Set this to `True` is also necessary to match the expected output, see the more detailed comment
  909. # https://github.com/huggingface/transformers/pull/45044#discussion_r3020805374.
  910. model_inputs = self.prepare_inputs_for_generation(
  911. inputs_embeds=inputs_embeds, input_ids=input_tokens, is_first_iteration=True, **model_kwargs
  912. )
  913. if "attention_mask" in model_inputs:
  914. model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
  915. outputs = self.model.language_model(
  916. **model_inputs,
  917. output_attentions=output_attentions,
  918. output_hidden_states=output_hidden_states,
  919. )
  920. # Update model_kwargs like attention_mask for next generation.
  921. model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
  922. hidden_state = outputs.last_hidden_state[:, -1, :].clone()
  923. # Generate scores using the generation head (Not using above defined LM Head)
  924. scores = self.model.generation_head(hidden_state)
  925. next_token_scores = logits_processor(input_ids, scores)
  926. # Sample next token.
  927. if generation_config.do_sample:
  928. probs = torch.softmax(next_token_scores, dim=-1)
  929. next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
  930. else:
  931. next_token = torch.argmax(next_token_scores, dim=-1)
  932. generated_tokens[:, i] = next_token
  933. # Prepare embeddings for the next step.
  934. next_token = torch.cat([next_token, next_token])
  935. next_token = next_token.unsqueeze(-1)
  936. inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
  937. if return_dict_in_generate:
  938. if output_scores:
  939. raw_scores += (scores,)
  940. if output_logits:
  941. raw_logits += (hidden_state.float(),)
  942. if output_attentions:
  943. decoder_attentions += outputs.attentions
  944. if output_hidden_states:
  945. decoder_hidden_states += outputs.hidden_states
  946. if return_dict_in_generate:
  947. return GenerateDecoderOnlyOutput(
  948. sequences=generated_tokens,
  949. scores=scores,
  950. logits=raw_logits,
  951. attentions=decoder_attentions,
  952. hidden_states=decoder_hidden_states,
  953. past_key_values=outputs.past_key_values,
  954. )
  955. else:
  956. return generated_tokens
  957. __all__ = [
  958. "JanusPreTrainedModel",
  959. "JanusForConditionalGeneration",
  960. "JanusModel",
  961. "JanusVQVAE",
  962. "JanusVisionModel",
  963. "JanusVQVAEConfig",
  964. "JanusVisionConfig",
  965. "JanusConfig",
  966. ]