| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156 |
- # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from huggingface_hub.dataclasses import strict
- from torch import nn
- from torchvision.transforms.v2 import functional as tvF
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache
- from ...configuration_utils import PreTrainedConfig
- from ...image_processing_backends import TorchvisionBackend
- from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
- from ...image_transforms import divide_to_patches
- from ...image_utils import (
- ChannelDimension,
- ImageInput,
- PILImageResampling,
- SizeDict,
- get_image_size,
- )
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_outputs import BaseModelOutputWithPooling
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
- from ...tokenization_python import PreTokenizedInput, TextInput
- from ...utils import (
- TensorType,
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- logging,
- )
- from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
- from ..llama.configuration_llama import LlamaConfig
- from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaMLP,
- LlamaModel,
- LlamaPreTrainedModel,
- LlamaRMSNorm,
- )
- from ..llava.modeling_llava import (
- LlavaCausalLMOutputWithPast,
- LlavaForConditionalGeneration,
- LlavaModel,
- LlavaModelOutputWithPast,
- )
- logger = logging.get_logger(__name__)
- def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
- """
- Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
- Args:
- token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
- expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
- tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
- Returns:
- torch.Tensor: Output tensor of shape (num_tokens, out_features).
- """
- num_tokens = token_states.shape[0]
- out_features = expert_weights.shape[-1]
- output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
- cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
- # Insert zero at the beginning for offset index's convenience
- zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
- cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
- for expert_num in range(expert_weights.shape[0]):
- start = cumsum_num_tokens[expert_num]
- end = cumsum_num_tokens[expert_num + 1]
- tokens = token_states[start:end]
- out = torch.matmul(tokens, expert_weights[expert_num])
- output[start:end] = out
- return output
- @auto_docstring(checkpoint="rhymes-ai/Aria")
- @strict
- class AriaTextConfig(LlamaConfig):
- r"""
- moe_num_experts (`int`, *optional*, defaults to 8):
- The number of experts in the MoE layer.
- moe_topk (`int`, *optional*, defaults to 2):
- The number of top experts to route to for each token.
- moe_num_shared_experts (`int`, *optional*, defaults to 2):
- The number of shared experts.
- """
- model_type = "aria_text"
- base_config_key = "text_config"
- base_model_tp_plan = {
- "layers.*.self_attn.q_proj": "colwise",
- "layers.*.self_attn.k_proj": "colwise",
- "layers.*.self_attn.v_proj": "colwise",
- "layers.*.self_attn.o_proj": "rowwise",
- "layers.*.mlp.shared_experts.gate_proj": "colwise",
- "layers.*.mlp.shared_experts.up_proj": "colwise",
- "layers.*.mlp.shared_experts.down_proj": "rowwise",
- }
- intermediate_size: int = 4096
- moe_num_experts: int = 8
- moe_topk: int = 2
- moe_num_shared_experts: int = 2
- pad_token_id: int | None = 2
- @auto_docstring(checkpoint="rhymes-ai/Aria")
- @strict
- class AriaConfig(PreTrainedConfig):
- r"""
- projector_patch_to_query_dict (`dict`, *optional*):
- Mapping of patch sizes to query dimensions.
- """
- model_type = "aria"
- attribute_map = {
- "image_token_id": "image_token_index",
- }
- sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
- vision_config: dict | PreTrainedConfig | None = None
- text_config: dict | AriaTextConfig | None = None
- vision_feature_layer: int | list[int] = -1
- projector_patch_to_query_dict: dict | None = None
- image_token_index: int = 9
- initializer_range: float = 0.02
- tie_word_embeddings: bool = False
- def __post_init__(self, **kwargs):
- # Convert the keys and values of projector_patch_to_query_dict to integers
- # This ensures consistency even if they were provided as strings
- if self.projector_patch_to_query_dict is None:
- self.projector_patch_to_query_dict = {
- 1225: 128,
- 4900: 256,
- }
- self.projector_patch_to_query_dict = {int(k): int(v) for k, v in self.projector_patch_to_query_dict.items()}
- self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
- if isinstance(self.vision_config, dict):
- self.vision_config["model_type"] = "idefics3_vision"
- self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config)
- elif self.vision_config is None:
- self.vision_config = CONFIG_MAPPING["idefics3_vision"]()
- if isinstance(self.text_config, dict) and "model_type" in self.text_config:
- self.text_config = AriaTextConfig(**self.text_config)
- elif self.text_config is None:
- self.text_config = AriaTextConfig()
- super().__post_init__(**kwargs)
- class AriaTextRMSNorm(LlamaRMSNorm):
- pass
- class AriaProjectorMLP(nn.Module):
- """
- Feed-Forward Network module for the Aria Projector.
- Args:
- in_features (`int`):
- Input embedding dimension.
- hidden_features (`int`):
- Hidden dimension of the feed-forward network.
- output_dim (`int`):
- Output dimension.
- """
- def __init__(self, in_features, hidden_features, output_dim):
- super().__init__()
- self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
- self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
- self.act = ACT2FN["gelu_new"]
- def forward(self, hidden_states):
- hidden_states = self.act(self.linear_in(hidden_states))
- hidden_states = self.linear_out(hidden_states)
- return hidden_states
- class AriaCrossAttention(nn.Module):
- """
- Aria Cross-Attention module.
- Args:
- config (`AriaConfig`):
- The configuration to use.
- """
- def __init__(self, config: AriaConfig, dropout_rate: float = 0):
- super().__init__()
- hidden_size = config.vision_config.hidden_size
- num_heads = config.vision_config.num_attention_heads
- self.num_heads = num_heads
- self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
- self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
- self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
- # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
- self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
- self.linear = nn.Linear(hidden_size, hidden_size)
- self.dropout = nn.Dropout(dropout_rate)
- self.layer_norm = nn.LayerNorm(hidden_size)
- self.layer_norm_kv = nn.LayerNorm(hidden_size)
- def forward(self, key_value_states, hidden_states, attn_mask=None):
- """
- Forward pass of the AriaCrossAttention module.
- Args:
- key_value_states (`torch.Tensor`):
- Input tensor for key and value.
- hidden_states (`torch.Tensor`):
- Input tensor for query.
- attn_mask (`torch.Tensor`, *optional*, defaults to None):
- Attention mask.
- Returns:
- torch.Tensor:
- Output tensor after cross-attention.
- """
- query = self.q_proj(self.layer_norm(hidden_states))
- key_value_states = self.layer_norm_kv(key_value_states)
- key = self.k_proj(key_value_states)
- value = self.v_proj(key_value_states)
- attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
- attn_output = self.dropout(self.linear(attn_output))
- return attn_output
- class AriaProjector(nn.Module):
- """
- Aria Projector module.
- This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
- Args:
- config (`AriaConfig`):
- Configuration object for the model.
- """
- def __init__(
- self,
- config: AriaConfig,
- ):
- super().__init__()
- self.patch_to_query_dict = config.projector_patch_to_query_dict
- self.in_features = config.vision_config.hidden_size
- self.num_heads = config.vision_config.num_attention_heads
- self.kv_dim = config.vision_config.hidden_size
- self.hidden_features = config.text_config.hidden_size
- self.output_dim = config.text_config.hidden_size
- self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
- self.cross_attn = AriaCrossAttention(config)
- self.layer_norm = nn.LayerNorm(self.in_features)
- self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
- def forward(self, key_value_states: torch.Tensor, attn_mask: torch.Tensor | None = None):
- """
- Forward pass of the Projector module.
- Args:
- key_value_states (`torch.Tensor`):
- Input tensor of shape (batch_size, num_patches, kv_dim).
- attn_mask (`torch.Tensor`, *optional*, default is None):
- Attention mask.
- Returns:
- `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
- """
- batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
- if num_patches not in self.patch_to_query_dict:
- raise KeyError(
- f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
- )
- query_num = self.patch_to_query_dict[num_patches]
- queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
- if attn_mask is not None:
- attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
- attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
- attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
- out = self.feed_forward(self.layer_norm(attention_out))
- return out
- class AriaImageProcessorKwargs(ImagesKwargs, total=False):
- r"""
- max_image_size (`int`, *optional*, defaults to `self.max_image_size`):
- Maximum image size. Must be either 490 or 980.
- min_image_size (`int`, *optional*, defaults to `self.min_image_size`):
- Minimum image size. Images smaller than this in any dimension will be scaled up.
- split_resolutions (`list[list[int]]`, *optional*, defaults to `self.split_resolutions`):
- A list of possible resolutions as (height, width) pairs for splitting high-resolution images into patches.
- split_image (`bool`, *optional*, defaults to `self.split_image`):
- Whether to split the image into patches using the best matching resolution from `split_resolutions`.
- """
- max_image_size: int
- min_image_size: int
- split_resolutions: list[list[int]]
- split_image: bool
- @auto_docstring
- class AriaImageProcessor(TorchvisionBackend):
- model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
- valid_kwargs = AriaImageProcessorKwargs
- resample = PILImageResampling.BICUBIC
- image_mean = [0.5, 0.5, 0.5]
- image_std = [0.5, 0.5, 0.5]
- max_image_size = 980
- min_image_size = 336
- split_image = False
- split_resolutions = None
- do_convert_rgb = True
- do_rescale = True
- do_normalize = True
- def __init__(self, **kwargs: Unpack[AriaImageProcessorKwargs]):
- if kwargs.get("split_resolutions") is None:
- default_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
- kwargs["split_resolutions"] = [[el[0] * 490, el[1] * 490] for el in default_resolutions]
- super().__init__(**kwargs)
- def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple) -> list[int]:
- """Get padding size for patching, returns [left, top, right, bottom] for tvF.pad."""
- original_height, original_width = original_resolution
- target_height, target_width = target_resolution
- paste_x, r_x = divmod(target_width - original_width, 2)
- paste_y, r_y = divmod(target_height - original_height, 2)
- return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
- def _resize_for_patching(
- self,
- image: "torch.Tensor",
- target_resolution: tuple,
- resample: "PILImageResampling | tvF.InterpolationMode | int | None",
- ) -> "torch.Tensor":
- """Resize an image to a target resolution while maintaining aspect ratio."""
- new_height, new_width = get_patch_output_size(
- image, target_resolution, input_data_format=ChannelDimension.FIRST
- )
- return self.resize(image, SizeDict(height=new_height, width=new_width), resample)
- def _pad_for_patching(
- self,
- image: "torch.Tensor",
- target_resolution: tuple,
- ) -> "torch.Tensor":
- """Pad an image to a target resolution while maintaining aspect ratio."""
- new_resolution = get_patch_output_size(image, target_resolution, input_data_format=ChannelDimension.FIRST)
- padding = self._get_padding_size(new_resolution, target_resolution)
- return tvF.pad(image, padding=padding)
- def get_image_patches(
- self,
- image: "torch.Tensor",
- grid_pinpoints: list[list[int]],
- patch_size: int,
- resample: "PILImageResampling | tvF.InterpolationMode | int | None",
- ) -> list["torch.Tensor"]:
- """
- Process an image with variable resolutions by dividing it into patches.
- Args:
- image (`torch.Tensor`):
- The input image to be processed (channels-first format).
- grid_pinpoints (`list[list[int]]`):
- A list of possible resolutions as (height, width) pairs.
- patch_size (`int`):
- Size of each square patch to divide the image into.
- resample (`PILImageResampling | tvF.InterpolationMode | int | None`):
- Resampling filter to use when resizing.
- Returns:
- `list[torch.Tensor]`: A list of image patches in channels-first format.
- """
- if not isinstance(grid_pinpoints, list):
- raise TypeError("grid_pinpoints must be a list of possible resolutions.")
- image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
- best_resolution = select_best_resolution(image_size, grid_pinpoints)
- resized_image = self._resize_for_patching(image, best_resolution, resample)
- padded_image = self._pad_for_patching(resized_image, best_resolution)
- patches = divide_to_patches(padded_image, patch_size=patch_size)
- return patches
- def _preprocess(
- self,
- images: list["torch.Tensor"],
- do_rescale: bool,
- rescale_factor: float,
- do_normalize: bool,
- image_mean: float | list[float] | None,
- image_std: float | list[float] | None,
- disable_grouping: bool | None,
- return_tensors: str | TensorType | None,
- max_image_size: int = 980,
- min_image_size: int = 336,
- split_resolutions: list[list[int]] | None = None,
- split_image: bool = False,
- resample: "PILImageResampling | tvF.InterpolationMode | int | None" = None,
- **kwargs,
- ) -> BatchFeature:
- if max_image_size not in [490, 980]:
- raise ValueError("max_image_size must be either 490 or 980")
- pixel_masks = []
- processed_crops = []
- num_crops = None
- for image in images:
- if split_image:
- crop_images = self.get_image_patches(image, split_resolutions, max_image_size, resample)
- else:
- crop_images = [image]
- if num_crops is None or len(crop_images) > num_crops:
- num_crops = len(crop_images)
- for crop_image in crop_images:
- h, w = crop_image.shape[-2], crop_image.shape[-1]
- scale = max_image_size / max(h, w)
- if w >= h:
- new_h = max(int(h * scale), min_image_size)
- new_w = max_image_size
- else:
- new_h = max_image_size
- new_w = max(int(w * scale), min_image_size)
- crop_image = self.resize(crop_image, SizeDict(height=new_h, width=new_w), resample)
- padding_bottom = max_image_size - new_h
- padding_right = max_image_size - new_w
- crop_image = tvF.pad(crop_image, [0, 0, padding_right, padding_bottom])
- pixel_mask = torch.zeros((max_image_size, max_image_size), dtype=torch.bool)
- pixel_mask[:new_h, :new_w] = True
- pixel_masks.append(pixel_mask)
- processed_crops.append(crop_image)
- stacked_images = torch.stack(processed_crops, dim=0)
- stacked_images = self.rescale_and_normalize(
- stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
- )
- stacked_masks = torch.stack(pixel_masks, dim=0)
- return BatchFeature(
- data={
- "pixel_values": stacked_images,
- "pixel_mask": stacked_masks,
- "num_crops": num_crops,
- },
- tensor_type=return_tensors,
- )
- def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
- """
- A utility that returns number of image patches for a given image size.
- Args:
- height (`int`):
- Height of the input image.
- width (`int`):
- Width of the input image.
- images_kwargs (`dict`, *optional*):
- Any kwargs to override defaults of the image processor.
- Returns:
- `int`: Number of patches per image.
- """
- split_image = images_kwargs.get("split_image", self.split_image)
- max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
- resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
- num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
- return num_patches
- class AriaImagesKwargs(ImagesKwargs, total=False):
- """
- split_image (`bool`, *optional*, defaults to `False`):
- Whether to split large images into multiple crops. When enabled, images exceeding the maximum size are
- divided into overlapping crops that are processed separately and then combined. This allows processing
- of very high-resolution images that exceed the model's input size limits.
- max_image_size (`int`, *optional*, defaults to `980`):
- Maximum image size (in pixels) for a single image crop. Images larger than this will be split into
- multiple crops when `split_image=True`, or resized if splitting is disabled. This parameter controls
- the maximum resolution of individual image patches processed by the model.
- min_image_size (`int`, *optional*):
- Minimum image size (in pixels) for a single image crop. Images smaller than this will be upscaled to
- meet the minimum requirement. If not specified, images are processed at their original size (subject
- to the maximum size constraint).
- """
- split_image: bool
- max_image_size: int
- min_image_size: int
- class AriaProcessorKwargs(ProcessingKwargs, total=False):
- images_kwargs: AriaImagesKwargs
- _defaults = {
- "text_kwargs": {
- "padding": False,
- "return_mm_token_type_ids": False,
- },
- "images_kwargs": {
- "max_image_size": 980,
- "split_image": False,
- },
- "return_tensors": TensorType.PYTORCH,
- }
- @auto_docstring
- class AriaProcessor(ProcessorMixin):
- def __init__(
- self,
- image_processor=None,
- tokenizer: AutoTokenizer | str = None,
- chat_template: str | None = None,
- size_conversion: dict[float | int, int] | None = None,
- ):
- r"""
- size_conversion (`Dict`, *optional*):
- A dictionary indicating size conversions for images.
- """
- if size_conversion is None:
- size_conversion = {490: 128, 980: 256}
- self.size_conversion = {int(k): v for k, v in size_conversion.items()}
- self.image_token = tokenizer.image_token
- self.image_token_id = tokenizer.image_token_id
- if tokenizer is not None and tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.unk_token
- super().__init__(image_processor, tokenizer, chat_template=chat_template)
- @auto_docstring
- def __call__(
- self,
- text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
- images: ImageInput | None = None,
- **kwargs: Unpack[AriaProcessorKwargs],
- ) -> BatchFeature:
- r"""
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
- """
- output_kwargs = self._merge_kwargs(
- AriaProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- if isinstance(text, str):
- text = [text]
- elif not isinstance(text, list) and not isinstance(text[0], str):
- raise TypeError("Invalid input text. Please provide a string, or a list of strings")
- if images is not None:
- image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
- # expand the image_token according to the num_crops and tokens per image
- tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
- prompt_strings = []
- num_crops = image_inputs.pop("num_crops") * tokens_per_image
- for sample in text:
- sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
- prompt_strings.append(sample)
- else:
- image_inputs = {}
- prompt_strings = text
- return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
- return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
- text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
- self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
- if return_mm_token_type_ids:
- text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
- return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
- def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
- """
- Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
- Args:
- image_sizes (`list[list[int]]`, *optional*):
- The input sizes formatted as (height, width) per each image.
- Returns:
- `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
- input modalities, along with other useful data.
- """
- vision_data = {}
- if image_sizes is not None:
- images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
- images_kwargs.update(kwargs)
- max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
- num_image_patches = [
- self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
- for image_size in image_sizes
- ]
- num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
- vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
- return MultiModalData(**vision_data)
- @property
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- image_processor_input_names = self.image_processor.model_input_names
- # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
- # otherwise `self.image_processor.model_input_names` is also modified
- image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
- class AriaSharedExpertsMLP(LlamaMLP):
- """
- Shared Expert MLP for shared experts.
- Unlike routed experts, shared experts process all tokens without routing.
- This class reconfigures the intermediate size in comparison to the LlamaMLP.
- Args:
- config (`AriaTextConfig`): Configuration object for the Aria language model.
- """
- def __init__(self, config: AriaTextConfig):
- super().__init__(config)
- self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
- class AriaGroupedExpertsGemm(nn.Module):
- """
- Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
- This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
- for optimized performance. If the grouped_gemm library is not installed, it gracefully
- falls back to a sequential GEMM implementation, which may be slower but ensures
- functionality.
- Args:
- in_features (`int`):
- Number of input features.
- out_features (`int`):
- Number of output features.
- groups (`int`):
- Number of expert groups.
- """
- def __init__(self, in_features, out_features, groups):
- super().__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.groups = groups
- self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
- def forward(self, input, tokens_per_expert):
- """
- Perform grouped matrix multiplication.
- Args:
- input (`torch.Tensor`):
- Input tensor of shape (num_tokens, in_features).
- tokens_per_expert (`torch.Tensor`):
- Number of tokens assigned to each expert.
- Returns:
- torch.Tensor: Output tensor of shape (num_tokens, out_features).
- """
- return sequential_experts_gemm(
- input,
- self.weight,
- tokens_per_expert.cpu(),
- )
- class AriaExperts(nn.Module):
- def __init__(self, config: AriaTextConfig) -> None:
- super().__init__()
- self.config = config
- self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
- self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
- def route_tokens_to_experts(self, router_logits):
- top_logits, top_indices = torch.topk(router_logits, k=self.config.moe_topk, dim=1)
- scores = nn.functional.softmax(top_logits, dim=-1)
- return top_indices, scores
- def forward(self, hidden_states, router_logits) -> torch.Tensor:
- top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
- original_dtype = top_k_index.dtype
- tokens_per_expert = torch.histc(
- top_k_index.flatten().to(torch.float32),
- bins=self.config.moe_num_experts,
- min=0,
- max=self.config.moe_num_experts - 1,
- ).to(original_dtype)
- indices = top_k_index
- flatten_indices = indices.view(-1)
- sorted_indices = torch.argsort(flatten_indices)
- permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
- fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
- projection, gate = torch.chunk(fc1_output, 2, dim=-1)
- fc1_output = nn.functional.silu(projection) * gate
- expert_output = self.fc2(fc1_output, tokens_per_expert)
- unpermuted_tokens = torch.zeros(
- (top_k_weights.shape[0] * self.config.moe_topk, expert_output.size(1)),
- dtype=expert_output.dtype,
- device=expert_output.device,
- )
- unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
- unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
- output = (unpermuted_tokens * top_k_weights.unsqueeze(-1)).sum(dim=1)
- return output
- class AriaTextMoELayer(nn.Module):
- def __init__(self, config: AriaTextConfig):
- super().__init__()
- self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
- self.experts = AriaExperts(config)
- self.shared_experts = AriaSharedExpertsMLP(config)
- self.config = config
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- original_shape = hidden_states.shape
- hidden_states = hidden_states.view(-1, hidden_states.size(-1))
- router_logits = self.router(hidden_states)
- expert_output = self.experts(hidden_states, router_logits).view(original_shape)
- shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
- return expert_output + shared_expert_output
- class AriaTextAttention(LlamaAttention):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- class AriaTextDecoderLayer(LlamaDecoderLayer):
- """
- Aria Text Decoder Layer.
- This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
- Args:
- config (`AriaTextConfig`):
- Configuration object for the text component of the model.
- layer_idx (`int`):
- Index of the layer.
- """
- def __init__(self, config: AriaTextConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.mlp = AriaTextMoELayer(config)
- @auto_docstring
- class AriaTextPreTrainedModel(PreTrainedModel):
- config: AriaTextConfig
- base_model_prefix = "model"
- input_modalities = ("image", "text")
- _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
- supports_gradient_checkpointing = True
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": AriaTextDecoderLayer,
- "attentions": AriaTextAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, AriaGroupedExpertsGemm):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- class AriaPreTrainedModel(LlamaPreTrainedModel):
- config: AriaConfig
- base_model_prefix = "model"
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
- _supports_attention_backend = True
- @torch.no_grad()
- def _init_weights(self, module):
- PreTrainedModel._init_weights(self, module)
- if isinstance(module, AriaProjector):
- init.trunc_normal_(module.query, std=self.config.initializer_range)
- class AriaTextModel(LlamaModel):
- def __init__(self, config: AriaTextConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.gradient_checkpointing = False
- self.post_init()
- class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
- _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
- def __init__(self, config: AriaTextConfig):
- super().__init__(config)
- self.model = AriaTextModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(self, **super_kwargs):
- super().forward(self, **super_kwargs)
- class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
- pass
- class AriaModelOutputWithPast(LlavaModelOutputWithPast):
- pass
- class AriaModel(LlavaModel):
- def __init__(self, config: AriaConfig):
- super().__init__(config)
- self.multi_modal_projector = AriaProjector(config)
- def _create_patch_attention_mask(self, pixel_mask):
- if pixel_mask is None:
- return None
- patches_subgrid = pixel_mask.unfold(
- dimension=1,
- size=self.vision_tower.config.patch_size,
- step=self.vision_tower.config.patch_size,
- )
- patches_subgrid = patches_subgrid.unfold(
- dimension=2,
- size=self.vision_tower.config.patch_size,
- step=self.vision_tower.config.patch_size,
- )
- return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- pixel_mask: torch.FloatTensor | None = None,
- vision_feature_layer: int | list[int] = -1,
- output_hidden_states: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
- image_outputs = self.vision_tower(
- pixel_values,
- patch_attention_mask=patch_attention_mask,
- output_hidden_states=True, # Ignore arg on purpose
- return_dict=True,
- **kwargs,
- )
- image_attn_mask = None
- if patch_attention_mask is not None:
- flattened_mask = patch_attention_mask.flatten(1)
- image_attn_mask = torch.logical_not(flattened_mask)
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
- image_outputs.pooler_output = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
- return image_outputs
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_mask: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple | AriaModelOutputWithPast:
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- # 2. Merge text and images
- if pixel_values is not None and inputs_embeds.shape[1] != 1:
- image_features = self.get_image_features(
- pixel_values=pixel_values,
- pixel_mask=pixel_mask,
- vision_feature_layer=self.config.vision_feature_layer,
- return_dict=True,
- ).pooler_output
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
- special_image_mask = self.get_placeholder_mask(
- input_ids, inputs_embeds=inputs_embeds, image_features=image_features
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
- outputs = self.language_model(
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- return AriaModelOutputWithPast(
- last_hidden_state=outputs.last_hidden_state,
- past_key_values=outputs.past_key_values if use_cache else None,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=image_features if pixel_values is not None else None,
- )
- @auto_docstring(
- custom_intro="""
- Aria model for conditional generation tasks.
- This model combines a vision tower, a multi-modal projector, and a language model
- to perform tasks that involve both image and text inputs.
- """
- )
- class AriaForConditionalGeneration(LlavaForConditionalGeneration):
- _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- pixel_mask: torch.FloatTensor | None = None,
- vision_feature_layer: int | list[int] = -1,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- return self.model.get_image_features(
- pixel_values=pixel_values,
- pixel_mask=pixel_mask,
- vision_feature_layer=vision_feature_layer,
- **kwargs,
- )
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- pixel_mask: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | AriaCausalLMOutputWithPast:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
- Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
- computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Example:
- ```python
- >>> import httpx
- >>> from io import BytesIO
- >>> import torch
- >>> from PIL import Image
- >>> from io import BytesIO
- >>> from transformers import AutoProcessor, AutoModel
- >>> from transformers.image_utils import load_image
- >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
- >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
- >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
- >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
- >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
- >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
- >>> # Create inputs
- >>> messages = [
- ... {
- ... "role": "user",
- ... "content": [
- ... {"type": "image"},
- ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
- ... {"type": "image"},
- ... {"type": "text", "text": "What can we see in this image?"},
- ... ]
- ... },
- ... {
- ... "role": "user",
- ... "content": [
- ... {"type": "image"},
- ... {"type": "text", "text": "In which city is that bridge located?"},
- ... ]
- ... }
- ... ]
- >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
- >>> images = [[image1, image2], [image3]]
- >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
- >>> # Generate
- >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
- >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
- >>> print(generated_texts[0])
- Assistant: There are buildings, trees, lights, and water visible in this image.
- >>> print(generated_texts[1])
- Assistant: The bridge is in San Francisco.
- ```"""
- outputs = self.model(
- input_ids=input_ids,
- pixel_values=pixel_values,
- pixel_mask=pixel_mask,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
- )
- return AriaCausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- inputs_embeds=None,
- pixel_values=None,
- pixel_mask=None,
- attention_mask=None,
- logits_to_keep=None,
- is_first_iteration=False,
- **kwargs,
- ):
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- logits_to_keep=logits_to_keep,
- is_first_iteration=is_first_iteration,
- **kwargs,
- )
- if is_first_iteration or not kwargs.get("use_cache", True):
- # Pixel values are used only in the first iteration if available
- # In subsequent iterations, they are already merged with text and cached
- # NOTE: first iteration doesn't have to be prefill, it can be the first
- # iteration with a question and cached system prompt (continue generate from cache)
- model_inputs["pixel_values"] = pixel_values
- model_inputs["pixel_mask"] = pixel_mask
- return model_inputs
- __all__ = [
- "AriaConfig",
- "AriaTextConfig",
- "AriaImageProcessor",
- "AriaProcessor",
- "AriaForConditionalGeneration",
- "AriaPreTrainedModel",
- "AriaTextPreTrainedModel",
- "AriaTextModel",
- "AriaModel",
- "AriaTextForCausalLM",
- ]
|