| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/emu3/modular_emu3.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_emu3.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2024 HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- from functools import cached_property
- from typing import Optional
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
- from ...masking_utils import create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
- @dataclass
- @auto_docstring
- class Emu3VQVAEModelOutput(BaseModelOutputWithPooling):
- r"""
- image_tokens (`torch.LongTensor` of shape `(batch_size, config.vocab_size`):
- Indices of the image tokens predicted by the VQ-VAE model.
- """
- image_tokens: torch.LongTensor | None = None
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- @use_kernel_func_from_hub("rotary_pos_emb")
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- @use_kernelized_func(apply_rotary_pos_emb)
- class Emu3Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Emu3Config, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- @use_kernel_forward_from_hub("RMSNorm")
- class Emu3RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
- """
- Emu3RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- class Emu3MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- class Emu3DecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Emu3Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = Emu3Attention(config=config, layer_idx=layer_idx)
- self.mlp = Emu3MLP(config)
- self.input_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.dropout = nn.Dropout(config.attention_dropout)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = False,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + self.dropout(hidden_states)
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- return hidden_states
- class Emu3VQVAEVectorQuantizer(nn.Module):
- """
- A module for vector quantization using learned embedding vectors.
- This module implements the quantization process similar to te one described in
- the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
- input vectors into discrete codebook vectors, which are learned during training.
- Current implementation improves over previous ones by avoiding costly matrix multiplications
- and allowing for post-hoc remapping of indices.
- """
- def __init__(self, config: Emu3VQVAEConfig):
- super().__init__()
- self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
- self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
- def forward(self, hidden_state: torch.Tensor):
- batch_size, temporal, channels, height, width = hidden_state.shape
- hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous()
- hidden_state_flattened = hidden_state.view(-1, channels)
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
- hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
- embedding_sum = torch.sum(self.embedding.weight**2, dim=1)
- # "bd,dn->bn",
- distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1))
- distances = hidden_state_sum + embedding_sum - distances
- min_encoding_indices = torch.argmin(distances, dim=1)
- min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width)
- return min_encoding_indices
- class Emu3VQVAEEncoderConvDownsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
- def forward(self, hidden_states):
- # no asymmetric padding in torch conv, must do it ourselves
- hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
- hidden_states = self.conv(hidden_states)
- return hidden_states
- class Emu3VQVAEEncoderConvUpsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
- def forward(self, hidden_states):
- hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
- hidden_states = self.conv(hidden_states)
- return hidden_states
- class Emu3VQVAEConv3d(nn.Module):
- def __init__(
- self,
- in_channel: int,
- out_channel: int,
- kernel_size: tuple[int],
- stride: tuple[int],
- ):
- super().__init__()
- padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])]
- self.padding = ()
- for pad_size in padding_sizes[::-1]:
- self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2)
- self.padding += (2, 0)
- self.conv = nn.Conv3d(
- in_channel,
- out_channel,
- kernel_size,
- stride=stride,
- )
- def forward(self, hidden_states: torch.Tensor):
- hidden_states = F.pad(hidden_states, self.padding)
- hidden_states = self.conv(hidden_states)
- return hidden_states
- class Emu3VQVAESpatialNorm(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- ):
- super().__init__()
- self.norm_layer = nn.GroupNorm(
- num_channels=out_channels,
- num_groups=32,
- eps=1e-6,
- affine=True,
- )
- self.conv_y = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
- self.conv_b = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
- def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
- quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest")
- hidden_states = self.norm_layer(hidden_states)
- hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states)
- return hidden_states
- class Emu3VQVAETemporalUpsample(nn.Module):
- def __init__(
- self,
- in_channel: int,
- out_channel: int,
- ):
- super().__init__()
- self.conv = Emu3VQVAEConv3d(
- in_channel,
- out_channel,
- kernel_size=(3, 3, 3),
- stride=(1, 1, 1),
- )
- def forward(self, hidden_states: torch.Tensor):
- batch_size, channels, temporal, height, width = hidden_states.shape
- hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal)
- hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
- hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous()
- hidden_states = self.conv(hidden_states)
- return hidden_states
- class Emu3VQVAETemporalDownsample(nn.Module):
- def __init__(
- self,
- in_channel: int,
- out_channel: int,
- ):
- super().__init__()
- self.conv = Emu3VQVAEConv3d(
- in_channel,
- out_channel,
- kernel_size=(4, 3, 3),
- stride=(2, 1, 1),
- )
- def forward(self, hidden_states: torch.Tensor):
- hidden_states = self.conv(hidden_states)
- return hidden_states
- class Emu3VQVAETemporalResnetBlock(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels=None,
- ):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.norm1 = nn.BatchNorm3d(in_channels)
- self.conv1 = Emu3VQVAEConv3d(
- in_channels,
- out_channels,
- kernel_size=(3, 3, 3),
- stride=(1, 1, 1),
- )
- self.norm2 = nn.BatchNorm3d(out_channels)
- self.conv2 = Emu3VQVAEConv3d(
- out_channels,
- out_channels,
- kernel_size=(3, 3, 3),
- stride=(1, 1, 1),
- )
- if self.in_channels != self.out_channels:
- self.nin_shortcut = nn.Conv3d(
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
- def forward(self, hidden_states):
- residual = hidden_states
- hidden_states = self.norm1(hidden_states)
- hidden_states *= torch.sigmoid(hidden_states)
- hidden_states = self.conv1(hidden_states)
- hidden_states = self.norm2(hidden_states)
- hidden_states *= torch.sigmoid(hidden_states)
- hidden_states = self.conv2(hidden_states)
- if self.in_channels != self.out_channels:
- residual = self.nin_shortcut(residual)
- return residual + hidden_states
- class Emu3VQVAEResnetBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int | None = None,
- quant_channels: int | None = None,
- ):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.quant_channels = quant_channels
- if quant_channels is None:
- self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
- self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
- else:
- self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels)
- self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels)
- self.conv1 = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- self.conv2 = nn.Conv2d(
- out_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- if self.in_channels != self.out_channels:
- self.nin_shortcut = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
- def forward(self, hidden_states: torch.Tensor, quant_channels: torch.Tensor | None = None):
- norm_args = () if self.quant_channels is None else (quant_channels,)
- residual = hidden_states
- hidden_states = self.norm1(hidden_states, *norm_args)
- hidden_states *= torch.sigmoid(hidden_states)
- hidden_states = self.conv1(hidden_states)
- hidden_states = self.norm2(hidden_states, *norm_args)
- hidden_states *= torch.sigmoid(hidden_states)
- hidden_states = self.conv2(hidden_states)
- if self.in_channels != self.out_channels:
- residual = self.nin_shortcut(residual)
- return residual + hidden_states
- class Emu3VQVAEAttentionBlock(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Emu3VQVAEConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.is_causal = False
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- # for compatibility with the attention interface
- self.num_key_value_groups = 1
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- """Input shape: Batch x Time x Channel"""
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- queries,
- keys,
- values,
- attention_mask,
- is_causal=self.is_causal,
- scaling=self.scale,
- dropout=0.0 if not self.training else self.dropout,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- class Emu3VQVAEGroupNorm(nn.GroupNorm):
- """
- Same as the torch GroupNorm with the only difference that this ones accepts
- an optional kwarg `quant_states` which is not used. This class makes it easier to
- use SpatialNorm or GroupNorm without conditionals
- """
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def forward(self, input, quant_states=None):
- return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
- class Emu3VQVAEMiddleBlock(nn.Module):
- def __init__(self, config, in_channels, quant_channels=None):
- super().__init__()
- self.block_1 = Emu3VQVAEResnetBlock(
- in_channels=in_channels,
- out_channels=in_channels,
- quant_channels=quant_channels,
- )
- self.attn_1 = Emu3VQVAEAttentionBlock(config)
- if quant_channels is None:
- self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
- else:
- self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels)
- self.block_2 = Emu3VQVAEResnetBlock(
- in_channels=in_channels,
- out_channels=in_channels,
- quant_channels=quant_channels,
- )
- def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor | None = None):
- hidden_states = self.block_1(hidden_states, quant_states)
- residual = hidden_states
- hidden_states = self.attn_norm(hidden_states, quant_states)
- batch_size, channels, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
- hidden_states = self.attn_1(hidden_states)[0]
- hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
- hidden_states = residual + hidden_states
- hidden_states = self.block_2(hidden_states, quant_states)
- return hidden_states
- class Emu3VQVAEDownBlock(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.num_resolutions = len(config.channel_multiplier)
- self.num_res_blocks = config.num_res_blocks
- base_channels = config.base_channels
- channel_multiplier = config.channel_multiplier
- in_channel_multiplier = (1,) + tuple(channel_multiplier)
- self.in_channel_multiplier = in_channel_multiplier
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- attn_norms = nn.ModuleList()
- block_in = base_channels * in_channel_multiplier[i_level]
- block_out = base_channels * channel_multiplier[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(
- Emu3VQVAEResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- )
- )
- block_in = block_out
- if config.attn_resolutions is not None and i_level in config.attn_resolutions:
- attn.append(Emu3VQVAEAttentionBlock(config))
- attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True))
- down = nn.Module()
- down.block = block
- down.attn = attn
- down.attn_norms = attn_norms
- if i_level != self.num_resolutions - 1:
- down.downsample = Emu3VQVAEEncoderConvDownsample(block_in)
- self.down.append(down)
- def forward(self, hidden_states: torch.FloatTensor):
- for i_level, blocks in enumerate(self.down):
- for i_block in range(self.num_res_blocks):
- hidden_states = blocks.block[i_block](hidden_states)
- if len(blocks.attn) > 0:
- residual = hidden_states
- hidden_states = blocks.attn_norms[i_block](hidden_states)
- batch_size, channels, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
- hidden_states = blocks.attn[i_block](hidden_states)[0]
- hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
- hidden_states = residual + hidden_states
- if i_level != self.num_resolutions - 1:
- hidden_states = blocks.downsample(hidden_states)
- return hidden_states
- class Emu3VQVAEUpBlock(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.num_resolutions = len(config.channel_multiplier)
- self.num_res_blocks = config.num_res_blocks
- quant_channels = config.embed_dim
- block_in = config.base_channels * config.channel_multiplier[-1]
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- attn_norms = nn.ModuleList()
- block_out = config.base_channels * config.channel_multiplier[i_level]
- for i_block in range(self.num_res_blocks + 1):
- block.append(
- Emu3VQVAEResnetBlock(
- in_channels=block_in,
- out_channels=block_out,
- quant_channels=quant_channels,
- )
- )
- block_in = block_out
- if i_level in config.attn_resolutions:
- attn.append(Emu3VQVAEAttentionBlock(config))
- attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in))
- up = nn.Module()
- up.block = block
- up.attn = attn
- up.attn_norms = attn_norms
- if i_level != 0:
- up.upsample = Emu3VQVAEEncoderConvUpsample(block_in)
- self.up.insert(0, up)
- def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor):
- for i_level, blocks in enumerate(self.up[::-1]):
- for i_block in range(self.num_res_blocks + 1):
- hidden_states = blocks.block[i_block](hidden_states, quant_states)
- if len(blocks.attn) > 0:
- residual = hidden_states
- hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states)
- batch_size, channels, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
- hidden_states = blocks.attn[i_block](hidden_states)[0]
- hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
- hidden_states = residual + hidden_states
- if i_level != len(self.up) - 1:
- hidden_states = blocks.upsample(hidden_states)
- return hidden_states
- class Emu3VQVAEEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- base_channels = config.base_channels
- in_channels = config.in_channels
- double_latent = config.double_latent
- latent_channels = config.latent_channels
- channel_multiplier = config.channel_multiplier
- out_channels = 2 * latent_channels if double_latent else latent_channels
- block_in = base_channels * channel_multiplier[-1]
- self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
- self.down_block = Emu3VQVAEDownBlock(config)
- self.middle_block = Emu3VQVAEMiddleBlock(config, block_in)
- self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
- self.conv_out = torch.nn.Conv2d(
- block_in,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
- self.time_conv = nn.ModuleList()
- self.time_res_stack = nn.ModuleList()
- for i in range(temporal_down_blocks):
- conv = Emu3VQVAETemporalDownsample(out_channels, out_channels)
- self.time_conv.append(conv)
- for _ in range(config.num_res_blocks):
- time_res_conv = Emu3VQVAETemporalResnetBlock(
- in_channels=out_channels,
- out_channels=out_channels,
- )
- self.time_res_stack.append(time_res_conv)
- def forward(self, pixel_values: torch.LongTensor):
- temporal_dim = pixel_values.shape[1]
- pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
- # downsampling & middle
- hidden_states = self.conv_in(pixel_values)
- hidden_states = self.down_block(hidden_states)
- hidden_states = self.middle_block(hidden_states)
- # end
- hidden_states = self.norm_out(hidden_states)
- hidden_states *= torch.sigmoid(hidden_states)
- hidden_states = self.conv_out(hidden_states)
- hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:])
- hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
- # temporal convs
- for conv in self.time_conv:
- hidden_states = conv(hidden_states)
- hidden_states *= torch.sigmoid(hidden_states)
- for layer in self.time_res_stack:
- hidden_states = layer(hidden_states)
- hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
- return hidden_states
- class Emu3VQVAEDecoder(nn.Module):
- def __init__(self, config: Emu3VQVAEConfig):
- super().__init__()
- quant_channels = config.embed_dim
- block_in = config.base_channels * config.channel_multiplier[-1]
- self.time_res_stack = nn.ModuleList()
- for _ in range(config.num_res_blocks):
- time_res_conv = Emu3VQVAETemporalResnetBlock(
- in_channels=config.latent_channels, out_channels=config.latent_channels
- )
- self.time_res_stack.append(time_res_conv)
- temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
- self.time_conv = nn.ModuleList()
- for i in range(temp_upsample_block_num):
- conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels)
- self.time_conv.append(conv)
- self.conv_in = nn.Conv2d(
- config.latent_channels,
- block_in,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels)
- self.up_block = Emu3VQVAEUpBlock(config)
- block_in = config.base_channels * config.channel_multiplier[0]
- self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in)
- self.conv_out = nn.Conv2d(
- block_in,
- config.out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
- hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0)
- hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
- # temporal convs
- for layer in self.time_res_stack:
- hidden_quant_states = layer(hidden_quant_states)
- for layer in self.time_conv:
- hidden_quant_states = layer(hidden_quant_states)
- hidden_quant_states *= torch.sigmoid(hidden_quant_states)
- hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
- hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0)
- hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:])
- quant_states = quant_states.reshape(-1, *quant_states.shape[2:])
- hidden_states = self.conv_in(hidden_states)
- # middle & upsampling
- hidden_states = self.middle_block(hidden_states, quant_states)
- hidden_states = self.up_block(hidden_states, quant_states)
- hidden_states = self.norm_out(hidden_states, quant_states)
- hidden_states *= torch.sigmoid(hidden_states)
- hidden_states = self.conv_out(hidden_states)
- return hidden_states
- @auto_docstring(
- custom_intro="""
- The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens.
- This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
- [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
- Taigman](https://huggingface.co/papers/2203.13131).
- """
- )
- class Emu3VQVAE(PreTrainedModel):
- config: Emu3VQVAEConfig
- base_model_prefix = "emuvideovq"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- _no_split_modules = [
- "Emu3VQVAETemporalResnetBlock",
- "Emu3VQVAEAttentionBlock",
- "Emu3VQVAEResnetBlock",
- "Emu3VQVAEVectorQuantizer",
- ]
- _can_record_outputs = {
- "hidden_states": [Emu3VQVAEResnetBlock, Emu3VQVAETemporalResnetBlock],
- "attentions": Emu3VQVAEAttentionBlock,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- if isinstance(module, (nn.Conv2d, nn.Conv3d)):
- init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
- if module.bias is not None:
- fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
- bound = 1 / math.sqrt(fan_in)
- init.uniform_(module.bias, -bound, bound)
- elif isinstance(module, nn.Linear):
- init.kaiming_uniform_(module.weight, a=math.sqrt(5))
- if module.bias is not None:
- fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- init.uniform_(module.bias, -bound, bound)
- elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
- init.constant_(module.weight, 1.0)
- init.constant_(module.bias, 0.0)
- if getattr(module, "running_mean", None) is not None:
- init.zeros_(module.running_mean)
- init.ones_(module.running_var)
- init.zeros_(module.num_batches_tracked)
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- def __init__(self, config: Emu3VQVAEConfig):
- super().__init__(config)
- self.config = config
- self.encoder = Emu3VQVAEEncoder(config)
- self.decoder = Emu3VQVAEDecoder(config)
- self.quantize = Emu3VQVAEVectorQuantizer(config)
- self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1)
- self.quant_conv = Emu3VQVAEConv3d(
- config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1)
- )
- self.post_quant_conv = Emu3VQVAEConv3d(
- config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1)
- )
- self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1)
- self.eval() # Emu3's VQ model is frozen
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def encode(
- self, pixel_values: torch.Tensor, image_sizes: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
- ) -> Emu3VQVAEModelOutput:
- is_image = pixel_values.ndim == 4
- if is_image:
- temporal = self.config.temporal_downsample_factor
- batch_size, channels, height, width = pixel_values.shape
- pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1)
- else:
- batch_size, temporal, channels, height, width = pixel_values.shape
- hidden_states = self.encoder(pixel_values)
- # b t c h w -> b c t h w
- conv_hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
- conv_hidden_states = self.quant_conv(conv_hidden_states)
- # b c t h w -> b t c h w
- conv_hidden_states = conv_hidden_states.permute(0, 2, 1, 3, 4)
- codes = self.quantize(conv_hidden_states)
- image_tokens = codes.squeeze(1) if is_image else codes
- image_tokens = [
- single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)]
- for single_image, size in zip(image_tokens, image_sizes)
- ]
- return Emu3VQVAEModelOutput(
- last_hidden_state=hidden_states,
- image_tokens=image_tokens,
- )
- def decode(self, hidden_states: torch.Tensor):
- is_image = hidden_states.ndim == 3
- if is_image:
- hidden_states = hidden_states.unsqueeze(1)
- batch_size, temporal, height, width = hidden_states.shape
- quant = self.quantize.embedding(hidden_states.flatten())
- channels = quant.shape[-1]
- quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous()
- post_quant = self.post_quant_conv(quant)
- quant = quant.permute(0, 2, 1, 3, 4)
- post_quant = post_quant.permute(0, 2, 1, 3, 4)
- video = self.decoder(post_quant, quant)
- video = video.reshape(
- batch_size,
- temporal * self.config.temporal_downsample_factor,
- self.config.out_channels,
- height * self.spatial_scale_factor,
- width * self.spatial_scale_factor,
- )
- return video[:, 0] if is_image else video
- class Emu3ImageVocabularyMapping:
- """
- A class for mapping discrete image tokens from VQGAN to BPE tokens.
- """
- def __init__(self, vocab_map):
- self.vocab_map = vocab_map
- self.eol_token_id = vocab_map.get("<|extra_200|>")
- self.image_token_id = vocab_map.get("<image>")
- @cached_property
- def image_tokens(self):
- return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
- @cached_property
- def image_tokens_str(self):
- return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
- @cached_property
- def img2bpe(self):
- return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str}
- @cached_property
- def bpe2img(self):
- return {v: k for k, v in self.img2bpe.items()}
- @cached_property
- def bpe2img_mapping_tensor(self):
- mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int)
- for k, v in self.bpe2img.items():
- mapping[k] = v
- return mapping
- @cached_property
- def img2bpe_mapping_tensor(self):
- mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
- for k, v in self.img2bpe.items():
- mapping[k] = v
- return mapping
- def convert_img2bpe(self, img_batch: list[torch.Tensor]) -> torch.Tensor:
- device = img_batch.device
- eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id
- img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
- img_tokens = torch.cat([img_tokens, eol_row], dim=-1)
- return img_tokens.to(device)
- def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor:
- device = img_batch.device
- img_batch = img_batch[..., :-1] # remove last row of EOL tokens
- img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")]
- return img_tokens.to(device)
- @auto_docstring
- class Emu3PreTrainedModel(PreTrainedModel):
- config: Emu3Config
- base_model_prefix = "model"
- input_modalities = ("image", "text")
- supports_gradient_checkpointing = True
- _no_split_modules = [
- "Emu3DecoderLayer",
- ]
- _skip_keys_device_placement = ["past_key_values", "causal_mask"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _can_compile_fullgraph = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": Emu3DecoderLayer,
- "attentions": Emu3Attention,
- }
- class Emu3RotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: Emu3Config, device=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- @staticmethod
- def compute_default_rope_parameters(
- config: Emu3Config | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- @auto_docstring
- class Emu3TextModel(Emu3PreTrainedModel):
- config: Emu3TextConfig
- def __init__(self, config: Emu3TextConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = Emu3RotaryEmbedding(config=config)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- causal_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_embeddings=position_embeddings,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- _tp_plan = {"lm_head": "colwise_gather_output"}
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
- config: Emu3TextConfig
- def __init__(self, config):
- super().__init__(config)
- self.model = Emu3TextModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
- ```python
- >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
- >>> import torch
- >>> import httpx
- >>> from io import BytesIO
- >>> from PIL import Image
- >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
- >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
- >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device)
- >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
- >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- ```"""
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class Emu3Model(Emu3PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.text_model = Emu3TextModel._from_config(config.text_config)
- self.vqmodel = Emu3VQVAE(config.vq_config)
- self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.text_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.text_model.set_input_embeddings(value)
- def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor) -> torch.LongTensor:
- """
- Tokenizes images into discrete tokens with VQGAN module. Converts
- obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
- special tokens.
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The tensors corresponding to the input images.
- image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
- The sizes of the images in the batch, being (height, width) for each image.
- """
- vqmodel_outputs: Emu3VQVAEModelOutput = self.vqmodel.encode(pixel_values, image_sizes, return_dict=True)
- bpe_tokens_list = [
- self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in vqmodel_outputs.image_tokens
- ]
- bpe_tokens = torch.cat(bpe_tokens_list)
- return bpe_tokens
- @can_return_tuple
- @auto_docstring(
- custom_intro="Tokenizes images into discrete tokens with VQGAN module and embeds them with text embeddings layer"
- )
- def get_image_features(
- self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor, **kwargs: Unpack[TransformersKwargs]
- ) -> tuple | Emu3VQVAEModelOutput:
- r"""
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
- The tensors corresponding to the input images.
- """
- vqmodel_outputs: Emu3VQVAEModelOutput = self.vqmodel.encode(
- pixel_values, image_sizes, return_dict=True, **kwargs
- )
- split_sizes = [
- (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1)
- for height, width in image_sizes
- ]
- bpe_tokens_list = [
- self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in vqmodel_outputs.image_tokens
- ]
- bpe_tokens = torch.cat(bpe_tokens_list)
- image_embeddings = self.get_input_embeddings()(bpe_tokens)
- image_features = torch.split(image_embeddings, split_sizes)
- vqmodel_outputs.pooler_output = image_features
- return vqmodel_outputs
- @torch.no_grad()
- def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
- """
- Decodes generated image tokens from language model to continuous pixel values
- with VQGAN module via upsampling.
- Args:
- image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
- The tensors corresponding to the input images.
- height (`int`):
- Height of the generated image before upsampling.
- width (`int`):
- Width of the generated image before upsampling.
- """
- sequences = image_tokens[:, :-3].view(-1, height, width + 1)
- image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences)
- image = self.vqmodel.decode(image_tokens)
- return image
- def get_placeholder_mask(
- self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
- ):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
- equal to the length of multimodal features. If the lengths are different, an error is raised.
- """
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- else:
- special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
- n_image_tokens = special_image_mask.sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- torch_compilable_check(
- inputs_embeds[special_image_mask].numel() == image_features.numel(),
- f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
- )
- return special_image_mask
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- image_sizes: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | CausalLMOutputWithPast:
- r"""
- image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
- The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
- [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
- [`Emu3ImageProcessor`] for processing images).
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError(
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
- )
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if pixel_values is not None:
- image_features = self.get_image_features(pixel_values, image_sizes).pooler_output
- image_features = torch.cat(image_features, dim=0)
- special_image_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, image_features=image_features
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.text_model(
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- return outputs
- class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
- output_modalities = ("image", "text")
- _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
- def __init__(self, config):
- super().__init__(config)
- self.model = Emu3Model(config)
- self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.model.set_input_embeddings(value)
- def get_output_embeddings(self) -> nn.Module:
- return self.lm_head
- def decode_image_tokens(self, **kwargs):
- return self.model.decode_image_tokens(**kwargs)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- image_sizes: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- labels: torch.LongTensor | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | CausalLMOutputWithPast:
- r"""
- image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
- The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
- [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
- [`Emu3ImageProcessor`] for processing images).
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
- >>> import torch
- >>> import httpx
- >>> from io import BytesIO
- >>> from PIL import Image
- >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
- >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
- >>> conversation = [
- ... {
- ... "role": "system",
- ... "content": [
- ... {"type": "text", "text": "You are a helpful assistant."},
- ... ],
- ... },
- ... {
- ... "role": "user",
- ... "content": [
- ... {"type": "image"},
- ... {"type": "text", "text": "Please describe the image."},
- ... ],
- ... },
- ... ]
- >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16)
- >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
- >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- ```"""
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
- )
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- position_ids=None,
- use_cache=True,
- pixel_values=None,
- is_first_iteration=False,
- **kwargs,
- ):
- # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- position_ids=position_ids,
- pixel_values=pixel_values,
- use_cache=use_cache,
- is_first_iteration=is_first_iteration,
- **kwargs,
- )
- if not is_first_iteration and use_cache:
- model_inputs["pixel_values"] = None
- return model_inputs
- __all__ = [
- "Emu3ForConditionalGeneration",
- "Emu3ForCausalLM",
- "Emu3TextModel",
- "Emu3PreTrainedModel",
- "Emu3VQVAE",
- "Emu3Model",
- ]
|