modeling_uvdoc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/uvdoc/modular_uvdoc.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_uvdoc.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import torch
  21. import torch.nn as nn
  22. from torch import Tensor
  23. from ...activations import ACT2FN
  24. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention
  27. from ...modeling_utils import PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  30. from ...utils.generic import merge_with_config_defaults
  31. from ...utils.output_capturing import capture_outputs
  32. from .configuration_uvdoc import UVDocBackboneConfig, UVDocConfig
  33. class UVDocConvLayer(nn.Module):
  34. """Convolutional layer with batch normalization and activation."""
  35. def __init__(
  36. self,
  37. in_channels: int,
  38. out_channels: int,
  39. kernel_size: int = 3,
  40. stride: int = 1,
  41. padding: int = 0,
  42. padding_mode: str = "zeros",
  43. bias: bool = False,
  44. dilation: int = 1,
  45. activation: str = "relu",
  46. ):
  47. super().__init__()
  48. self.convolution = nn.Conv2d(
  49. in_channels,
  50. out_channels,
  51. bias=bias,
  52. kernel_size=kernel_size,
  53. stride=stride,
  54. padding=padding,
  55. padding_mode=padding_mode,
  56. dilation=dilation,
  57. )
  58. self.normalization = nn.BatchNorm2d(out_channels)
  59. self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
  60. def forward(self, input: Tensor) -> Tensor:
  61. hidden_state = self.convolution(input)
  62. hidden_state = self.normalization(hidden_state)
  63. hidden_state = self.activation(hidden_state)
  64. return hidden_state
  65. class UVDocResidualBlock(nn.Module):
  66. """Base residual block with dilation support."""
  67. def __init__(
  68. self,
  69. in_channels: int,
  70. out_channels: int,
  71. kernel_size: int,
  72. stride: int = 1,
  73. padding: int = 0,
  74. dilation: int = 1,
  75. downsample: bool = False,
  76. activation: str = "relu",
  77. ):
  78. super().__init__()
  79. self.conv_down = (
  80. UVDocConvLayer(
  81. in_channels=in_channels,
  82. out_channels=out_channels,
  83. kernel_size=kernel_size,
  84. stride=stride,
  85. padding=kernel_size // 2,
  86. bias=True,
  87. activation=None,
  88. )
  89. if downsample
  90. else nn.Identity()
  91. )
  92. self.conv_start = UVDocConvLayer(
  93. in_channels=in_channels,
  94. out_channels=out_channels,
  95. kernel_size=kernel_size,
  96. stride=stride,
  97. padding=padding,
  98. dilation=dilation,
  99. bias=True,
  100. )
  101. self.conv_final = UVDocConvLayer(
  102. in_channels=out_channels,
  103. out_channels=out_channels,
  104. kernel_size=kernel_size,
  105. stride=1,
  106. padding=padding,
  107. bias=True,
  108. dilation=dilation,
  109. activation=None,
  110. )
  111. self.act_fn = ACT2FN[activation] if activation is not None else nn.Identity()
  112. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  113. residual = self.conv_down(hidden_states)
  114. hidden_states = self.conv_start(hidden_states)
  115. hidden_states = self.conv_final(hidden_states)
  116. hidden_states = hidden_states + residual
  117. hidden_states = self.act_fn(hidden_states)
  118. return hidden_states
  119. class UVDocResNetStage(nn.Module):
  120. """A ResNet stage containing multiple residual blocks."""
  121. def __init__(self, config, stage_index):
  122. super().__init__()
  123. stages = config.resnet_configs[stage_index]
  124. self.layers = nn.ModuleList([])
  125. for in_channels, out_channels, dilation, downsample in stages:
  126. self.layers.append(
  127. UVDocResidualBlock(
  128. in_channels=in_channels,
  129. out_channels=out_channels,
  130. stride=2 if downsample else 1,
  131. padding=dilation * 2,
  132. dilation=dilation,
  133. downsample=downsample,
  134. kernel_size=config.kernel_size,
  135. )
  136. )
  137. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  138. for layer in self.layers:
  139. hidden_states = layer(hidden_states)
  140. return hidden_states
  141. class UVDocResNet(nn.Module):
  142. """Initial resnet_head and resnet_down."""
  143. def __init__(self, config):
  144. super().__init__()
  145. self.resnet_head = nn.ModuleList([])
  146. for i in range(len(config.resnet_head)):
  147. self.resnet_head.append(
  148. UVDocConvLayer(
  149. in_channels=config.resnet_head[i][0],
  150. out_channels=config.resnet_head[i][1],
  151. kernel_size=config.kernel_size,
  152. stride=2,
  153. padding=config.kernel_size // 2,
  154. )
  155. )
  156. self.resnet_down = nn.ModuleList([])
  157. for stage_index in range(len(config.resnet_configs)):
  158. stage = UVDocResNetStage(config, stage_index)
  159. self.resnet_down.append(stage)
  160. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  161. for head in self.resnet_head:
  162. hidden_states = head(hidden_states)
  163. for stage in self.resnet_down:
  164. hidden_states = stage(hidden_states)
  165. return hidden_states
  166. class UVDocBridgeBlock(GradientCheckpointingLayer):
  167. """Bridge module with dilated convolutions for long-range dependencies."""
  168. def __init__(self, config, bridge_index):
  169. super().__init__()
  170. self.blocks = nn.ModuleList([])
  171. bridge = config.stage_configs[bridge_index]
  172. for in_channels, dilation in bridge:
  173. self.blocks.append(UVDocConvLayer(in_channels, in_channels, padding=dilation, dilation=dilation))
  174. def forward(
  175. self,
  176. hidden_states: torch.Tensor,
  177. **kwargs: Unpack[TransformersKwargs],
  178. ) -> torch.Tensor:
  179. for block in self.blocks:
  180. hidden_states = block(hidden_states)
  181. return hidden_states
  182. class UVDocPointPositions2D(nn.Module):
  183. """Module for predicting 2D point positions for document rectification."""
  184. def __init__(self, config):
  185. super().__init__()
  186. self.conv_down = UVDocConvLayer(
  187. in_channels=config.out_point_positions2D[0][0],
  188. out_channels=config.out_point_positions2D[0][1],
  189. kernel_size=config.kernel_size,
  190. stride=1,
  191. padding=config.kernel_size // 2,
  192. padding_mode=config.padding_mode,
  193. activation=config.hidden_act,
  194. )
  195. self.conv_up = nn.Conv2d(
  196. in_channels=config.out_point_positions2D[1][0],
  197. out_channels=config.out_point_positions2D[1][1],
  198. kernel_size=config.kernel_size,
  199. stride=1,
  200. padding=config.kernel_size // 2,
  201. padding_mode=config.padding_mode,
  202. )
  203. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  204. hidden_states = self.conv_down(hidden_states)
  205. hidden_states = self.conv_up(hidden_states)
  206. return hidden_states
  207. @auto_docstring
  208. class UVDocPreTrainedModel(PreTrainedModel):
  209. """
  210. Base class for all PPOCRV5 Server Det pre-trained models. Handles model initialization,
  211. configuration, and loading of pre-trained weights, following the Transformers library conventions.
  212. """
  213. config: UVDocConfig
  214. base_model_prefix = "uvdoc"
  215. main_input_name = "pixel_values"
  216. input_modalities = ("image",)
  217. _can_compile_fullgraph = True
  218. supports_gradient_checkpointing = True
  219. _can_record_outputs = {
  220. "hidden_states": UVDocBridgeBlock,
  221. }
  222. @torch.no_grad()
  223. def _init_weights(self, module):
  224. super()._init_weights(module)
  225. """Initialize the weights."""
  226. if isinstance(module, nn.PReLU):
  227. module.reset_parameters()
  228. class UVDocBridge(UVDocPreTrainedModel):
  229. def __init__(self, config):
  230. super().__init__(config)
  231. self.bridge = nn.ModuleList([])
  232. for bridge_index in range(len(config.stage_configs)):
  233. self.bridge.append(UVDocBridgeBlock(config, bridge_index))
  234. self.post_init()
  235. @merge_with_config_defaults
  236. @capture_outputs
  237. def forward(
  238. self,
  239. hidden_states: torch.Tensor,
  240. **kwargs: Unpack[TransformersKwargs],
  241. ) -> torch.Tensor:
  242. for layer in self.bridge:
  243. feature = layer(hidden_states)
  244. return BaseModelOutputWithNoAttention(last_hidden_state=feature)
  245. @auto_docstring(
  246. custom_intro="""
  247. UVDoc backbone model for feature extraction.
  248. """
  249. )
  250. class UVDocBackbone(BackboneMixin, UVDocPreTrainedModel):
  251. has_attentions = False
  252. base_model_prefix = "backbone"
  253. def __init__(self, config: UVDocBackboneConfig):
  254. super().__init__(config)
  255. num_features = [config.resnet_head[-1][-1]]
  256. for stage in config.stage_configs:
  257. num_features.append(stage[0][1])
  258. self.num_features = num_features
  259. self.resnet = UVDocResNet(config)
  260. self.bridge = UVDocBridge(config)
  261. self.post_init()
  262. @can_return_tuple
  263. @filter_output_hidden_states
  264. @auto_docstring
  265. def forward(
  266. self,
  267. pixel_values: torch.FloatTensor,
  268. **kwargs: Unpack[TransformersKwargs],
  269. ) -> BackboneOutput:
  270. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  271. hidden_states = self.resnet(pixel_values)
  272. outputs = self.bridge(hidden_states, **kwargs)
  273. feature_maps = ()
  274. for idx, stage in enumerate(self.stage_names):
  275. if stage in self.out_features:
  276. feature_maps += (outputs.hidden_states[idx],)
  277. return BackboneOutput(
  278. feature_maps=feature_maps,
  279. hidden_states=outputs.hidden_states,
  280. )
  281. class UVDocHead(nn.Module):
  282. def __init__(self, config):
  283. super().__init__()
  284. self.num_bridge_layers = len(config.backbone_config.stage_configs)
  285. self.bridge_connector = UVDocConvLayer(
  286. in_channels=config.bridge_connector[0] * self.num_bridge_layers,
  287. out_channels=config.bridge_connector[1],
  288. kernel_size=1,
  289. stride=1,
  290. padding=0,
  291. dilation=1,
  292. )
  293. self.out_point_positions2D = UVDocPointPositions2D(config)
  294. def forward(
  295. self,
  296. hidden_states: torch.Tensor,
  297. **kwargs: Unpack[TransformersKwargs],
  298. ) -> torch.torch.Tensor:
  299. hidden_states = self.bridge_connector(hidden_states)
  300. hidden_states = self.out_point_positions2D(hidden_states)
  301. return hidden_states
  302. @auto_docstring(
  303. custom_intro=r"""
  304. The model takes raw document images (pixel values) as input, processes them through the UVDoc backbone to predict spatial transformation parameters,
  305. and outputs the rectified (corrected) document image tensor.
  306. """
  307. )
  308. class UVDocModel(UVDocPreTrainedModel):
  309. def __init__(self, config: UVDocConfig):
  310. super().__init__(config)
  311. self.backbone = UVDocBackbone(config.backbone_config)
  312. self.head = UVDocHead(config)
  313. self.post_init()
  314. @can_return_tuple
  315. @auto_docstring
  316. def forward(
  317. self,
  318. pixel_values: torch.FloatTensor,
  319. **kwargs: Unpack[TransformersKwargs],
  320. ) -> tuple[torch.FloatTensor] | BaseModelOutputWithNoAttention:
  321. backbone_outputs = self.backbone(pixel_values, **kwargs)
  322. fused_outputs = torch.cat(backbone_outputs.feature_maps, dim=1)
  323. last_hidden_state = self.head(fused_outputs, **kwargs)
  324. return BaseModelOutputWithNoAttention(
  325. last_hidden_state=last_hidden_state,
  326. hidden_states=backbone_outputs.hidden_states,
  327. )
  328. __all__ = ["UVDocBridge", "UVDocBackbone", "UVDocModel", "UVDocPreTrainedModel"]