modular_aria.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156
  1. # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from huggingface_hub.dataclasses import strict
  16. from torch import nn
  17. from torchvision.transforms.v2 import functional as tvF
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...image_processing_backends import TorchvisionBackend
  23. from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
  24. from ...image_transforms import divide_to_patches
  25. from ...image_utils import (
  26. ChannelDimension,
  27. ImageInput,
  28. PILImageResampling,
  29. SizeDict,
  30. get_image_size,
  31. )
  32. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  33. from ...modeling_outputs import BaseModelOutputWithPooling
  34. from ...modeling_utils import PreTrainedModel
  35. from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  36. from ...tokenization_python import PreTokenizedInput, TextInput
  37. from ...utils import (
  38. TensorType,
  39. TransformersKwargs,
  40. auto_docstring,
  41. can_return_tuple,
  42. logging,
  43. )
  44. from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
  45. from ..llama.configuration_llama import LlamaConfig
  46. from ..llama.modeling_llama import (
  47. LlamaAttention,
  48. LlamaDecoderLayer,
  49. LlamaForCausalLM,
  50. LlamaMLP,
  51. LlamaModel,
  52. LlamaPreTrainedModel,
  53. LlamaRMSNorm,
  54. )
  55. from ..llava.modeling_llava import (
  56. LlavaCausalLMOutputWithPast,
  57. LlavaForConditionalGeneration,
  58. LlavaModel,
  59. LlavaModelOutputWithPast,
  60. )
  61. logger = logging.get_logger(__name__)
  62. def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
  63. """
  64. Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
  65. Args:
  66. token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
  67. expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
  68. tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
  69. Returns:
  70. torch.Tensor: Output tensor of shape (num_tokens, out_features).
  71. """
  72. num_tokens = token_states.shape[0]
  73. out_features = expert_weights.shape[-1]
  74. output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
  75. cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
  76. # Insert zero at the beginning for offset index's convenience
  77. zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
  78. cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
  79. for expert_num in range(expert_weights.shape[0]):
  80. start = cumsum_num_tokens[expert_num]
  81. end = cumsum_num_tokens[expert_num + 1]
  82. tokens = token_states[start:end]
  83. out = torch.matmul(tokens, expert_weights[expert_num])
  84. output[start:end] = out
  85. return output
  86. @auto_docstring(checkpoint="rhymes-ai/Aria")
  87. @strict
  88. class AriaTextConfig(LlamaConfig):
  89. r"""
  90. moe_num_experts (`int`, *optional*, defaults to 8):
  91. The number of experts in the MoE layer.
  92. moe_topk (`int`, *optional*, defaults to 2):
  93. The number of top experts to route to for each token.
  94. moe_num_shared_experts (`int`, *optional*, defaults to 2):
  95. The number of shared experts.
  96. """
  97. model_type = "aria_text"
  98. base_config_key = "text_config"
  99. base_model_tp_plan = {
  100. "layers.*.self_attn.q_proj": "colwise",
  101. "layers.*.self_attn.k_proj": "colwise",
  102. "layers.*.self_attn.v_proj": "colwise",
  103. "layers.*.self_attn.o_proj": "rowwise",
  104. "layers.*.mlp.shared_experts.gate_proj": "colwise",
  105. "layers.*.mlp.shared_experts.up_proj": "colwise",
  106. "layers.*.mlp.shared_experts.down_proj": "rowwise",
  107. }
  108. intermediate_size: int = 4096
  109. moe_num_experts: int = 8
  110. moe_topk: int = 2
  111. moe_num_shared_experts: int = 2
  112. pad_token_id: int | None = 2
  113. @auto_docstring(checkpoint="rhymes-ai/Aria")
  114. @strict
  115. class AriaConfig(PreTrainedConfig):
  116. r"""
  117. projector_patch_to_query_dict (`dict`, *optional*):
  118. Mapping of patch sizes to query dimensions.
  119. """
  120. model_type = "aria"
  121. attribute_map = {
  122. "image_token_id": "image_token_index",
  123. }
  124. sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
  125. vision_config: dict | PreTrainedConfig | None = None
  126. text_config: dict | AriaTextConfig | None = None
  127. vision_feature_layer: int | list[int] = -1
  128. projector_patch_to_query_dict: dict | None = None
  129. image_token_index: int = 9
  130. initializer_range: float = 0.02
  131. tie_word_embeddings: bool = False
  132. def __post_init__(self, **kwargs):
  133. # Convert the keys and values of projector_patch_to_query_dict to integers
  134. # This ensures consistency even if they were provided as strings
  135. if self.projector_patch_to_query_dict is None:
  136. self.projector_patch_to_query_dict = {
  137. 1225: 128,
  138. 4900: 256,
  139. }
  140. self.projector_patch_to_query_dict = {int(k): int(v) for k, v in self.projector_patch_to_query_dict.items()}
  141. self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
  142. if isinstance(self.vision_config, dict):
  143. self.vision_config["model_type"] = "idefics3_vision"
  144. self.vision_config = CONFIG_MAPPING[self.vision_config["model_type"]](**self.vision_config)
  145. elif self.vision_config is None:
  146. self.vision_config = CONFIG_MAPPING["idefics3_vision"]()
  147. if isinstance(self.text_config, dict) and "model_type" in self.text_config:
  148. self.text_config = AriaTextConfig(**self.text_config)
  149. elif self.text_config is None:
  150. self.text_config = AriaTextConfig()
  151. super().__post_init__(**kwargs)
  152. class AriaTextRMSNorm(LlamaRMSNorm):
  153. pass
  154. class AriaProjectorMLP(nn.Module):
  155. """
  156. Feed-Forward Network module for the Aria Projector.
  157. Args:
  158. in_features (`int`):
  159. Input embedding dimension.
  160. hidden_features (`int`):
  161. Hidden dimension of the feed-forward network.
  162. output_dim (`int`):
  163. Output dimension.
  164. """
  165. def __init__(self, in_features, hidden_features, output_dim):
  166. super().__init__()
  167. self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
  168. self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
  169. self.act = ACT2FN["gelu_new"]
  170. def forward(self, hidden_states):
  171. hidden_states = self.act(self.linear_in(hidden_states))
  172. hidden_states = self.linear_out(hidden_states)
  173. return hidden_states
  174. class AriaCrossAttention(nn.Module):
  175. """
  176. Aria Cross-Attention module.
  177. Args:
  178. config (`AriaConfig`):
  179. The configuration to use.
  180. """
  181. def __init__(self, config: AriaConfig, dropout_rate: float = 0):
  182. super().__init__()
  183. hidden_size = config.vision_config.hidden_size
  184. num_heads = config.vision_config.num_attention_heads
  185. self.num_heads = num_heads
  186. self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  187. self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  188. self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  189. # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
  190. self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
  191. self.linear = nn.Linear(hidden_size, hidden_size)
  192. self.dropout = nn.Dropout(dropout_rate)
  193. self.layer_norm = nn.LayerNorm(hidden_size)
  194. self.layer_norm_kv = nn.LayerNorm(hidden_size)
  195. def forward(self, key_value_states, hidden_states, attn_mask=None):
  196. """
  197. Forward pass of the AriaCrossAttention module.
  198. Args:
  199. key_value_states (`torch.Tensor`):
  200. Input tensor for key and value.
  201. hidden_states (`torch.Tensor`):
  202. Input tensor for query.
  203. attn_mask (`torch.Tensor`, *optional*, defaults to None):
  204. Attention mask.
  205. Returns:
  206. torch.Tensor:
  207. Output tensor after cross-attention.
  208. """
  209. query = self.q_proj(self.layer_norm(hidden_states))
  210. key_value_states = self.layer_norm_kv(key_value_states)
  211. key = self.k_proj(key_value_states)
  212. value = self.v_proj(key_value_states)
  213. attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
  214. attn_output = self.dropout(self.linear(attn_output))
  215. return attn_output
  216. class AriaProjector(nn.Module):
  217. """
  218. Aria Projector module.
  219. This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
  220. Args:
  221. config (`AriaConfig`):
  222. Configuration object for the model.
  223. """
  224. def __init__(
  225. self,
  226. config: AriaConfig,
  227. ):
  228. super().__init__()
  229. self.patch_to_query_dict = config.projector_patch_to_query_dict
  230. self.in_features = config.vision_config.hidden_size
  231. self.num_heads = config.vision_config.num_attention_heads
  232. self.kv_dim = config.vision_config.hidden_size
  233. self.hidden_features = config.text_config.hidden_size
  234. self.output_dim = config.text_config.hidden_size
  235. self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
  236. self.cross_attn = AriaCrossAttention(config)
  237. self.layer_norm = nn.LayerNorm(self.in_features)
  238. self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
  239. def forward(self, key_value_states: torch.Tensor, attn_mask: torch.Tensor | None = None):
  240. """
  241. Forward pass of the Projector module.
  242. Args:
  243. key_value_states (`torch.Tensor`):
  244. Input tensor of shape (batch_size, num_patches, kv_dim).
  245. attn_mask (`torch.Tensor`, *optional*, default is None):
  246. Attention mask.
  247. Returns:
  248. `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
  249. """
  250. batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
  251. if num_patches not in self.patch_to_query_dict:
  252. raise KeyError(
  253. f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
  254. )
  255. query_num = self.patch_to_query_dict[num_patches]
  256. queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
  257. if attn_mask is not None:
  258. attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
  259. attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
  260. attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
  261. out = self.feed_forward(self.layer_norm(attention_out))
  262. return out
  263. class AriaImageProcessorKwargs(ImagesKwargs, total=False):
  264. r"""
  265. max_image_size (`int`, *optional*, defaults to `self.max_image_size`):
  266. Maximum image size. Must be either 490 or 980.
  267. min_image_size (`int`, *optional*, defaults to `self.min_image_size`):
  268. Minimum image size. Images smaller than this in any dimension will be scaled up.
  269. split_resolutions (`list[list[int]]`, *optional*, defaults to `self.split_resolutions`):
  270. A list of possible resolutions as (height, width) pairs for splitting high-resolution images into patches.
  271. split_image (`bool`, *optional*, defaults to `self.split_image`):
  272. Whether to split the image into patches using the best matching resolution from `split_resolutions`.
  273. """
  274. max_image_size: int
  275. min_image_size: int
  276. split_resolutions: list[list[int]]
  277. split_image: bool
  278. @auto_docstring
  279. class AriaImageProcessor(TorchvisionBackend):
  280. model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
  281. valid_kwargs = AriaImageProcessorKwargs
  282. resample = PILImageResampling.BICUBIC
  283. image_mean = [0.5, 0.5, 0.5]
  284. image_std = [0.5, 0.5, 0.5]
  285. max_image_size = 980
  286. min_image_size = 336
  287. split_image = False
  288. split_resolutions = None
  289. do_convert_rgb = True
  290. do_rescale = True
  291. do_normalize = True
  292. def __init__(self, **kwargs: Unpack[AriaImageProcessorKwargs]):
  293. if kwargs.get("split_resolutions") is None:
  294. default_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
  295. kwargs["split_resolutions"] = [[el[0] * 490, el[1] * 490] for el in default_resolutions]
  296. super().__init__(**kwargs)
  297. def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple) -> list[int]:
  298. """Get padding size for patching, returns [left, top, right, bottom] for tvF.pad."""
  299. original_height, original_width = original_resolution
  300. target_height, target_width = target_resolution
  301. paste_x, r_x = divmod(target_width - original_width, 2)
  302. paste_y, r_y = divmod(target_height - original_height, 2)
  303. return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
  304. def _resize_for_patching(
  305. self,
  306. image: "torch.Tensor",
  307. target_resolution: tuple,
  308. resample: "PILImageResampling | tvF.InterpolationMode | int | None",
  309. ) -> "torch.Tensor":
  310. """Resize an image to a target resolution while maintaining aspect ratio."""
  311. new_height, new_width = get_patch_output_size(
  312. image, target_resolution, input_data_format=ChannelDimension.FIRST
  313. )
  314. return self.resize(image, SizeDict(height=new_height, width=new_width), resample)
  315. def _pad_for_patching(
  316. self,
  317. image: "torch.Tensor",
  318. target_resolution: tuple,
  319. ) -> "torch.Tensor":
  320. """Pad an image to a target resolution while maintaining aspect ratio."""
  321. new_resolution = get_patch_output_size(image, target_resolution, input_data_format=ChannelDimension.FIRST)
  322. padding = self._get_padding_size(new_resolution, target_resolution)
  323. return tvF.pad(image, padding=padding)
  324. def get_image_patches(
  325. self,
  326. image: "torch.Tensor",
  327. grid_pinpoints: list[list[int]],
  328. patch_size: int,
  329. resample: "PILImageResampling | tvF.InterpolationMode | int | None",
  330. ) -> list["torch.Tensor"]:
  331. """
  332. Process an image with variable resolutions by dividing it into patches.
  333. Args:
  334. image (`torch.Tensor`):
  335. The input image to be processed (channels-first format).
  336. grid_pinpoints (`list[list[int]]`):
  337. A list of possible resolutions as (height, width) pairs.
  338. patch_size (`int`):
  339. Size of each square patch to divide the image into.
  340. resample (`PILImageResampling | tvF.InterpolationMode | int | None`):
  341. Resampling filter to use when resizing.
  342. Returns:
  343. `list[torch.Tensor]`: A list of image patches in channels-first format.
  344. """
  345. if not isinstance(grid_pinpoints, list):
  346. raise TypeError("grid_pinpoints must be a list of possible resolutions.")
  347. image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
  348. best_resolution = select_best_resolution(image_size, grid_pinpoints)
  349. resized_image = self._resize_for_patching(image, best_resolution, resample)
  350. padded_image = self._pad_for_patching(resized_image, best_resolution)
  351. patches = divide_to_patches(padded_image, patch_size=patch_size)
  352. return patches
  353. def _preprocess(
  354. self,
  355. images: list["torch.Tensor"],
  356. do_rescale: bool,
  357. rescale_factor: float,
  358. do_normalize: bool,
  359. image_mean: float | list[float] | None,
  360. image_std: float | list[float] | None,
  361. disable_grouping: bool | None,
  362. return_tensors: str | TensorType | None,
  363. max_image_size: int = 980,
  364. min_image_size: int = 336,
  365. split_resolutions: list[list[int]] | None = None,
  366. split_image: bool = False,
  367. resample: "PILImageResampling | tvF.InterpolationMode | int | None" = None,
  368. **kwargs,
  369. ) -> BatchFeature:
  370. if max_image_size not in [490, 980]:
  371. raise ValueError("max_image_size must be either 490 or 980")
  372. pixel_masks = []
  373. processed_crops = []
  374. num_crops = None
  375. for image in images:
  376. if split_image:
  377. crop_images = self.get_image_patches(image, split_resolutions, max_image_size, resample)
  378. else:
  379. crop_images = [image]
  380. if num_crops is None or len(crop_images) > num_crops:
  381. num_crops = len(crop_images)
  382. for crop_image in crop_images:
  383. h, w = crop_image.shape[-2], crop_image.shape[-1]
  384. scale = max_image_size / max(h, w)
  385. if w >= h:
  386. new_h = max(int(h * scale), min_image_size)
  387. new_w = max_image_size
  388. else:
  389. new_h = max_image_size
  390. new_w = max(int(w * scale), min_image_size)
  391. crop_image = self.resize(crop_image, SizeDict(height=new_h, width=new_w), resample)
  392. padding_bottom = max_image_size - new_h
  393. padding_right = max_image_size - new_w
  394. crop_image = tvF.pad(crop_image, [0, 0, padding_right, padding_bottom])
  395. pixel_mask = torch.zeros((max_image_size, max_image_size), dtype=torch.bool)
  396. pixel_mask[:new_h, :new_w] = True
  397. pixel_masks.append(pixel_mask)
  398. processed_crops.append(crop_image)
  399. stacked_images = torch.stack(processed_crops, dim=0)
  400. stacked_images = self.rescale_and_normalize(
  401. stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
  402. )
  403. stacked_masks = torch.stack(pixel_masks, dim=0)
  404. return BatchFeature(
  405. data={
  406. "pixel_values": stacked_images,
  407. "pixel_mask": stacked_masks,
  408. "num_crops": num_crops,
  409. },
  410. tensor_type=return_tensors,
  411. )
  412. def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
  413. """
  414. A utility that returns number of image patches for a given image size.
  415. Args:
  416. height (`int`):
  417. Height of the input image.
  418. width (`int`):
  419. Width of the input image.
  420. images_kwargs (`dict`, *optional*):
  421. Any kwargs to override defaults of the image processor.
  422. Returns:
  423. `int`: Number of patches per image.
  424. """
  425. split_image = images_kwargs.get("split_image", self.split_image)
  426. max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
  427. resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
  428. num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
  429. return num_patches
  430. class AriaImagesKwargs(ImagesKwargs, total=False):
  431. """
  432. split_image (`bool`, *optional*, defaults to `False`):
  433. Whether to split large images into multiple crops. When enabled, images exceeding the maximum size are
  434. divided into overlapping crops that are processed separately and then combined. This allows processing
  435. of very high-resolution images that exceed the model's input size limits.
  436. max_image_size (`int`, *optional*, defaults to `980`):
  437. Maximum image size (in pixels) for a single image crop. Images larger than this will be split into
  438. multiple crops when `split_image=True`, or resized if splitting is disabled. This parameter controls
  439. the maximum resolution of individual image patches processed by the model.
  440. min_image_size (`int`, *optional*):
  441. Minimum image size (in pixels) for a single image crop. Images smaller than this will be upscaled to
  442. meet the minimum requirement. If not specified, images are processed at their original size (subject
  443. to the maximum size constraint).
  444. """
  445. split_image: bool
  446. max_image_size: int
  447. min_image_size: int
  448. class AriaProcessorKwargs(ProcessingKwargs, total=False):
  449. images_kwargs: AriaImagesKwargs
  450. _defaults = {
  451. "text_kwargs": {
  452. "padding": False,
  453. "return_mm_token_type_ids": False,
  454. },
  455. "images_kwargs": {
  456. "max_image_size": 980,
  457. "split_image": False,
  458. },
  459. "return_tensors": TensorType.PYTORCH,
  460. }
  461. @auto_docstring
  462. class AriaProcessor(ProcessorMixin):
  463. def __init__(
  464. self,
  465. image_processor=None,
  466. tokenizer: AutoTokenizer | str = None,
  467. chat_template: str | None = None,
  468. size_conversion: dict[float | int, int] | None = None,
  469. ):
  470. r"""
  471. size_conversion (`Dict`, *optional*):
  472. A dictionary indicating size conversions for images.
  473. """
  474. if size_conversion is None:
  475. size_conversion = {490: 128, 980: 256}
  476. self.size_conversion = {int(k): v for k, v in size_conversion.items()}
  477. self.image_token = tokenizer.image_token
  478. self.image_token_id = tokenizer.image_token_id
  479. if tokenizer is not None and tokenizer.pad_token is None:
  480. tokenizer.pad_token = tokenizer.unk_token
  481. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  482. @auto_docstring
  483. def __call__(
  484. self,
  485. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
  486. images: ImageInput | None = None,
  487. **kwargs: Unpack[AriaProcessorKwargs],
  488. ) -> BatchFeature:
  489. r"""
  490. Returns:
  491. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  492. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  493. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  494. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  495. `None`).
  496. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  497. - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
  498. """
  499. output_kwargs = self._merge_kwargs(
  500. AriaProcessorKwargs,
  501. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  502. **kwargs,
  503. )
  504. if isinstance(text, str):
  505. text = [text]
  506. elif not isinstance(text, list) and not isinstance(text[0], str):
  507. raise TypeError("Invalid input text. Please provide a string, or a list of strings")
  508. if images is not None:
  509. image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
  510. # expand the image_token according to the num_crops and tokens per image
  511. tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
  512. prompt_strings = []
  513. num_crops = image_inputs.pop("num_crops") * tokens_per_image
  514. for sample in text:
  515. sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
  516. prompt_strings.append(sample)
  517. else:
  518. image_inputs = {}
  519. prompt_strings = text
  520. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  521. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  522. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
  523. self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
  524. if return_mm_token_type_ids:
  525. text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
  526. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  527. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  528. """
  529. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  530. Args:
  531. image_sizes (`list[list[int]]`, *optional*):
  532. The input sizes formatted as (height, width) per each image.
  533. Returns:
  534. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  535. input modalities, along with other useful data.
  536. """
  537. vision_data = {}
  538. if image_sizes is not None:
  539. images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
  540. images_kwargs.update(kwargs)
  541. max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
  542. num_image_patches = [
  543. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  544. for image_size in image_sizes
  545. ]
  546. num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
  547. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  548. return MultiModalData(**vision_data)
  549. @property
  550. def model_input_names(self):
  551. tokenizer_input_names = self.tokenizer.model_input_names
  552. image_processor_input_names = self.image_processor.model_input_names
  553. # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
  554. # otherwise `self.image_processor.model_input_names` is also modified
  555. image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
  556. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
  557. class AriaSharedExpertsMLP(LlamaMLP):
  558. """
  559. Shared Expert MLP for shared experts.
  560. Unlike routed experts, shared experts process all tokens without routing.
  561. This class reconfigures the intermediate size in comparison to the LlamaMLP.
  562. Args:
  563. config (`AriaTextConfig`): Configuration object for the Aria language model.
  564. """
  565. def __init__(self, config: AriaTextConfig):
  566. super().__init__(config)
  567. self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
  568. class AriaGroupedExpertsGemm(nn.Module):
  569. """
  570. Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
  571. This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
  572. for optimized performance. If the grouped_gemm library is not installed, it gracefully
  573. falls back to a sequential GEMM implementation, which may be slower but ensures
  574. functionality.
  575. Args:
  576. in_features (`int`):
  577. Number of input features.
  578. out_features (`int`):
  579. Number of output features.
  580. groups (`int`):
  581. Number of expert groups.
  582. """
  583. def __init__(self, in_features, out_features, groups):
  584. super().__init__()
  585. self.in_features = in_features
  586. self.out_features = out_features
  587. self.groups = groups
  588. self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
  589. def forward(self, input, tokens_per_expert):
  590. """
  591. Perform grouped matrix multiplication.
  592. Args:
  593. input (`torch.Tensor`):
  594. Input tensor of shape (num_tokens, in_features).
  595. tokens_per_expert (`torch.Tensor`):
  596. Number of tokens assigned to each expert.
  597. Returns:
  598. torch.Tensor: Output tensor of shape (num_tokens, out_features).
  599. """
  600. return sequential_experts_gemm(
  601. input,
  602. self.weight,
  603. tokens_per_expert.cpu(),
  604. )
  605. class AriaExperts(nn.Module):
  606. def __init__(self, config: AriaTextConfig) -> None:
  607. super().__init__()
  608. self.config = config
  609. self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
  610. self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
  611. def route_tokens_to_experts(self, router_logits):
  612. top_logits, top_indices = torch.topk(router_logits, k=self.config.moe_topk, dim=1)
  613. scores = nn.functional.softmax(top_logits, dim=-1)
  614. return top_indices, scores
  615. def forward(self, hidden_states, router_logits) -> torch.Tensor:
  616. top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits)
  617. original_dtype = top_k_index.dtype
  618. tokens_per_expert = torch.histc(
  619. top_k_index.flatten().to(torch.float32),
  620. bins=self.config.moe_num_experts,
  621. min=0,
  622. max=self.config.moe_num_experts - 1,
  623. ).to(original_dtype)
  624. indices = top_k_index
  625. flatten_indices = indices.view(-1)
  626. sorted_indices = torch.argsort(flatten_indices)
  627. permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
  628. fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
  629. projection, gate = torch.chunk(fc1_output, 2, dim=-1)
  630. fc1_output = nn.functional.silu(projection) * gate
  631. expert_output = self.fc2(fc1_output, tokens_per_expert)
  632. unpermuted_tokens = torch.zeros(
  633. (top_k_weights.shape[0] * self.config.moe_topk, expert_output.size(1)),
  634. dtype=expert_output.dtype,
  635. device=expert_output.device,
  636. )
  637. unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
  638. unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
  639. output = (unpermuted_tokens * top_k_weights.unsqueeze(-1)).sum(dim=1)
  640. return output
  641. class AriaTextMoELayer(nn.Module):
  642. def __init__(self, config: AriaTextConfig):
  643. super().__init__()
  644. self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
  645. self.experts = AriaExperts(config)
  646. self.shared_experts = AriaSharedExpertsMLP(config)
  647. self.config = config
  648. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  649. original_shape = hidden_states.shape
  650. hidden_states = hidden_states.view(-1, hidden_states.size(-1))
  651. router_logits = self.router(hidden_states)
  652. expert_output = self.experts(hidden_states, router_logits).view(original_shape)
  653. shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
  654. return expert_output + shared_expert_output
  655. class AriaTextAttention(LlamaAttention):
  656. """Multi-headed attention from 'Attention Is All You Need' paper"""
  657. class AriaTextDecoderLayer(LlamaDecoderLayer):
  658. """
  659. Aria Text Decoder Layer.
  660. This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
  661. Args:
  662. config (`AriaTextConfig`):
  663. Configuration object for the text component of the model.
  664. layer_idx (`int`):
  665. Index of the layer.
  666. """
  667. def __init__(self, config: AriaTextConfig, layer_idx: int):
  668. super().__init__(config, layer_idx)
  669. self.mlp = AriaTextMoELayer(config)
  670. @auto_docstring
  671. class AriaTextPreTrainedModel(PreTrainedModel):
  672. config: AriaTextConfig
  673. base_model_prefix = "model"
  674. input_modalities = ("image", "text")
  675. _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
  676. supports_gradient_checkpointing = True
  677. _skip_keys_device_placement = "past_key_values"
  678. _supports_flash_attn = True
  679. _supports_sdpa = True
  680. _supports_attention_backend = True
  681. _can_record_outputs = {
  682. "hidden_states": AriaTextDecoderLayer,
  683. "attentions": AriaTextAttention,
  684. }
  685. @torch.no_grad()
  686. def _init_weights(self, module):
  687. super()._init_weights(module)
  688. if isinstance(module, AriaGroupedExpertsGemm):
  689. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  690. class AriaPreTrainedModel(LlamaPreTrainedModel):
  691. config: AriaConfig
  692. base_model_prefix = "model"
  693. _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
  694. _supports_attention_backend = True
  695. @torch.no_grad()
  696. def _init_weights(self, module):
  697. PreTrainedModel._init_weights(self, module)
  698. if isinstance(module, AriaProjector):
  699. init.trunc_normal_(module.query, std=self.config.initializer_range)
  700. class AriaTextModel(LlamaModel):
  701. def __init__(self, config: AriaTextConfig):
  702. super().__init__(config)
  703. self.layers = nn.ModuleList(
  704. [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  705. )
  706. self.gradient_checkpointing = False
  707. self.post_init()
  708. class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
  709. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  710. def __init__(self, config: AriaTextConfig):
  711. super().__init__(config)
  712. self.model = AriaTextModel(config)
  713. self.vocab_size = config.vocab_size
  714. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  715. # Initialize weights and apply final processing
  716. self.post_init()
  717. @auto_docstring
  718. def forward(self, **super_kwargs):
  719. super().forward(self, **super_kwargs)
  720. class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
  721. pass
  722. class AriaModelOutputWithPast(LlavaModelOutputWithPast):
  723. pass
  724. class AriaModel(LlavaModel):
  725. def __init__(self, config: AriaConfig):
  726. super().__init__(config)
  727. self.multi_modal_projector = AriaProjector(config)
  728. def _create_patch_attention_mask(self, pixel_mask):
  729. if pixel_mask is None:
  730. return None
  731. patches_subgrid = pixel_mask.unfold(
  732. dimension=1,
  733. size=self.vision_tower.config.patch_size,
  734. step=self.vision_tower.config.patch_size,
  735. )
  736. patches_subgrid = patches_subgrid.unfold(
  737. dimension=2,
  738. size=self.vision_tower.config.patch_size,
  739. step=self.vision_tower.config.patch_size,
  740. )
  741. return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  742. def get_image_features(
  743. self,
  744. pixel_values: torch.FloatTensor,
  745. pixel_mask: torch.FloatTensor | None = None,
  746. vision_feature_layer: int | list[int] = -1,
  747. output_hidden_states: bool | None = None,
  748. **kwargs: Unpack[TransformersKwargs],
  749. ) -> tuple | BaseModelOutputWithPooling:
  750. patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
  751. image_outputs = self.vision_tower(
  752. pixel_values,
  753. patch_attention_mask=patch_attention_mask,
  754. output_hidden_states=True, # Ignore arg on purpose
  755. return_dict=True,
  756. **kwargs,
  757. )
  758. image_attn_mask = None
  759. if patch_attention_mask is not None:
  760. flattened_mask = patch_attention_mask.flatten(1)
  761. image_attn_mask = torch.logical_not(flattened_mask)
  762. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  763. image_outputs.pooler_output = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
  764. return image_outputs
  765. def forward(
  766. self,
  767. input_ids: torch.LongTensor | None = None,
  768. pixel_values: torch.FloatTensor | None = None,
  769. pixel_mask: torch.LongTensor | None = None,
  770. attention_mask: torch.Tensor | None = None,
  771. position_ids: torch.LongTensor | None = None,
  772. past_key_values: Cache | None = None,
  773. inputs_embeds: torch.FloatTensor | None = None,
  774. use_cache: bool | None = None,
  775. **kwargs: Unpack[FlashAttentionKwargs],
  776. ) -> tuple | AriaModelOutputWithPast:
  777. if inputs_embeds is None:
  778. inputs_embeds = self.get_input_embeddings()(input_ids)
  779. # 2. Merge text and images
  780. if pixel_values is not None and inputs_embeds.shape[1] != 1:
  781. image_features = self.get_image_features(
  782. pixel_values=pixel_values,
  783. pixel_mask=pixel_mask,
  784. vision_feature_layer=self.config.vision_feature_layer,
  785. return_dict=True,
  786. ).pooler_output
  787. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  788. special_image_mask = self.get_placeholder_mask(
  789. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  790. )
  791. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  792. outputs = self.language_model(
  793. attention_mask=attention_mask,
  794. position_ids=position_ids,
  795. past_key_values=past_key_values,
  796. inputs_embeds=inputs_embeds,
  797. use_cache=use_cache,
  798. **kwargs,
  799. )
  800. return AriaModelOutputWithPast(
  801. last_hidden_state=outputs.last_hidden_state,
  802. past_key_values=outputs.past_key_values if use_cache else None,
  803. hidden_states=outputs.hidden_states,
  804. attentions=outputs.attentions,
  805. image_hidden_states=image_features if pixel_values is not None else None,
  806. )
  807. @auto_docstring(
  808. custom_intro="""
  809. Aria model for conditional generation tasks.
  810. This model combines a vision tower, a multi-modal projector, and a language model
  811. to perform tasks that involve both image and text inputs.
  812. """
  813. )
  814. class AriaForConditionalGeneration(LlavaForConditionalGeneration):
  815. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  816. @auto_docstring
  817. def get_image_features(
  818. self,
  819. pixel_values: torch.FloatTensor,
  820. pixel_mask: torch.FloatTensor | None = None,
  821. vision_feature_layer: int | list[int] = -1,
  822. **kwargs: Unpack[TransformersKwargs],
  823. ) -> tuple | BaseModelOutputWithPooling:
  824. return self.model.get_image_features(
  825. pixel_values=pixel_values,
  826. pixel_mask=pixel_mask,
  827. vision_feature_layer=vision_feature_layer,
  828. **kwargs,
  829. )
  830. @can_return_tuple
  831. @auto_docstring
  832. def forward(
  833. self,
  834. input_ids: torch.LongTensor | None = None,
  835. pixel_values: torch.FloatTensor | None = None,
  836. pixel_mask: torch.LongTensor | None = None,
  837. attention_mask: torch.Tensor | None = None,
  838. position_ids: torch.LongTensor | None = None,
  839. past_key_values: Cache | None = None,
  840. inputs_embeds: torch.FloatTensor | None = None,
  841. labels: torch.LongTensor | None = None,
  842. use_cache: bool | None = None,
  843. logits_to_keep: int | torch.Tensor = 0,
  844. **kwargs: Unpack[TransformersKwargs],
  845. ) -> tuple | AriaCausalLMOutputWithPast:
  846. r"""
  847. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  848. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  849. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
  850. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  851. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  852. Example:
  853. ```python
  854. >>> import httpx
  855. >>> from io import BytesIO
  856. >>> import torch
  857. >>> from PIL import Image
  858. >>> from io import BytesIO
  859. >>> from transformers import AutoProcessor, AutoModel
  860. >>> from transformers.image_utils import load_image
  861. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  862. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  863. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  864. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  865. >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
  866. >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
  867. >>> # Create inputs
  868. >>> messages = [
  869. ... {
  870. ... "role": "user",
  871. ... "content": [
  872. ... {"type": "image"},
  873. ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
  874. ... {"type": "image"},
  875. ... {"type": "text", "text": "What can we see in this image?"},
  876. ... ]
  877. ... },
  878. ... {
  879. ... "role": "user",
  880. ... "content": [
  881. ... {"type": "image"},
  882. ... {"type": "text", "text": "In which city is that bridge located?"},
  883. ... ]
  884. ... }
  885. ... ]
  886. >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
  887. >>> images = [[image1, image2], [image3]]
  888. >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
  889. >>> # Generate
  890. >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
  891. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  892. >>> print(generated_texts[0])
  893. Assistant: There are buildings, trees, lights, and water visible in this image.
  894. >>> print(generated_texts[1])
  895. Assistant: The bridge is in San Francisco.
  896. ```"""
  897. outputs = self.model(
  898. input_ids=input_ids,
  899. pixel_values=pixel_values,
  900. pixel_mask=pixel_mask,
  901. attention_mask=attention_mask,
  902. position_ids=position_ids,
  903. past_key_values=past_key_values,
  904. inputs_embeds=inputs_embeds,
  905. use_cache=use_cache,
  906. **kwargs,
  907. )
  908. hidden_states = outputs[0]
  909. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  910. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  911. logits = self.lm_head(hidden_states[:, slice_indices, :])
  912. loss = None
  913. if labels is not None:
  914. loss = self.loss_function(
  915. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  916. )
  917. return AriaCausalLMOutputWithPast(
  918. loss=loss,
  919. logits=logits,
  920. past_key_values=outputs.past_key_values,
  921. hidden_states=outputs.hidden_states,
  922. attentions=outputs.attentions,
  923. )
  924. def prepare_inputs_for_generation(
  925. self,
  926. input_ids,
  927. past_key_values=None,
  928. inputs_embeds=None,
  929. pixel_values=None,
  930. pixel_mask=None,
  931. attention_mask=None,
  932. logits_to_keep=None,
  933. is_first_iteration=False,
  934. **kwargs,
  935. ):
  936. model_inputs = super().prepare_inputs_for_generation(
  937. input_ids,
  938. past_key_values=past_key_values,
  939. inputs_embeds=inputs_embeds,
  940. attention_mask=attention_mask,
  941. logits_to_keep=logits_to_keep,
  942. is_first_iteration=is_first_iteration,
  943. **kwargs,
  944. )
  945. if is_first_iteration or not kwargs.get("use_cache", True):
  946. # Pixel values are used only in the first iteration if available
  947. # In subsequent iterations, they are already merged with text and cached
  948. # NOTE: first iteration doesn't have to be prefill, it can be the first
  949. # iteration with a question and cached system prompt (continue generate from cache)
  950. model_inputs["pixel_values"] = pixel_values
  951. model_inputs["pixel_mask"] = pixel_mask
  952. return model_inputs
  953. __all__ = [
  954. "AriaConfig",
  955. "AriaTextConfig",
  956. "AriaImageProcessor",
  957. "AriaProcessor",
  958. "AriaForConditionalGeneration",
  959. "AriaPreTrainedModel",
  960. "AriaTextPreTrainedModel",
  961. "AriaTextModel",
  962. "AriaModel",
  963. "AriaTextForCausalLM",
  964. ]