modular_slanext.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from dataclasses import dataclass
  16. import torch
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. import torchvision.transforms.v2.functional as tvF
  20. from huggingface_hub.dataclasses import strict
  21. from ... import initialization as init
  22. from ...activations import ACT2CLS
  23. from ...backbone_utils import filter_output_hidden_states
  24. from ...configuration_utils import PreTrainedConfig
  25. from ...image_processing_backends import TorchvisionBackend
  26. from ...image_processing_utils import BatchFeature
  27. from ...image_transforms import group_images_by_shape, reorder_images
  28. from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, SizeDict
  29. from ...modeling_outputs import BaseModelOutput
  30. from ...modeling_utils import PreTrainedModel
  31. from ...processing_utils import ImagesKwargs, Unpack
  32. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
  33. from ...utils.generic import TensorType, merge_with_config_defaults
  34. from ...utils.import_utils import requires
  35. from ...utils.output_capturing import capture_outputs
  36. from ..got_ocr2.configuration_got_ocr2 import GotOcr2VisionConfig
  37. from ..got_ocr2.modeling_got_ocr2 import (
  38. GotOcr2VisionAttention,
  39. GotOcr2VisionEncoder,
  40. )
  41. logger = logging.get_logger(__name__)
  42. @auto_docstring(checkpoint="PaddlePaddle/SLANeXt_wired_safetensors")
  43. @strict
  44. class SLANeXtVisionConfig(GotOcr2VisionConfig):
  45. image_size: int = 512
  46. class SLANeXtVisionAttention(GotOcr2VisionAttention):
  47. pass
  48. @auto_docstring(checkpoint="PaddlePaddle/SLANeXt_wired_safetensors")
  49. @strict
  50. class SLANeXtConfig(PreTrainedConfig):
  51. r"""
  52. vision_config (`dict` or [`SLANeXtVisionConfig`], *optional*):
  53. Configuration for the vision encoder. If `None`, a default [`SLANeXtVisionConfig`] is used.
  54. post_conv_in_channels (`int`, *optional*, defaults to 256):
  55. Number of input channels for the post-encoder convolution layer.
  56. post_conv_out_channels (`int`, *optional*, defaults to 512):
  57. Number of output channels for the post-encoder convolution layer.
  58. out_channels (`int`, *optional*, defaults to 50):
  59. Vocabulary size for the table structure token prediction head, i.e., the number of distinct structure
  60. tokens the model can predict.
  61. hidden_size (`int`, *optional*, defaults to 512):
  62. Dimensionality of the hidden states in the attention GRU cell and the structure/location prediction heads.
  63. max_text_length (`int`, *optional*, defaults to 500):
  64. Maximum number of autoregressive decoding steps (tokens) for the structure and location decoder.
  65. """
  66. model_type = "slanext"
  67. sub_configs = {"vision_config": SLANeXtVisionConfig}
  68. vision_config: dict | SLANeXtVisionConfig | None = None
  69. post_conv_in_channels: int = 256
  70. post_conv_out_channels: int = 512
  71. out_channels: int = 50
  72. hidden_size: int = 512
  73. max_text_length: int = 500
  74. def __post_init__(self, **kwargs):
  75. if self.vision_config is None:
  76. self.vision_config = SLANeXtVisionConfig()
  77. elif isinstance(self.vision_config, dict):
  78. self.vision_config = SLANeXtVisionConfig(**self.vision_config)
  79. super().__post_init__(**kwargs)
  80. class SLANeXtAttentionGRUCell(nn.Module):
  81. def __init__(self, input_size, hidden_size, num_embeddings):
  82. super().__init__()
  83. self.input_to_hidden = nn.Linear(input_size, hidden_size, bias=False)
  84. self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size)
  85. self.score = nn.Linear(hidden_size, 1, bias=False)
  86. self.rnn = nn.GRUCell(input_size + num_embeddings, hidden_size)
  87. def forward(
  88. self,
  89. prev_hidden: torch.FloatTensor,
  90. batch_hidden: torch.FloatTensor,
  91. char_onehots: torch.FloatTensor,
  92. **kwargs: Unpack[TransformersKwargs],
  93. ):
  94. batch_hidden_proj = self.input_to_hidden(batch_hidden)
  95. prev_hidden_proj = self.hidden_to_hidden(prev_hidden).unsqueeze(1)
  96. attention_scores = batch_hidden_proj + prev_hidden_proj
  97. attention_scores = torch.tanh(attention_scores)
  98. attention_scores = self.score(attention_scores)
  99. attn_weights = F.softmax(attention_scores, dim=1, dtype=torch.float32).to(attention_scores.dtype)
  100. attn_weights = attn_weights.transpose(1, 2)
  101. context = torch.matmul(attn_weights, batch_hidden).squeeze(1)
  102. concat_context = torch.cat([context, char_onehots], 1)
  103. hidden_states = self.rnn(concat_context, prev_hidden)
  104. return hidden_states, attn_weights
  105. class SLANeXtMLP(nn.Module):
  106. def __init__(self, hidden_size, out_channels, activation=None):
  107. super().__init__()
  108. self.fc1 = nn.Linear(hidden_size, hidden_size)
  109. self.fc2 = nn.Linear(hidden_size, out_channels)
  110. self.act_fn = nn.Identity() if activation is None else ACT2CLS[activation]()
  111. def forward(self, hidden_states):
  112. hidden_states = self.fc1(hidden_states)
  113. hidden_states = self.fc2(hidden_states)
  114. hidden_states = self.act_fn(hidden_states)
  115. return hidden_states
  116. class SLANeXtPreTrainedModel(PreTrainedModel):
  117. config: SLANeXtConfig
  118. base_model_prefix = "backbone"
  119. main_input_name = "pixel_values"
  120. input_modalities = ("image",)
  121. supports_gradient_checkpointing = True
  122. _keep_in_fp32_modules_strict = ["structure_attention_cell", "structure_generator"]
  123. @torch.no_grad()
  124. def _init_weights(self, module):
  125. """Initialize the weights"""
  126. super()._init_weights(module)
  127. # Initialize positional embeddings to zero (SLANeXtVisionEncoder holds pos_embed)
  128. if isinstance(module, SLANeXtVisionEncoder):
  129. if module.pos_embed is not None:
  130. init.constant_(module.pos_embed, 0.0)
  131. # Initialize relative positional embeddings to zero (SLANeXtVisionAttention holds rel_pos_h/w)
  132. if isinstance(module, SLANeXtVisionAttention):
  133. if module.use_rel_pos:
  134. init.constant_(module.rel_pos_h, 0.0)
  135. init.constant_(module.rel_pos_w, 0.0)
  136. # Initialize GRUCell (replicates PyTorch default reset_parameters)
  137. if isinstance(module, nn.GRUCell):
  138. std = 1.0 / math.sqrt(module.hidden_size) if module.hidden_size > 0 else 0
  139. init.uniform_(module.weight_ih, -std, std)
  140. init.uniform_(module.weight_hh, -std, std)
  141. if module.bias_ih is not None:
  142. init.uniform_(module.bias_ih, -std, std)
  143. if module.bias_hh is not None:
  144. init.uniform_(module.bias_hh, -std, std)
  145. # Initialize SLAHead layers
  146. if isinstance(module, SLANeXtSLAHead):
  147. std = 1.0 / math.sqrt(self.config.hidden_size * 1.0)
  148. # Initialize structure_generator and loc_generator layers
  149. for generator in (module.structure_generator,):
  150. for layer in generator.children():
  151. if isinstance(layer, nn.Linear):
  152. init.uniform_(layer.weight, -std, std)
  153. if layer.bias is not None:
  154. init.uniform_(layer.bias, -std, std)
  155. class SLANeXtVisionEncoder(GotOcr2VisionEncoder):
  156. pass
  157. class SLANeXtBackbone(SLANeXtPreTrainedModel):
  158. def __init__(
  159. self,
  160. config: dict | None = None,
  161. **kwargs,
  162. ):
  163. super().__init__(config)
  164. self.vision_tower = SLANeXtVisionEncoder(config.vision_config)
  165. self.post_conv = nn.Conv2d(
  166. config.post_conv_in_channels, config.post_conv_out_channels, kernel_size=3, stride=2, padding=1, bias=False
  167. )
  168. self.post_init()
  169. def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]):
  170. vision_output = self.vision_tower(hidden_states, **kwargs)
  171. hidden_states = self.post_conv(vision_output.last_hidden_state)
  172. hidden_states = hidden_states.flatten(2).transpose(1, 2)
  173. return BaseModelOutput(
  174. last_hidden_state=hidden_states,
  175. hidden_states=vision_output.hidden_states,
  176. attentions=vision_output.attentions,
  177. )
  178. class SLANeXtSLAHead(SLANeXtPreTrainedModel):
  179. _can_record_outputs = {
  180. "attentions": SLANeXtAttentionGRUCell,
  181. }
  182. def __init__(
  183. self,
  184. config: dict | None = None,
  185. **kwargs,
  186. ):
  187. super().__init__(config)
  188. self.structure_attention_cell = SLANeXtAttentionGRUCell(
  189. config.post_conv_out_channels, config.hidden_size, config.out_channels
  190. )
  191. self.structure_generator = SLANeXtMLP(config.hidden_size, config.out_channels)
  192. self.post_init()
  193. @merge_with_config_defaults
  194. @capture_outputs
  195. @filter_output_hidden_states
  196. def forward(
  197. self,
  198. hidden_states: torch.FloatTensor,
  199. targets: torch.Tensor | None = None,
  200. **kwargs: Unpack[TransformersKwargs],
  201. ):
  202. features = torch.zeros(
  203. (hidden_states.shape[0], self.config.hidden_size), dtype=torch.float32, device=hidden_states.device
  204. )
  205. predicted_chars = torch.zeros(size=[hidden_states.shape[0]], dtype=torch.long, device=hidden_states.device)
  206. structure_preds_list = []
  207. structure_ids_list = []
  208. for _ in range(self.config.max_text_length + 1):
  209. embedding_feature = F.one_hot(predicted_chars, self.config.out_channels).float()
  210. features, _ = self.structure_attention_cell(features, hidden_states.float(), embedding_feature)
  211. structure_step = self.structure_generator(features)
  212. predicted_chars = structure_step.argmax(dim=1)
  213. structure_preds_list.append(structure_step)
  214. structure_ids_list.append(predicted_chars)
  215. if torch.stack(structure_ids_list, dim=1).eq(self.config.out_channels - 1).any(-1).all():
  216. break
  217. structure_preds = F.softmax(torch.stack(structure_preds_list, dim=1), dim=-1, dtype=torch.float32).to(
  218. hidden_states.dtype
  219. )
  220. return BaseModelOutput(last_hidden_state=structure_preds, hidden_states=structure_preds_list)
  221. @dataclass
  222. @auto_docstring
  223. class SLANeXtForTableRecognitionOutput(BaseModelOutput):
  224. r"""
  225. head_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  226. Hidden-states of the SLANeXtSLAHead at each prediction step, varies up to max `self.config.max_text_length` states (depending on early exits).
  227. head_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  228. Attentions of the SLANeXtSLAHead at each prediction step, varies up to max `self.config.max_text_length` attentions (depending on early exits).
  229. """
  230. head_hidden_states: torch.FloatTensor | None = None
  231. head_attentions: torch.FloatTensor | None = None
  232. @auto_docstring(
  233. custom_intro="""
  234. SLANeXt Table Recognition model for table recognition tasks. Wraps the core SLANeXtPreTrainedModel
  235. and returns outputs compatible with the Transformers table recognition API.
  236. """
  237. )
  238. class SLANeXtForTableRecognition(SLANeXtPreTrainedModel):
  239. def __init__(self, config: SLANeXtConfig):
  240. super().__init__(config)
  241. self.backbone = SLANeXtBackbone(config=config)
  242. self.head = SLANeXtSLAHead(config=config)
  243. self.post_init()
  244. @can_return_tuple
  245. @auto_docstring
  246. def forward(
  247. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  248. ) -> tuple[torch.FloatTensor] | SLANeXtForTableRecognitionOutput:
  249. backbone_outputs = self.backbone(pixel_values, **kwargs)
  250. head_outputs = self.head(backbone_outputs.last_hidden_state, **kwargs)
  251. return SLANeXtForTableRecognitionOutput(
  252. last_hidden_state=head_outputs.last_hidden_state,
  253. hidden_states=backbone_outputs.hidden_states,
  254. attentions=backbone_outputs.attentions,
  255. head_hidden_states=head_outputs.hidden_states,
  256. head_attentions=head_outputs.attentions,
  257. )
  258. @auto_docstring
  259. @requires(backends=("torch",))
  260. class SLANeXtImageProcessor(TorchvisionBackend):
  261. resample = 2 # PILImageResampling.BILINEAR
  262. image_mean = IMAGENET_DEFAULT_MEAN
  263. image_std = IMAGENET_DEFAULT_STD
  264. size = {"height": 512, "width": 512}
  265. pad_size = {"height": 512, "width": 512}
  266. do_convert_rgb = True
  267. do_resize = True
  268. do_rescale = True
  269. do_normalize = True
  270. do_pad = True
  271. def _resize(
  272. self,
  273. image: "torch.Tensor",
  274. size: SizeDict,
  275. ) -> "torch.Tensor":
  276. batch_size, channels, height, width = image.shape
  277. image = image.view(batch_size * channels, height, width)
  278. device = image.device
  279. scale = max(size.height, size.width) / max(height, width)
  280. target_height = round(height * scale)
  281. target_width = round(width * scale)
  282. target_col = torch.arange(target_width, dtype=torch.float32, device=device)
  283. src_col = (target_col + 0.5) * (float(width) / float(target_width)) - 0.5
  284. src_col_floor = src_col.floor().to(torch.int32)
  285. src_col_frac = src_col - src_col_floor.float()
  286. # boundary handling
  287. src_col_frac = torch.where(src_col_floor < 0, torch.zeros_like(src_col_frac), src_col_frac)
  288. src_col_floor = torch.where(src_col_floor < 0, torch.zeros_like(src_col_floor), src_col_floor)
  289. src_col_frac = torch.where(src_col_floor >= width - 1, torch.ones_like(src_col_frac), src_col_frac)
  290. src_col_floor = torch.where(
  291. src_col_floor >= width - 1, torch.full_like(src_col_floor, width - 2), src_col_floor
  292. )
  293. # fixed-point weights
  294. weight_right = (src_col_frac * 2048 + 0.5).floor().to(torch.int32) # round-to-nearest
  295. weight_left = 2048 - weight_right # (target_w,)
  296. # --- row coordinate tables ---
  297. target_row = torch.arange(target_height, dtype=torch.float32, device=device)
  298. src_row = (target_row + 0.5) * (float(height) / float(target_height)) - 0.5
  299. src_row_floor = src_row.floor().to(torch.int32)
  300. src_row_frac = src_row - src_row_floor.float()
  301. src_row_frac = torch.where(src_row_floor < 0, torch.zeros_like(src_row_frac), src_row_frac)
  302. src_row_floor = torch.where(src_row_floor < 0, torch.zeros_like(src_row_floor), src_row_floor)
  303. src_row_frac = torch.where(src_row_floor >= height - 1, torch.ones_like(src_row_frac), src_row_frac)
  304. src_row_floor = torch.where(
  305. src_row_floor >= height - 1, torch.full_like(src_row_floor, height - 2), src_row_floor
  306. )
  307. weight_bottom = (src_row_frac * 2048 + 0.5).floor().to(torch.int32)
  308. weight_top = 2048 - weight_bottom # (target_h,)
  309. image_uint8 = image.clamp(0, 255).to(torch.uint8) # (C, H, W)
  310. image_int32 = image_uint8.to(torch.int32) # (C, H, W)
  311. col_left = src_col_floor.long() # (target_w,)
  312. col_right = (src_col_floor + 1).long() # (target_w,) safe: src_col_floor <= width-2
  313. row_top = src_row_floor.long() # (target_h,)
  314. row_bottom = (src_row_floor + 1).long() # (target_h,)
  315. # gather 4 neighbours: (C, target_h, target_w)
  316. pixel_top_left = image_int32[:, row_top[:, None], col_left[None, :]]
  317. pixel_top_right = image_int32[:, row_top[:, None], col_right[None, :]]
  318. pixel_bottom_left = image_int32[:, row_bottom[:, None], col_left[None, :]]
  319. pixel_bottom_right = image_int32[:, row_bottom[:, None], col_right[None, :]]
  320. # fixed-point bilinear: weights broadcast over (C, target_h, target_w)
  321. weight_bottom_3d = weight_bottom.view(1, target_height, 1)
  322. weight_top_3d = weight_top.view(1, target_height, 1)
  323. weight_right_3d = weight_right.view(1, 1, target_width)
  324. weight_left_3d = weight_left.view(1, 1, target_width)
  325. interp = weight_top_3d * (
  326. weight_left_3d * pixel_top_left + weight_right_3d * pixel_top_right
  327. ) + weight_bottom_3d * (weight_left_3d * pixel_bottom_left + weight_right_3d * pixel_bottom_right)
  328. interp = (interp + (1 << 21)) >> 22
  329. result = interp.clamp(0, 255).to(torch.uint8) # (B*C, target_h, target_w)
  330. return result.view(batch_size, channels, target_height, target_width).to(dtype=image.dtype)
  331. def _preprocess(
  332. self,
  333. images: list["torch.Tensor"],
  334. do_resize: bool,
  335. size: SizeDict,
  336. resample: "tvF.InterpolationMode | int | None",
  337. do_center_crop: bool,
  338. crop_size: SizeDict,
  339. do_rescale: bool,
  340. rescale_factor: float,
  341. do_normalize: bool,
  342. image_mean: float | list[float] | None,
  343. image_std: float | list[float] | None,
  344. do_pad: bool | None,
  345. pad_size: SizeDict | None,
  346. disable_grouping: bool | None,
  347. return_tensors: str | TensorType | None,
  348. **kwargs,
  349. ) -> BatchFeature:
  350. if resample is not None and not is_torchdynamo_compiling():
  351. logger.warning_once("Resampling is not supported in SLANeXt")
  352. # Group images by size for batched resizing
  353. grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
  354. resized_images_grouped = {}
  355. for shape, stacked_images in grouped_images.items():
  356. if do_resize:
  357. stacked_images = self._resize(image=stacked_images, size=size)
  358. resized_images_grouped[shape] = stacked_images
  359. resized_images = reorder_images(resized_images_grouped, grouped_images_index)
  360. # Group images by size for further processing
  361. # Needed in case do_resize is False, or resize returns images with different sizes
  362. grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
  363. processed_images_grouped = {}
  364. for shape, stacked_images in grouped_images.items():
  365. if do_center_crop:
  366. stacked_images = self.center_crop(stacked_images, crop_size)
  367. # Fused rescale and normalize
  368. stacked_images = self.rescale_and_normalize(
  369. stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
  370. )
  371. processed_images_grouped[shape] = stacked_images
  372. processed_images = reorder_images(processed_images_grouped, grouped_images_index)
  373. if do_pad:
  374. processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
  375. return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
  376. def __init__(self, **kwargs: Unpack[ImagesKwargs]):
  377. super().__init__(**kwargs)
  378. self.init_decoder()
  379. def init_decoder(self):
  380. """
  381. Initialize the decoder vocabulary for table structure recognition.
  382. Builds a character dictionary mapping HTML table structure tokens (e.g., `<thead>`, `<tr>`, `<td>`, colspan/
  383. rowspan attributes) to integer indices. The dictionary includes special `"sos"` (start-of-sequence) and
  384. `"eos"` (end-of-sequence) tokens. Merged `<td></td>` tokens are used in place of standalone `<td>` tokens
  385. when applicable.
  386. """
  387. dict_character = [
  388. "<thead>",
  389. "</thead>",
  390. "<tbody>",
  391. "</tbody>",
  392. "<tr>",
  393. "</tr>",
  394. "<td>",
  395. "<td",
  396. ">",
  397. "</td>",
  398. ]
  399. dict_character += [f' colspan="{i + 2}"' for i in range(19)]
  400. dict_character += [f' rowspan="{i + 2}"' for i in range(19)]
  401. if "<td></td>" not in dict_character:
  402. dict_character.append("<td></td>")
  403. if "<td>" in dict_character:
  404. dict_character.remove("<td>")
  405. dict_character = ["sos"] + dict_character + ["eos"]
  406. self.dict = {char: i for i, char in enumerate(dict_character)}
  407. self.character = dict_character
  408. self.td_token = ["<td>", "<td", "<td></td>"]
  409. self.bos_id = self.dict["sos"]
  410. self.eos_id = self.dict["eos"]
  411. def post_process_table_recognition(self, outputs):
  412. """
  413. Post-process the raw model outputs to decode the predicted table structure into an HTML token sequence.
  414. Converts the model's predicted probability distributions over the structure vocabulary into a sequence of
  415. HTML tokens representing the table structure. The decoded tokens are wrapped with `<html>`, `<body>`, and
  416. `<table>` tags to form a complete HTML table structure.
  417. Args:
  418. outputs ([`SLANeXtForTableRecognitionOutput`]):
  419. Raw outputs from the SLANeXt model. The `last_hidden_state` field contains the predicted probability
  420. distributions over the structure vocabulary at each decoding step, with shape
  421. `(batch_size, max_text_length, num_classes)`.
  422. Returns:
  423. `dict`: A dictionary containing:
  424. - **structure** (`list[str]`): The predicted HTML table structure as a list of tokens, wrapped with
  425. `<html>`, `<body>`, and `<table>` tags.
  426. - **structure_score** (`float`): The mean confidence score across all predicted tokens.
  427. """
  428. self.pred = outputs.last_hidden_state
  429. structure_probs = self.pred[0:1]
  430. ignored_tokens = [int(self.bos_id), int(self.eos_id)]
  431. end_idx = int(self.eos_id)
  432. structure_idx = structure_probs.argmax(dim=2)
  433. structure_probs = structure_probs.max(dim=2).values
  434. structure_str_list = []
  435. batch_size = structure_idx.shape[0]
  436. for batch_index in range(batch_size):
  437. structure_list = []
  438. score_list = []
  439. for position in range(structure_idx.shape[1]):
  440. char_idx = int(structure_idx[batch_index, position])
  441. if position > 0 and char_idx == end_idx:
  442. break
  443. if char_idx in ignored_tokens:
  444. continue
  445. text = self.character[char_idx]
  446. structure_list.append(text)
  447. score_list.append(structure_probs[batch_index, position])
  448. structure_str_list.append(structure_list)
  449. structure_score = torch.stack(score_list).mean().item()
  450. structure = ["<html>", "<body>", "<table>"] + structure_str_list[0] + ["</table>", "</body>", "</html>"]
  451. return {"structure": structure, "structure_score": structure_score}
  452. __all__ = [
  453. "SLANeXtImageProcessor",
  454. "SLANeXtConfig",
  455. "SLANeXtSLAHead",
  456. "SLANeXtBackbone",
  457. "SLANeXtForTableRecognition",
  458. "SLANeXtPreTrainedModel",
  459. ]