# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/slanext/modular_slanext.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_slanext.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from ... import initialization as init from ...activations import ACT2CLS, ACT2FN from ...backbone_utils import filter_output_hidden_states from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_slanext import SLANeXtConfig, SLANeXtVisionConfig class SLANeXtVisionAttention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__(self, config, window_size): super().__init__() input_size = ( (config.image_size // config.patch_size, config.image_size // config.patch_size) if window_size == 0 else (window_size, window_size) ) self.num_attention_heads = config.num_attention_heads head_dim = config.hidden_size // config.num_attention_heads self.scale = head_dim**-0.5 self.dropout = config.attention_dropout self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size) self.use_rel_pos = config.use_rel_pos if self.use_rel_pos: if input_size is None: raise ValueError("Input size must be provided if using relative positional encoding.") # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of the query. k_size (int): size of key k. rel_pos (`torch.Tensor`): relative position embeddings (L, channel). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def get_decomposed_rel_pos( self, query: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: tuple[int, int], k_size: tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py Args: query (`torch.Tensor`): query q in the attention layer with shape (batch_size, query_height * query_width, channel). rel_pos_h (`torch.Tensor`): relative position embeddings (Lh, channel) for height axis. rel_pos_w (`torch.Tensor`): relative position embeddings (Lw, channel) for width axis. q_size (tuple): spatial sequence size of query q with (query_height, query_width). k_size (tuple): spatial sequence size of key k with (key_height, key_width). Returns: decomposed_rel_pos (`torch.Tensor`): decomposed relative position embeddings. """ query_height, query_width = q_size key_height, key_width = k_size relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) batch_size, _, dim = query.shape reshaped_query = query.reshape(batch_size, query_height, query_width, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] return decomposed_rel_pos def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) qkv = ( self.qkv(hidden_states) .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) .permute(2, 0, 3, 1, 4) ) # q, k, v with shape (batch_size * nHead, height * width, channel) query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: decomposed_rel_pos = self.get_decomposed_rel_pos( query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) ) decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) attn_weights = attn_weights + decomposed_rel_pos attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) attn_output = self.proj(attn_output) return attn_output, attn_weights class SLANeXtAttentionGRUCell(nn.Module): def __init__(self, input_size, hidden_size, num_embeddings): super().__init__() self.input_to_hidden = nn.Linear(input_size, hidden_size, bias=False) self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size) self.score = nn.Linear(hidden_size, 1, bias=False) self.rnn = nn.GRUCell(input_size + num_embeddings, hidden_size) def forward( self, prev_hidden: torch.FloatTensor, batch_hidden: torch.FloatTensor, char_onehots: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs], ): batch_hidden_proj = self.input_to_hidden(batch_hidden) prev_hidden_proj = self.hidden_to_hidden(prev_hidden).unsqueeze(1) attention_scores = batch_hidden_proj + prev_hidden_proj attention_scores = torch.tanh(attention_scores) attention_scores = self.score(attention_scores) attn_weights = F.softmax(attention_scores, dim=1, dtype=torch.float32).to(attention_scores.dtype) attn_weights = attn_weights.transpose(1, 2) context = torch.matmul(attn_weights, batch_hidden).squeeze(1) concat_context = torch.cat([context, char_onehots], 1) hidden_states = self.rnn(concat_context, prev_hidden) return hidden_states, attn_weights class SLANeXtMLP(nn.Module): def __init__(self, hidden_size, out_channels, activation=None): super().__init__() self.fc1 = nn.Linear(hidden_size, hidden_size) self.fc2 = nn.Linear(hidden_size, out_channels) self.act_fn = nn.Identity() if activation is None else ACT2CLS[activation]() def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = self.act_fn(hidden_states) return hidden_states class SLANeXtPreTrainedModel(PreTrainedModel): config: SLANeXtConfig base_model_prefix = "backbone" main_input_name = "pixel_values" input_modalities = ("image",) supports_gradient_checkpointing = True _keep_in_fp32_modules_strict = ["structure_attention_cell", "structure_generator"] @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" super()._init_weights(module) # Initialize positional embeddings to zero (SLANeXtVisionEncoder holds pos_embed) if isinstance(module, SLANeXtVisionEncoder): if module.pos_embed is not None: init.constant_(module.pos_embed, 0.0) # Initialize relative positional embeddings to zero (SLANeXtVisionAttention holds rel_pos_h/w) if isinstance(module, SLANeXtVisionAttention): if module.use_rel_pos: init.constant_(module.rel_pos_h, 0.0) init.constant_(module.rel_pos_w, 0.0) # Initialize GRUCell (replicates PyTorch default reset_parameters) if isinstance(module, nn.GRUCell): std = 1.0 / math.sqrt(module.hidden_size) if module.hidden_size > 0 else 0 init.uniform_(module.weight_ih, -std, std) init.uniform_(module.weight_hh, -std, std) if module.bias_ih is not None: init.uniform_(module.bias_ih, -std, std) if module.bias_hh is not None: init.uniform_(module.bias_hh, -std, std) # Initialize SLAHead layers if isinstance(module, SLANeXtSLAHead): std = 1.0 / math.sqrt(self.config.hidden_size * 1.0) # Initialize structure_generator and loc_generator layers for generator in (module.structure_generator,): for layer in generator.children(): if isinstance(layer, nn.Linear): init.uniform_(layer.weight, -std, std) if layer.bias is not None: init.uniform_(layer.bias, -std, std) class SLANeXtMLPBlock(nn.Module): def __init__(self, config): super().__init__() self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) self.act = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.lin1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.lin2(hidden_states) return hidden_states class SLANeXtVisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn = SLANeXtVisionAttention(config, window_size) self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SLANeXtMLPBlock(config) self.window_size = window_size def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]: """ Args: Partition into non-overlapping windows with padding if needed. hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window size. Returns: windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. (pad_height, pad_width): padded height and width before partition """ batch_size, height, width, channel = hidden_states.shape pad_h = (window_size - height % window_size) % window_size pad_w = (window_size - width % window_size) % window_size hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) pad_height, pad_width = height + pad_h, width + pad_w hidden_states = hidden_states.reshape( batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel ) windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) return windows, (pad_height, pad_width) def window_unpartition( self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int] ) -> torch.Tensor: """ Args: Window unpartition into original sequences and removing padding. hidden_states (tensor): input tokens with [batch_size * num_windows, window_size, window_size, channel]. window_size (int): window size. padding_shape (Tuple): padded height and width (pad_height, pad_width). original_shape (Tuple): original height and width (height, width) before padding. Returns: hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. """ pad_height, pad_width = padding_shape height, width = original_shape batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) hidden_states = windows.reshape( batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 ) hidden_states = ( hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) ) hidden_states = hidden_states[:, :height, :width, :].contiguous() return hidden_states def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) # Window partition if self.window_size > 0: height, width = hidden_states.shape[1], hidden_states.shape[2] hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) hidden_states, attn_weights = self.attn( hidden_states=hidden_states, ) # Reverse window partition if self.window_size > 0: hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) hidden_states = residual + hidden_states layernorm_output = self.layer_norm2(hidden_states) hidden_states = hidden_states + self.mlp(layernorm_output) return hidden_states @dataclass @auto_docstring( custom_intro=""" Base class for slanext vision model's outputs that also contains image embeddings obtained by applying the projection layer to the pooler_output. """ ) class SLANeXtVisionEncoderOutput(ModelOutput): r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. """ image_embeds: torch.FloatTensor | None = None last_hidden_state: torch.FloatTensor | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None class SLANeXtPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a Transformer. """ def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) return embeddings class SLANeXtLayerNorm(nn.LayerNorm): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): super().__init__(normalized_shape, eps=eps, **kwargs) if data_format not in ["channels_last", "channels_first"]: raise NotImplementedError(f"Unsupported data format: {data_format}") self.data_format = data_format def forward(self, features: torch.Tensor) -> torch.Tensor: """ Args: features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) """ if self.data_format == "channels_first": features = features.permute(0, 2, 3, 1) features = super().forward(features) features = features.permute(0, 3, 1, 2) else: features = super().forward(features) return features class SLANeXtVisionNeck(nn.Module): def __init__(self, config: SLANeXtVisionConfig): super().__init__() self.config = config self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) self.layer_norm1 = SLANeXtLayerNorm(config.output_channels, data_format="channels_first") self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) self.layer_norm2 = SLANeXtLayerNorm(config.output_channels, data_format="channels_first") def forward(self, hidden_states): hidden_states = hidden_states.permute(0, 3, 1, 2) hidden_states = self.conv1(hidden_states) hidden_states = self.layer_norm1(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.layer_norm2(hidden_states) return hidden_states class SLANeXtVisionEncoder(SLANeXtPreTrainedModel): _can_record_outputs = {"hidden_states": SLANeXtVisionLayer, "attentions": SLANeXtVisionAttention} input_modalities = ("image",) def __init__(self, config: SLANeXtVisionConfig): super().__init__(config) self.config = config self.image_size = config.image_size self.patch_embed = SLANeXtPatchEmbeddings(config) self.pos_embed = None if config.use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros( 1, config.image_size // config.patch_size, config.image_size // config.patch_size, config.hidden_size, ) ) self.layers = nn.ModuleList() for i in range(config.num_hidden_layers): layer = SLANeXtVisionLayer( config, window_size=config.window_size if i not in config.global_attn_indexes else 0, ) self.layers.append(layer) self.neck = SLANeXtVisionNeck(config) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.patch_embed @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) def forward( self, pixel_values: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs] ) -> tuple | SLANeXtVisionEncoderOutput: if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.patch_embed(pixel_values) if self.pos_embed is not None: hidden_states = hidden_states + self.pos_embed for layer_module in self.layers: hidden_states = layer_module(hidden_states) hidden_states = self.neck(hidden_states) return SLANeXtVisionEncoderOutput( last_hidden_state=hidden_states, ) class SLANeXtBackbone(SLANeXtPreTrainedModel): def __init__( self, config: dict | None = None, **kwargs, ): super().__init__(config) self.vision_tower = SLANeXtVisionEncoder(config.vision_config) self.post_conv = nn.Conv2d( config.post_conv_in_channels, config.post_conv_out_channels, kernel_size=3, stride=2, padding=1, bias=False ) self.post_init() def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]): vision_output = self.vision_tower(hidden_states, **kwargs) hidden_states = self.post_conv(vision_output.last_hidden_state) hidden_states = hidden_states.flatten(2).transpose(1, 2) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=vision_output.hidden_states, attentions=vision_output.attentions, ) class SLANeXtSLAHead(SLANeXtPreTrainedModel): _can_record_outputs = { "attentions": SLANeXtAttentionGRUCell, } def __init__( self, config: dict | None = None, **kwargs, ): super().__init__(config) self.structure_attention_cell = SLANeXtAttentionGRUCell( config.post_conv_out_channels, config.hidden_size, config.out_channels ) self.structure_generator = SLANeXtMLP(config.hidden_size, config.out_channels) self.post_init() @merge_with_config_defaults @capture_outputs @filter_output_hidden_states def forward( self, hidden_states: torch.FloatTensor, targets: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ): features = torch.zeros( (hidden_states.shape[0], self.config.hidden_size), dtype=torch.float32, device=hidden_states.device ) predicted_chars = torch.zeros(size=[hidden_states.shape[0]], dtype=torch.long, device=hidden_states.device) structure_preds_list = [] structure_ids_list = [] for _ in range(self.config.max_text_length + 1): embedding_feature = F.one_hot(predicted_chars, self.config.out_channels).float() features, _ = self.structure_attention_cell(features, hidden_states.float(), embedding_feature) structure_step = self.structure_generator(features) predicted_chars = structure_step.argmax(dim=1) structure_preds_list.append(structure_step) structure_ids_list.append(predicted_chars) if torch.stack(structure_ids_list, dim=1).eq(self.config.out_channels - 1).any(-1).all(): break structure_preds = F.softmax(torch.stack(structure_preds_list, dim=1), dim=-1, dtype=torch.float32).to( hidden_states.dtype ) return BaseModelOutput(last_hidden_state=structure_preds, hidden_states=structure_preds_list) @dataclass @auto_docstring class SLANeXtForTableRecognitionOutput(BaseModelOutput): r""" head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Hidden-states of the SLANeXtSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits). head_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Attentions of the SLANeXtSLAHead at each prediction step, varies up to max `self.config.max_text_length` attentions (depending on early exits). """ head_hidden_states: torch.FloatTensor | None = None head_attentions: torch.FloatTensor | None = None @auto_docstring( custom_intro=""" SLANeXt Table Recognition model for table recognition tasks. Wraps the core SLANeXtPreTrainedModel and returns outputs compatible with the Transformers table recognition API. """ ) class SLANeXtForTableRecognition(SLANeXtPreTrainedModel): def __init__(self, config: SLANeXtConfig): super().__init__(config) self.backbone = SLANeXtBackbone(config=config) self.head = SLANeXtSLAHead(config=config) self.post_init() @can_return_tuple @auto_docstring def forward( self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple[torch.FloatTensor] | SLANeXtForTableRecognitionOutput: backbone_outputs = self.backbone(pixel_values, **kwargs) head_outputs = self.head(backbone_outputs.last_hidden_state, **kwargs) return SLANeXtForTableRecognitionOutput( last_hidden_state=head_outputs.last_hidden_state, hidden_states=backbone_outputs.hidden_states, attentions=backbone_outputs.attentions, head_hidden_states=head_outputs.hidden_states, head_attentions=head_outputs.attentions, ) __all__ = ["SLANeXtSLAHead", "SLANeXtBackbone", "SLANeXtForTableRecognition", "SLANeXtPreTrainedModel"]