modular_emu3.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230
  1. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from dataclasses import dataclass
  17. from functools import cached_property
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from ... import initialization as init
  22. from ...cache_utils import Cache
  23. from ...generation import GenerationMixin
  24. from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
  25. from ...modeling_utils import PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import auto_docstring, can_return_tuple, logging, torch_compilable_check
  28. from ...utils.generic import merge_with_config_defaults
  29. from ...utils.output_capturing import capture_outputs
  30. from ..chameleon.modeling_chameleon import (
  31. ChameleonPreTrainedModel,
  32. ChameleonVQVAEEncoderConvDownsample,
  33. )
  34. from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, TransformersKwargs
  35. from ..siglip.modeling_siglip import SiglipAttention
  36. from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
  37. logger = logging.get_logger(__name__)
  38. @dataclass
  39. @auto_docstring
  40. class Emu3VQVAEModelOutput(BaseModelOutputWithPooling):
  41. r"""
  42. image_tokens (`torch.LongTensor` of shape `(batch_size, config.vocab_size`):
  43. Indices of the image tokens predicted by the VQ-VAE model.
  44. """
  45. image_tokens: torch.LongTensor | None = None
  46. class Emu3Attention(LlamaAttention):
  47. pass
  48. # Has extra dropout which no other model in the library has
  49. class Emu3DecoderLayer(LlamaDecoderLayer):
  50. def __init__(self, config: Emu3Config, layer_idx: int):
  51. super().__init__(config, layer_idx)
  52. self.dropout = nn.Dropout(config.attention_dropout)
  53. def forward(
  54. self,
  55. hidden_states: torch.Tensor,
  56. attention_mask: torch.Tensor | None = None,
  57. position_ids: torch.LongTensor | None = None,
  58. past_key_values: Cache | None = None,
  59. use_cache: bool | None = False,
  60. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  61. **kwargs: Unpack[TransformersKwargs],
  62. ) -> torch.Tensor:
  63. residual = hidden_states
  64. hidden_states = self.input_layernorm(hidden_states)
  65. hidden_states, _ = self.self_attn(
  66. hidden_states=hidden_states,
  67. attention_mask=attention_mask,
  68. position_ids=position_ids,
  69. past_key_values=past_key_values,
  70. use_cache=use_cache,
  71. position_embeddings=position_embeddings,
  72. **kwargs,
  73. )
  74. hidden_states = residual + self.dropout(hidden_states)
  75. residual = hidden_states
  76. hidden_states = self.post_attention_layernorm(hidden_states)
  77. hidden_states = self.mlp(hidden_states)
  78. hidden_states = residual + self.dropout(hidden_states)
  79. return hidden_states
  80. class Emu3VQVAEVectorQuantizer(nn.Module):
  81. """
  82. A module for vector quantization using learned embedding vectors.
  83. This module implements the quantization process similar to te one described in
  84. the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
  85. input vectors into discrete codebook vectors, which are learned during training.
  86. Current implementation improves over previous ones by avoiding costly matrix multiplications
  87. and allowing for post-hoc remapping of indices.
  88. """
  89. def __init__(self, config: Emu3VQVAEConfig):
  90. super().__init__()
  91. self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
  92. self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
  93. def forward(self, hidden_state: torch.Tensor):
  94. batch_size, temporal, channels, height, width = hidden_state.shape
  95. hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous()
  96. hidden_state_flattened = hidden_state.view(-1, channels)
  97. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  98. hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
  99. embedding_sum = torch.sum(self.embedding.weight**2, dim=1)
  100. # "bd,dn->bn",
  101. distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1))
  102. distances = hidden_state_sum + embedding_sum - distances
  103. min_encoding_indices = torch.argmin(distances, dim=1)
  104. min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width)
  105. return min_encoding_indices
  106. class Emu3VQVAEEncoderConvDownsample(ChameleonVQVAEEncoderConvDownsample):
  107. pass
  108. class Emu3VQVAEEncoderConvUpsample(nn.Module):
  109. def __init__(self, in_channels):
  110. super().__init__()
  111. self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  112. def forward(self, hidden_states):
  113. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  114. hidden_states = self.conv(hidden_states)
  115. return hidden_states
  116. class Emu3VQVAEConv3d(nn.Module):
  117. def __init__(
  118. self,
  119. in_channel: int,
  120. out_channel: int,
  121. kernel_size: tuple[int],
  122. stride: tuple[int],
  123. ):
  124. super().__init__()
  125. padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])]
  126. self.padding = ()
  127. for pad_size in padding_sizes[::-1]:
  128. self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2)
  129. self.padding += (2, 0)
  130. self.conv = nn.Conv3d(
  131. in_channel,
  132. out_channel,
  133. kernel_size,
  134. stride=stride,
  135. )
  136. def forward(self, hidden_states: torch.Tensor):
  137. hidden_states = F.pad(hidden_states, self.padding)
  138. hidden_states = self.conv(hidden_states)
  139. return hidden_states
  140. class Emu3VQVAESpatialNorm(nn.Module):
  141. def __init__(
  142. self,
  143. in_channels: int,
  144. out_channels: int,
  145. ):
  146. super().__init__()
  147. self.norm_layer = nn.GroupNorm(
  148. num_channels=out_channels,
  149. num_groups=32,
  150. eps=1e-6,
  151. affine=True,
  152. )
  153. self.conv_y = nn.Conv2d(
  154. in_channels,
  155. out_channels,
  156. kernel_size=1,
  157. stride=1,
  158. padding=0,
  159. )
  160. self.conv_b = nn.Conv2d(
  161. in_channels,
  162. out_channels,
  163. kernel_size=1,
  164. stride=1,
  165. padding=0,
  166. )
  167. def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
  168. quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest")
  169. hidden_states = self.norm_layer(hidden_states)
  170. hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states)
  171. return hidden_states
  172. class Emu3VQVAETemporalUpsample(nn.Module):
  173. def __init__(
  174. self,
  175. in_channel: int,
  176. out_channel: int,
  177. ):
  178. super().__init__()
  179. self.conv = Emu3VQVAEConv3d(
  180. in_channel,
  181. out_channel,
  182. kernel_size=(3, 3, 3),
  183. stride=(1, 1, 1),
  184. )
  185. def forward(self, hidden_states: torch.Tensor):
  186. batch_size, channels, temporal, height, width = hidden_states.shape
  187. hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal)
  188. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  189. hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous()
  190. hidden_states = self.conv(hidden_states)
  191. return hidden_states
  192. class Emu3VQVAETemporalDownsample(nn.Module):
  193. def __init__(
  194. self,
  195. in_channel: int,
  196. out_channel: int,
  197. ):
  198. super().__init__()
  199. self.conv = Emu3VQVAEConv3d(
  200. in_channel,
  201. out_channel,
  202. kernel_size=(4, 3, 3),
  203. stride=(2, 1, 1),
  204. )
  205. def forward(self, hidden_states: torch.Tensor):
  206. hidden_states = self.conv(hidden_states)
  207. return hidden_states
  208. class Emu3VQVAETemporalResnetBlock(nn.Module):
  209. def __init__(
  210. self,
  211. in_channels,
  212. out_channels=None,
  213. ):
  214. super().__init__()
  215. self.in_channels = in_channels
  216. self.out_channels = in_channels if out_channels is None else out_channels
  217. self.norm1 = nn.BatchNorm3d(in_channels)
  218. self.conv1 = Emu3VQVAEConv3d(
  219. in_channels,
  220. out_channels,
  221. kernel_size=(3, 3, 3),
  222. stride=(1, 1, 1),
  223. )
  224. self.norm2 = nn.BatchNorm3d(out_channels)
  225. self.conv2 = Emu3VQVAEConv3d(
  226. out_channels,
  227. out_channels,
  228. kernel_size=(3, 3, 3),
  229. stride=(1, 1, 1),
  230. )
  231. if self.in_channels != self.out_channels:
  232. self.nin_shortcut = nn.Conv3d(
  233. in_channels,
  234. out_channels,
  235. kernel_size=1,
  236. stride=1,
  237. padding=0,
  238. )
  239. def forward(self, hidden_states):
  240. residual = hidden_states
  241. hidden_states = self.norm1(hidden_states)
  242. hidden_states *= torch.sigmoid(hidden_states)
  243. hidden_states = self.conv1(hidden_states)
  244. hidden_states = self.norm2(hidden_states)
  245. hidden_states *= torch.sigmoid(hidden_states)
  246. hidden_states = self.conv2(hidden_states)
  247. if self.in_channels != self.out_channels:
  248. residual = self.nin_shortcut(residual)
  249. return residual + hidden_states
  250. class Emu3VQVAEResnetBlock(nn.Module):
  251. def __init__(
  252. self,
  253. in_channels: int,
  254. out_channels: int | None = None,
  255. quant_channels: int | None = None,
  256. ):
  257. super().__init__()
  258. self.in_channels = in_channels
  259. out_channels = in_channels if out_channels is None else out_channels
  260. self.out_channels = out_channels
  261. self.quant_channels = quant_channels
  262. if quant_channels is None:
  263. self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
  264. self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
  265. else:
  266. self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels)
  267. self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels)
  268. self.conv1 = nn.Conv2d(
  269. in_channels,
  270. out_channels,
  271. kernel_size=3,
  272. stride=1,
  273. padding=1,
  274. )
  275. self.conv2 = nn.Conv2d(
  276. out_channels,
  277. out_channels,
  278. kernel_size=3,
  279. stride=1,
  280. padding=1,
  281. )
  282. if self.in_channels != self.out_channels:
  283. self.nin_shortcut = nn.Conv2d(
  284. in_channels,
  285. out_channels,
  286. kernel_size=1,
  287. stride=1,
  288. padding=0,
  289. )
  290. def forward(self, hidden_states: torch.Tensor, quant_channels: torch.Tensor | None = None):
  291. norm_args = () if self.quant_channels is None else (quant_channels,)
  292. residual = hidden_states
  293. hidden_states = self.norm1(hidden_states, *norm_args)
  294. hidden_states *= torch.sigmoid(hidden_states)
  295. hidden_states = self.conv1(hidden_states)
  296. hidden_states = self.norm2(hidden_states, *norm_args)
  297. hidden_states *= torch.sigmoid(hidden_states)
  298. hidden_states = self.conv2(hidden_states)
  299. if self.in_channels != self.out_channels:
  300. residual = self.nin_shortcut(residual)
  301. return residual + hidden_states
  302. class Emu3VQVAEAttentionBlock(SiglipAttention):
  303. def __init__(self, config: Emu3VQVAEConfig):
  304. super().__init__(config)
  305. # for compatibility with the attention interface
  306. self.num_key_value_groups = 1
  307. class Emu3VQVAEGroupNorm(nn.GroupNorm):
  308. """
  309. Same as the torch GroupNorm with the only difference that this ones accepts
  310. an optional kwarg `quant_states` which is not used. This class makes it easier to
  311. use SpatialNorm or GroupNorm without conditionals
  312. """
  313. def __init__(self, **kwargs):
  314. super().__init__(**kwargs)
  315. def forward(self, input, quant_states=None):
  316. return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
  317. class Emu3VQVAEMiddleBlock(nn.Module):
  318. def __init__(self, config, in_channels, quant_channels=None):
  319. super().__init__()
  320. self.block_1 = Emu3VQVAEResnetBlock(
  321. in_channels=in_channels,
  322. out_channels=in_channels,
  323. quant_channels=quant_channels,
  324. )
  325. self.attn_1 = Emu3VQVAEAttentionBlock(config)
  326. if quant_channels is None:
  327. self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
  328. else:
  329. self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels)
  330. self.block_2 = Emu3VQVAEResnetBlock(
  331. in_channels=in_channels,
  332. out_channels=in_channels,
  333. quant_channels=quant_channels,
  334. )
  335. def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor | None = None):
  336. hidden_states = self.block_1(hidden_states, quant_states)
  337. residual = hidden_states
  338. hidden_states = self.attn_norm(hidden_states, quant_states)
  339. batch_size, channels, height, width = hidden_states.shape
  340. hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
  341. hidden_states = self.attn_1(hidden_states)[0]
  342. hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
  343. hidden_states = residual + hidden_states
  344. hidden_states = self.block_2(hidden_states, quant_states)
  345. return hidden_states
  346. class Emu3VQVAEDownBlock(nn.Module):
  347. def __init__(self, config):
  348. super().__init__()
  349. self.num_resolutions = len(config.channel_multiplier)
  350. self.num_res_blocks = config.num_res_blocks
  351. base_channels = config.base_channels
  352. channel_multiplier = config.channel_multiplier
  353. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  354. self.in_channel_multiplier = in_channel_multiplier
  355. self.down = nn.ModuleList()
  356. for i_level in range(self.num_resolutions):
  357. block = nn.ModuleList()
  358. attn = nn.ModuleList()
  359. attn_norms = nn.ModuleList()
  360. block_in = base_channels * in_channel_multiplier[i_level]
  361. block_out = base_channels * channel_multiplier[i_level]
  362. for i_block in range(self.num_res_blocks):
  363. block.append(
  364. Emu3VQVAEResnetBlock(
  365. in_channels=block_in,
  366. out_channels=block_out,
  367. )
  368. )
  369. block_in = block_out
  370. if config.attn_resolutions is not None and i_level in config.attn_resolutions:
  371. attn.append(Emu3VQVAEAttentionBlock(config))
  372. attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True))
  373. down = nn.Module()
  374. down.block = block
  375. down.attn = attn
  376. down.attn_norms = attn_norms
  377. if i_level != self.num_resolutions - 1:
  378. down.downsample = Emu3VQVAEEncoderConvDownsample(block_in)
  379. self.down.append(down)
  380. def forward(self, hidden_states: torch.FloatTensor):
  381. for i_level, blocks in enumerate(self.down):
  382. for i_block in range(self.num_res_blocks):
  383. hidden_states = blocks.block[i_block](hidden_states)
  384. if len(blocks.attn) > 0:
  385. residual = hidden_states
  386. hidden_states = blocks.attn_norms[i_block](hidden_states)
  387. batch_size, channels, height, width = hidden_states.shape
  388. hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
  389. hidden_states = blocks.attn[i_block](hidden_states)[0]
  390. hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
  391. hidden_states = residual + hidden_states
  392. if i_level != self.num_resolutions - 1:
  393. hidden_states = blocks.downsample(hidden_states)
  394. return hidden_states
  395. class Emu3VQVAEUpBlock(nn.Module):
  396. def __init__(self, config):
  397. super().__init__()
  398. self.num_resolutions = len(config.channel_multiplier)
  399. self.num_res_blocks = config.num_res_blocks
  400. quant_channels = config.embed_dim
  401. block_in = config.base_channels * config.channel_multiplier[-1]
  402. self.up = nn.ModuleList()
  403. for i_level in reversed(range(self.num_resolutions)):
  404. block = nn.ModuleList()
  405. attn = nn.ModuleList()
  406. attn_norms = nn.ModuleList()
  407. block_out = config.base_channels * config.channel_multiplier[i_level]
  408. for i_block in range(self.num_res_blocks + 1):
  409. block.append(
  410. Emu3VQVAEResnetBlock(
  411. in_channels=block_in,
  412. out_channels=block_out,
  413. quant_channels=quant_channels,
  414. )
  415. )
  416. block_in = block_out
  417. if i_level in config.attn_resolutions:
  418. attn.append(Emu3VQVAEAttentionBlock(config))
  419. attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in))
  420. up = nn.Module()
  421. up.block = block
  422. up.attn = attn
  423. up.attn_norms = attn_norms
  424. if i_level != 0:
  425. up.upsample = Emu3VQVAEEncoderConvUpsample(block_in)
  426. self.up.insert(0, up)
  427. def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor):
  428. for i_level, blocks in enumerate(self.up[::-1]):
  429. for i_block in range(self.num_res_blocks + 1):
  430. hidden_states = blocks.block[i_block](hidden_states, quant_states)
  431. if len(blocks.attn) > 0:
  432. residual = hidden_states
  433. hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states)
  434. batch_size, channels, height, width = hidden_states.shape
  435. hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
  436. hidden_states = blocks.attn[i_block](hidden_states)[0]
  437. hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
  438. hidden_states = residual + hidden_states
  439. if i_level != len(self.up) - 1:
  440. hidden_states = blocks.upsample(hidden_states)
  441. return hidden_states
  442. class Emu3VQVAEEncoder(nn.Module):
  443. def __init__(self, config):
  444. super().__init__()
  445. base_channels = config.base_channels
  446. in_channels = config.in_channels
  447. double_latent = config.double_latent
  448. latent_channels = config.latent_channels
  449. channel_multiplier = config.channel_multiplier
  450. out_channels = 2 * latent_channels if double_latent else latent_channels
  451. block_in = base_channels * channel_multiplier[-1]
  452. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  453. self.down_block = Emu3VQVAEDownBlock(config)
  454. self.middle_block = Emu3VQVAEMiddleBlock(config, block_in)
  455. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  456. self.conv_out = torch.nn.Conv2d(
  457. block_in,
  458. out_channels,
  459. kernel_size=3,
  460. stride=1,
  461. padding=1,
  462. )
  463. temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
  464. self.time_conv = nn.ModuleList()
  465. self.time_res_stack = nn.ModuleList()
  466. for i in range(temporal_down_blocks):
  467. conv = Emu3VQVAETemporalDownsample(out_channels, out_channels)
  468. self.time_conv.append(conv)
  469. for _ in range(config.num_res_blocks):
  470. time_res_conv = Emu3VQVAETemporalResnetBlock(
  471. in_channels=out_channels,
  472. out_channels=out_channels,
  473. )
  474. self.time_res_stack.append(time_res_conv)
  475. def forward(self, pixel_values: torch.LongTensor):
  476. temporal_dim = pixel_values.shape[1]
  477. pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
  478. # downsampling & middle
  479. hidden_states = self.conv_in(pixel_values)
  480. hidden_states = self.down_block(hidden_states)
  481. hidden_states = self.middle_block(hidden_states)
  482. # end
  483. hidden_states = self.norm_out(hidden_states)
  484. hidden_states *= torch.sigmoid(hidden_states)
  485. hidden_states = self.conv_out(hidden_states)
  486. hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:])
  487. hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
  488. # temporal convs
  489. for conv in self.time_conv:
  490. hidden_states = conv(hidden_states)
  491. hidden_states *= torch.sigmoid(hidden_states)
  492. for layer in self.time_res_stack:
  493. hidden_states = layer(hidden_states)
  494. hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
  495. return hidden_states
  496. class Emu3VQVAEDecoder(nn.Module):
  497. def __init__(self, config: Emu3VQVAEConfig):
  498. super().__init__()
  499. quant_channels = config.embed_dim
  500. block_in = config.base_channels * config.channel_multiplier[-1]
  501. self.time_res_stack = nn.ModuleList()
  502. for _ in range(config.num_res_blocks):
  503. time_res_conv = Emu3VQVAETemporalResnetBlock(
  504. in_channels=config.latent_channels, out_channels=config.latent_channels
  505. )
  506. self.time_res_stack.append(time_res_conv)
  507. temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
  508. self.time_conv = nn.ModuleList()
  509. for i in range(temp_upsample_block_num):
  510. conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels)
  511. self.time_conv.append(conv)
  512. self.conv_in = nn.Conv2d(
  513. config.latent_channels,
  514. block_in,
  515. kernel_size=3,
  516. stride=1,
  517. padding=1,
  518. )
  519. self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels)
  520. self.up_block = Emu3VQVAEUpBlock(config)
  521. block_in = config.base_channels * config.channel_multiplier[0]
  522. self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in)
  523. self.conv_out = nn.Conv2d(
  524. block_in,
  525. config.out_channels,
  526. kernel_size=3,
  527. stride=1,
  528. padding=1,
  529. )
  530. def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
  531. hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0)
  532. hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
  533. # temporal convs
  534. for layer in self.time_res_stack:
  535. hidden_quant_states = layer(hidden_quant_states)
  536. for layer in self.time_conv:
  537. hidden_quant_states = layer(hidden_quant_states)
  538. hidden_quant_states *= torch.sigmoid(hidden_quant_states)
  539. hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
  540. hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0)
  541. hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:])
  542. quant_states = quant_states.reshape(-1, *quant_states.shape[2:])
  543. hidden_states = self.conv_in(hidden_states)
  544. # middle & upsampling
  545. hidden_states = self.middle_block(hidden_states, quant_states)
  546. hidden_states = self.up_block(hidden_states, quant_states)
  547. hidden_states = self.norm_out(hidden_states, quant_states)
  548. hidden_states *= torch.sigmoid(hidden_states)
  549. hidden_states = self.conv_out(hidden_states)
  550. return hidden_states
  551. @auto_docstring(
  552. custom_intro="""
  553. The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens.
  554. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
  555. [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
  556. Taigman](https://huggingface.co/papers/2203.13131).
  557. """
  558. )
  559. class Emu3VQVAE(PreTrainedModel):
  560. config: Emu3VQVAEConfig
  561. base_model_prefix = "emuvideovq"
  562. main_input_name = "pixel_values"
  563. input_modalities = ("image",)
  564. _supports_sdpa = True
  565. _supports_flash_attn = True
  566. _supports_flex_attn = True
  567. _supports_attention_backend = True
  568. _no_split_modules = [
  569. "Emu3VQVAETemporalResnetBlock",
  570. "Emu3VQVAEAttentionBlock",
  571. "Emu3VQVAEResnetBlock",
  572. "Emu3VQVAEVectorQuantizer",
  573. ]
  574. _can_record_outputs = {
  575. "hidden_states": [Emu3VQVAEResnetBlock, Emu3VQVAETemporalResnetBlock],
  576. "attentions": Emu3VQVAEAttentionBlock,
  577. }
  578. @torch.no_grad()
  579. def _init_weights(self, module):
  580. if isinstance(module, (nn.Conv2d, nn.Conv3d)):
  581. init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  582. if module.bias is not None:
  583. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  584. bound = 1 / math.sqrt(fan_in)
  585. init.uniform_(module.bias, -bound, bound)
  586. elif isinstance(module, nn.Linear):
  587. init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  588. if module.bias is not None:
  589. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  590. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  591. init.uniform_(module.bias, -bound, bound)
  592. elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
  593. init.constant_(module.weight, 1.0)
  594. init.constant_(module.bias, 0.0)
  595. if getattr(module, "running_mean", None) is not None:
  596. init.zeros_(module.running_mean)
  597. init.ones_(module.running_var)
  598. init.zeros_(module.num_batches_tracked)
  599. elif isinstance(module, nn.Embedding):
  600. init.normal_(module.weight)
  601. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  602. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  603. init.zeros_(module.weight[module.padding_idx])
  604. def __init__(self, config: Emu3VQVAEConfig):
  605. super().__init__(config)
  606. self.config = config
  607. self.encoder = Emu3VQVAEEncoder(config)
  608. self.decoder = Emu3VQVAEDecoder(config)
  609. self.quantize = Emu3VQVAEVectorQuantizer(config)
  610. self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1)
  611. self.quant_conv = Emu3VQVAEConv3d(
  612. config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1)
  613. )
  614. self.post_quant_conv = Emu3VQVAEConv3d(
  615. config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1)
  616. )
  617. self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1)
  618. self.eval() # Emu3's VQ model is frozen
  619. self.post_init()
  620. @merge_with_config_defaults
  621. @capture_outputs
  622. def encode(
  623. self, pixel_values: torch.Tensor, image_sizes: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
  624. ) -> Emu3VQVAEModelOutput:
  625. is_image = pixel_values.ndim == 4
  626. if is_image:
  627. temporal = self.config.temporal_downsample_factor
  628. batch_size, channels, height, width = pixel_values.shape
  629. pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1)
  630. else:
  631. batch_size, temporal, channels, height, width = pixel_values.shape
  632. hidden_states = self.encoder(pixel_values)
  633. # b t c h w -> b c t h w
  634. conv_hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
  635. conv_hidden_states = self.quant_conv(conv_hidden_states)
  636. # b c t h w -> b t c h w
  637. conv_hidden_states = conv_hidden_states.permute(0, 2, 1, 3, 4)
  638. codes = self.quantize(conv_hidden_states)
  639. image_tokens = codes.squeeze(1) if is_image else codes
  640. image_tokens = [
  641. single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)]
  642. for single_image, size in zip(image_tokens, image_sizes)
  643. ]
  644. return Emu3VQVAEModelOutput(
  645. last_hidden_state=hidden_states,
  646. image_tokens=image_tokens,
  647. )
  648. def decode(self, hidden_states: torch.Tensor):
  649. is_image = hidden_states.ndim == 3
  650. if is_image:
  651. hidden_states = hidden_states.unsqueeze(1)
  652. batch_size, temporal, height, width = hidden_states.shape
  653. quant = self.quantize.embedding(hidden_states.flatten())
  654. channels = quant.shape[-1]
  655. quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous()
  656. post_quant = self.post_quant_conv(quant)
  657. quant = quant.permute(0, 2, 1, 3, 4)
  658. post_quant = post_quant.permute(0, 2, 1, 3, 4)
  659. video = self.decoder(post_quant, quant)
  660. video = video.reshape(
  661. batch_size,
  662. temporal * self.config.temporal_downsample_factor,
  663. self.config.out_channels,
  664. height * self.spatial_scale_factor,
  665. width * self.spatial_scale_factor,
  666. )
  667. return video[:, 0] if is_image else video
  668. class Emu3ImageVocabularyMapping:
  669. """
  670. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  671. """
  672. def __init__(self, vocab_map):
  673. self.vocab_map = vocab_map
  674. self.eol_token_id = vocab_map.get("<|extra_200|>")
  675. self.image_token_id = vocab_map.get("<image>")
  676. @cached_property
  677. def image_tokens(self):
  678. return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
  679. @cached_property
  680. def image_tokens_str(self):
  681. return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
  682. @cached_property
  683. def img2bpe(self):
  684. return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str}
  685. @cached_property
  686. def bpe2img(self):
  687. return {v: k for k, v in self.img2bpe.items()}
  688. @cached_property
  689. def bpe2img_mapping_tensor(self):
  690. mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int)
  691. for k, v in self.bpe2img.items():
  692. mapping[k] = v
  693. return mapping
  694. @cached_property
  695. def img2bpe_mapping_tensor(self):
  696. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  697. for k, v in self.img2bpe.items():
  698. mapping[k] = v
  699. return mapping
  700. def convert_img2bpe(self, img_batch: list[torch.Tensor]) -> torch.Tensor:
  701. device = img_batch.device
  702. eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id
  703. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  704. img_tokens = torch.cat([img_tokens, eol_row], dim=-1)
  705. return img_tokens.to(device)
  706. def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor:
  707. device = img_batch.device
  708. img_batch = img_batch[..., :-1] # remove last row of EOL tokens
  709. img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")]
  710. return img_tokens.to(device)
  711. class Emu3PreTrainedModel(ChameleonPreTrainedModel):
  712. _no_split_modules = [
  713. "Emu3DecoderLayer",
  714. ]
  715. _supports_flex_attn = True
  716. _supports_attention_backend = True
  717. _can_record_outputs = {
  718. "hidden_states": Emu3DecoderLayer,
  719. "attentions": Emu3Attention,
  720. }
  721. class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
  722. config: Emu3TextConfig
  723. def __init__(self, config: Emu3TextConfig):
  724. super().__init__(config)
  725. self.layers = nn.ModuleList(
  726. [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  727. )
  728. class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
  729. config: Emu3TextConfig
  730. def __init__(self, config):
  731. super().__init__(config)
  732. self.model = Emu3TextModel(config)
  733. def forward(**super_kwargs):
  734. r"""
  735. Example:
  736. ```python
  737. >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
  738. >>> import torch
  739. >>> import httpx
  740. >>> from io import BytesIO
  741. >>> from PIL import Image
  742. >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
  743. >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
  744. >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device)
  745. >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
  746. >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  747. ```"""
  748. super().forward()
  749. class Emu3Model(Emu3PreTrainedModel):
  750. def __init__(self, config):
  751. super().__init__(config)
  752. self.text_model = Emu3TextModel._from_config(config.text_config)
  753. self.vqmodel = Emu3VQVAE(config.vq_config)
  754. self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
  755. # Initialize weights and apply final processing
  756. self.post_init()
  757. def get_input_embeddings(self):
  758. return self.text_model.get_input_embeddings()
  759. def set_input_embeddings(self, value):
  760. self.text_model.set_input_embeddings(value)
  761. def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor) -> torch.LongTensor:
  762. """
  763. Tokenizes images into discrete tokens with VQGAN module. Converts
  764. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  765. special tokens.
  766. Args:
  767. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  768. The tensors corresponding to the input images.
  769. image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
  770. The sizes of the images in the batch, being (height, width) for each image.
  771. """
  772. vqmodel_outputs: Emu3VQVAEModelOutput = self.vqmodel.encode(pixel_values, image_sizes, return_dict=True)
  773. bpe_tokens_list = [
  774. self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in vqmodel_outputs.image_tokens
  775. ]
  776. bpe_tokens = torch.cat(bpe_tokens_list)
  777. return bpe_tokens
  778. @can_return_tuple
  779. @auto_docstring(
  780. custom_intro="Tokenizes images into discrete tokens with VQGAN module and embeds them with text embeddings layer"
  781. )
  782. def get_image_features(
  783. self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor, **kwargs: Unpack[TransformersKwargs]
  784. ) -> tuple | Emu3VQVAEModelOutput:
  785. r"""
  786. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
  787. The tensors corresponding to the input images.
  788. """
  789. vqmodel_outputs: Emu3VQVAEModelOutput = self.vqmodel.encode(
  790. pixel_values, image_sizes, return_dict=True, **kwargs
  791. )
  792. split_sizes = [
  793. (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1)
  794. for height, width in image_sizes
  795. ]
  796. bpe_tokens_list = [
  797. self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in vqmodel_outputs.image_tokens
  798. ]
  799. bpe_tokens = torch.cat(bpe_tokens_list)
  800. image_embeddings = self.get_input_embeddings()(bpe_tokens)
  801. image_features = torch.split(image_embeddings, split_sizes)
  802. vqmodel_outputs.pooler_output = image_features
  803. return vqmodel_outputs
  804. @torch.no_grad()
  805. def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
  806. """
  807. Decodes generated image tokens from language model to continuous pixel values
  808. with VQGAN module via upsampling.
  809. Args:
  810. image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
  811. The tensors corresponding to the input images.
  812. height (`int`):
  813. Height of the generated image before upsampling.
  814. width (`int`):
  815. Width of the generated image before upsampling.
  816. """
  817. sequences = image_tokens[:, :-3].view(-1, height, width + 1)
  818. image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences)
  819. image = self.vqmodel.decode(image_tokens)
  820. return image
  821. def get_placeholder_mask(
  822. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  823. ):
  824. """
  825. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  826. equal to the length of multimodal features. If the lengths are different, an error is raised.
  827. """
  828. if input_ids is None:
  829. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  830. torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  831. )
  832. special_image_mask = special_image_mask.all(-1)
  833. else:
  834. special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
  835. n_image_tokens = special_image_mask.sum()
  836. n_image_features = image_features.shape[0] * image_features.shape[1]
  837. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  838. torch_compilable_check(
  839. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  840. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
  841. )
  842. return special_image_mask
  843. @can_return_tuple
  844. @auto_docstring
  845. def forward(
  846. self,
  847. input_ids: torch.LongTensor | None = None,
  848. pixel_values: torch.FloatTensor | None = None,
  849. image_sizes: torch.Tensor | None = None,
  850. attention_mask: torch.Tensor | None = None,
  851. position_ids: torch.LongTensor | None = None,
  852. past_key_values: Cache | None = None,
  853. inputs_embeds: torch.FloatTensor | None = None,
  854. use_cache: bool | None = None,
  855. **kwargs: Unpack[TransformersKwargs],
  856. ) -> tuple | CausalLMOutputWithPast:
  857. r"""
  858. image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
  859. The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
  860. [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
  861. [`Emu3ImageProcessor`] for processing images).
  862. """
  863. if (input_ids is None) ^ (inputs_embeds is not None):
  864. raise ValueError(
  865. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  866. )
  867. if inputs_embeds is None:
  868. inputs_embeds = self.get_input_embeddings()(input_ids)
  869. if pixel_values is not None:
  870. image_features = self.get_image_features(pixel_values, image_sizes).pooler_output
  871. image_features = torch.cat(image_features, dim=0)
  872. special_image_mask = self.get_placeholder_mask(
  873. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  874. )
  875. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  876. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  877. outputs = self.text_model(
  878. attention_mask=attention_mask,
  879. position_ids=position_ids,
  880. past_key_values=past_key_values,
  881. inputs_embeds=inputs_embeds,
  882. use_cache=use_cache,
  883. **kwargs,
  884. )
  885. return outputs
  886. class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
  887. output_modalities = ("image", "text")
  888. _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
  889. def __init__(self, config):
  890. super().__init__(config)
  891. self.model = Emu3Model(config)
  892. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  893. self.post_init()
  894. def get_input_embeddings(self):
  895. return self.model.get_input_embeddings()
  896. def set_input_embeddings(self, value):
  897. self.model.set_input_embeddings(value)
  898. def get_output_embeddings(self) -> nn.Module:
  899. return self.lm_head
  900. def decode_image_tokens(self, **kwargs):
  901. return self.model.decode_image_tokens(**kwargs)
  902. @can_return_tuple
  903. @auto_docstring
  904. def forward(
  905. self,
  906. input_ids: torch.LongTensor | None = None,
  907. pixel_values: torch.FloatTensor | None = None,
  908. image_sizes: torch.Tensor | None = None,
  909. attention_mask: torch.Tensor | None = None,
  910. position_ids: torch.LongTensor | None = None,
  911. past_key_values: Cache | None = None,
  912. inputs_embeds: torch.FloatTensor | None = None,
  913. use_cache: bool | None = None,
  914. labels: torch.LongTensor | None = None,
  915. logits_to_keep: int | torch.Tensor = 0,
  916. **kwargs: Unpack[TransformersKwargs],
  917. ) -> tuple | CausalLMOutputWithPast:
  918. r"""
  919. image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
  920. The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
  921. [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
  922. [`Emu3ImageProcessor`] for processing images).
  923. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  924. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  925. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  926. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  927. Example:
  928. ```python
  929. >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
  930. >>> import torch
  931. >>> import httpx
  932. >>> from io import BytesIO
  933. >>> from PIL import Image
  934. >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
  935. >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
  936. >>> conversation = [
  937. ... {
  938. ... "role": "system",
  939. ... "content": [
  940. ... {"type": "text", "text": "You are a helpful assistant."},
  941. ... ],
  942. ... },
  943. ... {
  944. ... "role": "user",
  945. ... "content": [
  946. ... {"type": "image"},
  947. ... {"type": "text", "text": "Please describe the image."},
  948. ... ],
  949. ... },
  950. ... ]
  951. >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
  952. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  953. >>> with httpx.stream("GET", url) as response:
  954. ... image = Image.open(BytesIO(response.read()))
  955. >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16)
  956. >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
  957. >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  958. ```"""
  959. outputs = self.model(
  960. input_ids=input_ids,
  961. attention_mask=attention_mask,
  962. position_ids=position_ids,
  963. past_key_values=past_key_values,
  964. inputs_embeds=inputs_embeds,
  965. use_cache=use_cache,
  966. **kwargs,
  967. )
  968. hidden_states = outputs[0]
  969. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  970. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  971. logits = self.lm_head(hidden_states[:, slice_indices, :])
  972. loss = None
  973. if labels is not None:
  974. loss = self.loss_function(
  975. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  976. )
  977. return CausalLMOutputWithPast(
  978. loss=loss,
  979. logits=logits,
  980. past_key_values=outputs.past_key_values,
  981. hidden_states=outputs.hidden_states,
  982. attentions=outputs.attentions,
  983. )
  984. def prepare_inputs_for_generation(
  985. self,
  986. input_ids,
  987. past_key_values=None,
  988. attention_mask=None,
  989. inputs_embeds=None,
  990. position_ids=None,
  991. use_cache=True,
  992. pixel_values=None,
  993. is_first_iteration=False,
  994. **kwargs,
  995. ):
  996. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  997. model_inputs = super().prepare_inputs_for_generation(
  998. input_ids,
  999. past_key_values=past_key_values,
  1000. attention_mask=attention_mask,
  1001. inputs_embeds=inputs_embeds,
  1002. position_ids=position_ids,
  1003. pixel_values=pixel_values,
  1004. use_cache=use_cache,
  1005. is_first_iteration=is_first_iteration,
  1006. **kwargs,
  1007. )
  1008. if not is_first_iteration and use_cache:
  1009. model_inputs["pixel_values"] = None
  1010. return model_inputs
  1011. __all__ = [
  1012. "Emu3ForConditionalGeneration",
  1013. "Emu3ForCausalLM",
  1014. "Emu3TextModel",
  1015. "Emu3PreTrainedModel",
  1016. "Emu3VQVAE",
  1017. "Emu3Model",
  1018. ]