modular_uvdoc.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. # Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from huggingface_hub.dataclasses import strict
  19. from ...activations import ACT2FN
  20. from ...backbone_utils import (
  21. BackboneConfigMixin,
  22. BackboneMixin,
  23. consolidate_backbone_kwargs_to_config,
  24. filter_output_hidden_states,
  25. )
  26. from ...configuration_utils import PreTrainedConfig
  27. from ...feature_extraction_utils import BatchFeature
  28. from ...image_processing_backends import TorchvisionBackend
  29. from ...image_transforms import group_images_by_shape, reorder_images
  30. from ...image_utils import PILImageResampling, SizeDict
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention
  33. from ...modeling_utils import PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import TensorType, merge_with_config_defaults
  37. from ...utils.import_utils import requires
  38. from ...utils.output_capturing import capture_outputs
  39. from ..auto import AutoConfig
  40. from ..pp_lcnet.modeling_pp_lcnet import PPLCNetConvLayer
  41. from ..pp_ocrv5_server_det.modeling_pp_ocrv5_server_det import PPOCRV5ServerDetPreTrainedModel
  42. @auto_docstring(checkpoint="PaddlePaddle/UVDoc_safetensors")
  43. @strict
  44. class UVDocBackboneConfig(BackboneConfigMixin, PreTrainedConfig):
  45. r"""
  46. resnet_head (`Sequence[list[int] | tuple[int, ...]]`, *optional*, defaults to `((3, 32), (32, 32))`):
  47. Configuration for the ResNet head layers in format [in_channels, out_channels].
  48. resnet_configs (`Sequence[Sequence[tuple[int, int, int, bool] | list[int | bool]]]`, *optional*, defaults to `(((32, 32, 1, False),
  49. (32, 32, 3, False), (32, 32, 3, False)), ((32, 64, 1, True), (64, 64, 3, False), (64, 64, 3, False), (64, 64, 3, False)), ((64, 128, 1, True),
  50. (128, 128, 3, False), (128, 128, 3, False), (128, 128, 3, False), (128, 128, 3, False), (128, 128, 3, False)))`):
  51. Configuration for the ResNet stages in format [in_channels, out_channels, dilation_value, downsample].
  52. stage_configs (Sequence[Sequence[tuple[int, ...] | list[int]]], *optional*, defaults to `(((128, 1),), ((128, 2),),
  53. ((128, 5),), ((128, 8),(128, 3),(128, 2),), ((128, 12), (128, 7), (128, 4),), ((128, 18), (128, 12), (128, 6),),)`):
  54. Configuration for the bridge module stages in format [in_channels, dilation_value].
  55. Each inner sequence corresponds to a single bridge block, and the outer sequence groups blocks by bridge stage.
  56. """
  57. model_type = "uvdoc_backbone"
  58. _out_features: list[str] | None = None
  59. _out_indices: list[int] | None = None
  60. resnet_head: Sequence[list[int] | tuple[int, ...]] = (
  61. (3, 32),
  62. (32, 32),
  63. )
  64. resnet_configs: Sequence[Sequence[tuple[int, int, int, bool] | list[int | bool]]] = (
  65. (
  66. (32, 32, 1, False),
  67. (32, 32, 3, False),
  68. (32, 32, 3, False),
  69. ),
  70. (
  71. (32, 64, 1, True),
  72. (64, 64, 3, False),
  73. (64, 64, 3, False),
  74. (64, 64, 3, False),
  75. ),
  76. (
  77. (64, 128, 1, True),
  78. (128, 128, 3, False),
  79. (128, 128, 3, False),
  80. (128, 128, 3, False),
  81. (128, 128, 3, False),
  82. (128, 128, 3, False),
  83. ),
  84. )
  85. stage_configs: Sequence[Sequence[tuple[int, ...] | list[int]]] = (
  86. ((128, 1),),
  87. ((128, 2),),
  88. ((128, 5),),
  89. (
  90. (128, 8),
  91. (128, 3),
  92. (128, 2),
  93. ),
  94. (
  95. (128, 12),
  96. (128, 7),
  97. (128, 4),
  98. ),
  99. (
  100. (128, 18),
  101. (128, 12),
  102. (128, 6),
  103. ),
  104. )
  105. kernel_size: int = 5
  106. def __post_init__(self, **kwargs):
  107. self.depths = [len(stages) for stages in self.stage_configs]
  108. self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.stage_configs) + 1)]
  109. self.set_output_features_output_indices(
  110. out_indices=kwargs.pop("out_indices", None), out_features=kwargs.pop("out_features", None)
  111. )
  112. super().__post_init__(**kwargs)
  113. @auto_docstring(checkpoint="PaddlePaddle/UVDoc_safetensors")
  114. @strict
  115. class UVDocConfig(PreTrainedConfig):
  116. r"""
  117. padding_mode (`str`, *optional*, defaults to `"reflect"`):
  118. Padding mode for convolutional layers. Supported modes are `"reflect"`, `"constant"`, and `"replicate"`.
  119. kernel_size (`int`, *optional*, defaults to 5):
  120. Kernel size for convolutional layers in the backbone network.
  121. bridge_connector (`list[int] | tuple[int, ...]`, *optional*, defaults to `(128, 128)`):
  122. Configuration for the bridge connector in format [in_channels, out_channels].
  123. out_point_positions2D (`Sequence[list[int] | tuple[int, ...]]`, *optional*, defaults to `((128, 32), (32, 2))`):
  124. Configuration for the output point positions 2D layer in format [in_channels, out_channels].
  125. """
  126. model_type = "uvdoc"
  127. sub_configs = {"backbone_config": AutoConfig}
  128. backbone_config: dict | PreTrainedConfig | None = None
  129. hidden_act: str = "prelu"
  130. padding_mode: str = "reflect"
  131. kernel_size: int = 5
  132. bridge_connector: list[int] | tuple[int, ...] = (128, 128)
  133. out_point_positions2D: Sequence[list[int] | tuple[int, ...]] = ((128, 32), (32, 2))
  134. def __post_init__(self, **kwargs):
  135. self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
  136. backbone_config=self.backbone_config,
  137. default_config_type="uvdoc_backbone",
  138. **kwargs,
  139. )
  140. super().__post_init__(**kwargs)
  141. @auto_docstring
  142. @requires(backends=("torch",))
  143. class UVDocImageProcessor(TorchvisionBackend):
  144. do_rescale = True
  145. do_resize = True
  146. size = {"height": 712, "width": 488}
  147. resample = PILImageResampling.BILINEAR
  148. def _preprocess(
  149. self,
  150. images: list["torch.Tensor"],
  151. do_resize: bool,
  152. size: SizeDict,
  153. do_rescale: bool,
  154. rescale_factor: float,
  155. do_normalize: bool,
  156. image_mean: float | list[float] | None,
  157. image_std: float | list[float] | None,
  158. disable_grouping: bool | None,
  159. return_tensors: str | TensorType | None,
  160. **kwargs,
  161. ) -> BatchFeature:
  162. grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
  163. processed_images_grouped = {}
  164. for shape, stacked_images in grouped_images.items():
  165. stacked_images = self.rescale_and_normalize(
  166. stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
  167. )
  168. # RGB to BGR conversion
  169. stacked_images = stacked_images[:, [2, 1, 0], :, :]
  170. processed_images_grouped[shape] = stacked_images
  171. rescale_and_normalize_images = reorder_images(processed_images_grouped, grouped_images_index)
  172. original_images = rescale_and_normalize_images.copy()
  173. grouped_images, grouped_images_index = group_images_by_shape(
  174. rescale_and_normalize_images, disable_grouping=disable_grouping
  175. )
  176. interpolated_images_grouped = {}
  177. # Upsample images and extract originals for post-processing
  178. for shape, stacked_images in grouped_images.items():
  179. # Interpolate to target size (use interpolate with align_corners=True to match original implementation)
  180. if do_resize:
  181. stacked_images = F.interpolate(
  182. stacked_images, size=(size.height, size.width), mode="bilinear", align_corners=True
  183. )
  184. interpolated_images_grouped[shape] = stacked_images
  185. pixel_values = reorder_images(interpolated_images_grouped, grouped_images_index)
  186. return BatchFeature(
  187. data={"pixel_values": pixel_values, "original_images": original_images},
  188. tensor_type=return_tensors,
  189. skip_tensor_conversion=["original_images"],
  190. )
  191. def post_process_document_rectification(
  192. self,
  193. prediction: torch.Tensor,
  194. original_images: list[torch.Tensor],
  195. scale: float = 255.0,
  196. ) -> list[dict[str, torch.Tensor]]:
  197. """
  198. Post-process document rectification predictions to convert them into rectified images.
  199. Args:
  200. prediction: Predicted 2D Bezier mesh coordinates, shape (B, 2, H, W)
  201. original_images: List of original input tensors, each of shape (C, H_i, W_i). Images may have different sizes.
  202. scale: Scaling factor for output images (default: 255.0)
  203. Returns:
  204. List of dictionaries containing rectified images. Each dictionary has:
  205. - "images": Rectified image tensor of shape (H, W, 3) with dtype torch.uint8
  206. and BGR channel order (suitable for OpenCV visualization)
  207. """
  208. image_list = list(original_images)
  209. scale = torch.tensor(float(scale), device=prediction.device)
  210. results = []
  211. for i, original_image in enumerate(image_list):
  212. # Ensure (1, C, H, W) for grid_sample
  213. if original_image.ndim == 3:
  214. original_image = original_image.unsqueeze(0)
  215. original_image = original_image.to(prediction.device)
  216. original_height, original_width = original_image.shape[2:]
  217. # Upsample predicted mesh for this image to its original size
  218. upsampled_mesh = F.interpolate(
  219. prediction[i : i + 1],
  220. size=(original_height, original_width),
  221. mode="bilinear",
  222. align_corners=True,
  223. )
  224. # Permute mesh for grid_sample: (1, H, W, 2)
  225. rearranged_mesh = upsampled_mesh.permute(0, 2, 3, 1)
  226. # Apply spatial transformation to rectify the document
  227. rectified = F.grid_sample(original_image, rearranged_mesh, align_corners=True)
  228. # Remove batch dimension and rearrange channels: (H, W, C)
  229. image = rectified.squeeze(0).permute(1, 2, 0)
  230. # Scale and convert to uint8 with BGR channel
  231. image = image * scale
  232. image = image.flip(dims=[-1]).to(dtype=torch.uint8, non_blocking=True, copy=False)
  233. results.append({"images": image})
  234. return results
  235. class UVDocConvLayer(PPLCNetConvLayer):
  236. """Convolutional layer with batch normalization and activation."""
  237. def __init__(
  238. self,
  239. in_channels: int,
  240. out_channels: int,
  241. kernel_size: int = 3,
  242. stride: int = 1,
  243. padding: int = 0,
  244. padding_mode: str = "zeros",
  245. bias: bool = False,
  246. dilation: int = 1,
  247. activation: str = "relu",
  248. ):
  249. super().__init__()
  250. self.convolution = nn.Conv2d(
  251. in_channels,
  252. out_channels,
  253. bias=bias,
  254. kernel_size=kernel_size,
  255. stride=stride,
  256. padding=padding,
  257. padding_mode=padding_mode,
  258. dilation=dilation,
  259. )
  260. class UVDocResidualBlock(nn.Module):
  261. """Base residual block with dilation support."""
  262. def __init__(
  263. self,
  264. in_channels: int,
  265. out_channels: int,
  266. kernel_size: int,
  267. stride: int = 1,
  268. padding: int = 0,
  269. dilation: int = 1,
  270. downsample: bool = False,
  271. activation: str = "relu",
  272. ):
  273. super().__init__()
  274. self.conv_down = (
  275. UVDocConvLayer(
  276. in_channels=in_channels,
  277. out_channels=out_channels,
  278. kernel_size=kernel_size,
  279. stride=stride,
  280. padding=kernel_size // 2,
  281. bias=True,
  282. activation=None,
  283. )
  284. if downsample
  285. else nn.Identity()
  286. )
  287. self.conv_start = UVDocConvLayer(
  288. in_channels=in_channels,
  289. out_channels=out_channels,
  290. kernel_size=kernel_size,
  291. stride=stride,
  292. padding=padding,
  293. dilation=dilation,
  294. bias=True,
  295. )
  296. self.conv_final = UVDocConvLayer(
  297. in_channels=out_channels,
  298. out_channels=out_channels,
  299. kernel_size=kernel_size,
  300. stride=1,
  301. padding=padding,
  302. bias=True,
  303. dilation=dilation,
  304. activation=None,
  305. )
  306. self.act_fn = ACT2FN[activation] if activation is not None else nn.Identity()
  307. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  308. residual = self.conv_down(hidden_states)
  309. hidden_states = self.conv_start(hidden_states)
  310. hidden_states = self.conv_final(hidden_states)
  311. hidden_states = hidden_states + residual
  312. hidden_states = self.act_fn(hidden_states)
  313. return hidden_states
  314. class UVDocResNetStage(nn.Module):
  315. """A ResNet stage containing multiple residual blocks."""
  316. def __init__(self, config, stage_index):
  317. super().__init__()
  318. stages = config.resnet_configs[stage_index]
  319. self.layers = nn.ModuleList([])
  320. for in_channels, out_channels, dilation, downsample in stages:
  321. self.layers.append(
  322. UVDocResidualBlock(
  323. in_channels=in_channels,
  324. out_channels=out_channels,
  325. stride=2 if downsample else 1,
  326. padding=dilation * 2,
  327. dilation=dilation,
  328. downsample=downsample,
  329. kernel_size=config.kernel_size,
  330. )
  331. )
  332. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  333. for layer in self.layers:
  334. hidden_states = layer(hidden_states)
  335. return hidden_states
  336. class UVDocResNet(nn.Module):
  337. """Initial resnet_head and resnet_down."""
  338. def __init__(self, config):
  339. super().__init__()
  340. self.resnet_head = nn.ModuleList([])
  341. for i in range(len(config.resnet_head)):
  342. self.resnet_head.append(
  343. UVDocConvLayer(
  344. in_channels=config.resnet_head[i][0],
  345. out_channels=config.resnet_head[i][1],
  346. kernel_size=config.kernel_size,
  347. stride=2,
  348. padding=config.kernel_size // 2,
  349. )
  350. )
  351. self.resnet_down = nn.ModuleList([])
  352. for stage_index in range(len(config.resnet_configs)):
  353. stage = UVDocResNetStage(config, stage_index)
  354. self.resnet_down.append(stage)
  355. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  356. for head in self.resnet_head:
  357. hidden_states = head(hidden_states)
  358. for stage in self.resnet_down:
  359. hidden_states = stage(hidden_states)
  360. return hidden_states
  361. class UVDocBridgeBlock(GradientCheckpointingLayer):
  362. """Bridge module with dilated convolutions for long-range dependencies."""
  363. def __init__(self, config, bridge_index):
  364. super().__init__()
  365. self.blocks = nn.ModuleList([])
  366. bridge = config.stage_configs[bridge_index]
  367. for in_channels, dilation in bridge:
  368. self.blocks.append(UVDocConvLayer(in_channels, in_channels, padding=dilation, dilation=dilation))
  369. def forward(
  370. self,
  371. hidden_states: torch.Tensor,
  372. **kwargs: Unpack[TransformersKwargs],
  373. ) -> torch.Tensor:
  374. for block in self.blocks:
  375. hidden_states = block(hidden_states)
  376. return hidden_states
  377. class UVDocPointPositions2D(nn.Module):
  378. """Module for predicting 2D point positions for document rectification."""
  379. def __init__(self, config):
  380. super().__init__()
  381. self.conv_down = UVDocConvLayer(
  382. in_channels=config.out_point_positions2D[0][0],
  383. out_channels=config.out_point_positions2D[0][1],
  384. kernel_size=config.kernel_size,
  385. stride=1,
  386. padding=config.kernel_size // 2,
  387. padding_mode=config.padding_mode,
  388. activation=config.hidden_act,
  389. )
  390. self.conv_up = nn.Conv2d(
  391. in_channels=config.out_point_positions2D[1][0],
  392. out_channels=config.out_point_positions2D[1][1],
  393. kernel_size=config.kernel_size,
  394. stride=1,
  395. padding=config.kernel_size // 2,
  396. padding_mode=config.padding_mode,
  397. )
  398. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  399. hidden_states = self.conv_down(hidden_states)
  400. hidden_states = self.conv_up(hidden_states)
  401. return hidden_states
  402. @auto_docstring
  403. class UVDocPreTrainedModel(PPOCRV5ServerDetPreTrainedModel):
  404. supports_gradient_checkpointing = True
  405. _can_record_outputs = {
  406. "hidden_states": UVDocBridgeBlock,
  407. }
  408. @torch.no_grad()
  409. def _init_weights(self, module):
  410. PreTrainedModel._init_weights(module)
  411. """Initialize the weights."""
  412. if isinstance(module, nn.PReLU):
  413. module.reset_parameters()
  414. class UVDocBridge(UVDocPreTrainedModel):
  415. def __init__(self, config):
  416. super().__init__(config)
  417. self.bridge = nn.ModuleList([])
  418. for bridge_index in range(len(config.stage_configs)):
  419. self.bridge.append(UVDocBridgeBlock(config, bridge_index))
  420. self.post_init()
  421. @merge_with_config_defaults
  422. @capture_outputs
  423. def forward(
  424. self,
  425. hidden_states: torch.Tensor,
  426. **kwargs: Unpack[TransformersKwargs],
  427. ) -> torch.Tensor:
  428. for layer in self.bridge:
  429. feature = layer(hidden_states)
  430. return BaseModelOutputWithNoAttention(last_hidden_state=feature)
  431. @auto_docstring(
  432. custom_intro="""
  433. UVDoc backbone model for feature extraction.
  434. """
  435. )
  436. class UVDocBackbone(BackboneMixin, UVDocPreTrainedModel):
  437. has_attentions = False
  438. base_model_prefix = "backbone"
  439. def __init__(self, config: UVDocBackboneConfig):
  440. super().__init__(config)
  441. num_features = [config.resnet_head[-1][-1]]
  442. for stage in config.stage_configs:
  443. num_features.append(stage[0][1])
  444. self.num_features = num_features
  445. self.resnet = UVDocResNet(config)
  446. self.bridge = UVDocBridge(config)
  447. self.post_init()
  448. @can_return_tuple
  449. @filter_output_hidden_states
  450. @auto_docstring
  451. def forward(
  452. self,
  453. pixel_values: torch.FloatTensor,
  454. **kwargs: Unpack[TransformersKwargs],
  455. ) -> BackboneOutput:
  456. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  457. hidden_states = self.resnet(pixel_values)
  458. outputs = self.bridge(hidden_states, **kwargs)
  459. feature_maps = ()
  460. for idx, stage in enumerate(self.stage_names):
  461. if stage in self.out_features:
  462. feature_maps += (outputs.hidden_states[idx],)
  463. return BackboneOutput(
  464. feature_maps=feature_maps,
  465. hidden_states=outputs.hidden_states,
  466. )
  467. class UVDocHead(nn.Module):
  468. def __init__(self, config):
  469. super().__init__()
  470. self.num_bridge_layers = len(config.backbone_config.stage_configs)
  471. self.bridge_connector = UVDocConvLayer(
  472. in_channels=config.bridge_connector[0] * self.num_bridge_layers,
  473. out_channels=config.bridge_connector[1],
  474. kernel_size=1,
  475. stride=1,
  476. padding=0,
  477. dilation=1,
  478. )
  479. self.out_point_positions2D = UVDocPointPositions2D(config)
  480. def forward(
  481. self,
  482. hidden_states: torch.Tensor,
  483. **kwargs: Unpack[TransformersKwargs],
  484. ) -> torch.torch.Tensor:
  485. hidden_states = self.bridge_connector(hidden_states)
  486. hidden_states = self.out_point_positions2D(hidden_states)
  487. return hidden_states
  488. @auto_docstring(
  489. custom_intro=r"""
  490. The model takes raw document images (pixel values) as input, processes them through the UVDoc backbone to predict spatial transformation parameters,
  491. and outputs the rectified (corrected) document image tensor.
  492. """
  493. )
  494. class UVDocModel(UVDocPreTrainedModel):
  495. def __init__(self, config: UVDocConfig):
  496. super().__init__(config)
  497. self.backbone = UVDocBackbone(config.backbone_config)
  498. self.head = UVDocHead(config)
  499. self.post_init()
  500. @can_return_tuple
  501. @auto_docstring
  502. def forward(
  503. self,
  504. pixel_values: torch.FloatTensor,
  505. **kwargs: Unpack[TransformersKwargs],
  506. ) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention:
  507. backbone_outputs = self.backbone(pixel_values, **kwargs)
  508. fused_outputs = torch.cat(backbone_outputs.feature_maps, dim=1)
  509. last_hidden_state = self.head(fused_outputs, **kwargs)
  510. return BaseModelOutputWithNoAttention(
  511. last_hidden_state=last_hidden_state,
  512. hidden_states=backbone_outputs.hidden_states,
  513. )
  514. __all__ = [
  515. "UVDocBridge",
  516. "UVDocBackbone",
  517. "UVDocBackboneConfig",
  518. "UVDocImageProcessor",
  519. "UVDocConfig",
  520. "UVDocModel",
  521. "UVDocPreTrainedModel",
  522. ]