modeling_textnet.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch TextNet model."""
  15. from typing import Any
  16. import torch
  17. import torch.nn as nn
  18. from torch import Tensor
  19. from ...activations import ACT2CLS
  20. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  21. from ...modeling_outputs import (
  22. BackboneOutput,
  23. BaseModelOutputWithNoAttention,
  24. BaseModelOutputWithPoolingAndNoAttention,
  25. ImageClassifierOutputWithNoAttention,
  26. )
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import auto_docstring, logging
  29. from ...utils.generic import can_return_tuple
  30. from .configuration_textnet import TextNetConfig
  31. logger = logging.get_logger(__name__)
  32. class TextNetConvLayer(nn.Module):
  33. def __init__(self, config: TextNetConfig):
  34. super().__init__()
  35. self.kernel_size = config.stem_kernel_size
  36. self.stride = config.stem_stride
  37. self.activation_function = config.stem_act_func
  38. padding = (
  39. (config.kernel_size[0] // 2, config.kernel_size[1] // 2)
  40. if isinstance(config.stem_kernel_size, tuple)
  41. else config.stem_kernel_size // 2
  42. )
  43. self.conv = nn.Conv2d(
  44. config.stem_num_channels,
  45. config.stem_out_channels,
  46. kernel_size=config.stem_kernel_size,
  47. stride=config.stem_stride,
  48. padding=padding,
  49. bias=False,
  50. )
  51. self.batch_norm = nn.BatchNorm2d(config.stem_out_channels, config.batch_norm_eps)
  52. self.activation = nn.Identity()
  53. if self.activation_function is not None:
  54. self.activation = ACT2CLS[self.activation_function]()
  55. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  56. hidden_states = self.conv(hidden_states)
  57. hidden_states = self.batch_norm(hidden_states)
  58. return self.activation(hidden_states)
  59. class TextNetRepConvLayer(nn.Module):
  60. r"""
  61. This layer supports re-parameterization by combining multiple convolutional branches
  62. (e.g., main convolution, vertical, horizontal, and identity branches) during training.
  63. At inference time, these branches can be collapsed into a single convolution for
  64. efficiency, as per the re-parameterization paradigm.
  65. The "Rep" in the name stands for "re-parameterization" (introduced by RepVGG).
  66. """
  67. def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, kernel_size: int, stride: int):
  68. super().__init__()
  69. self.num_channels = in_channels
  70. self.out_channels = out_channels
  71. self.kernel_size = kernel_size
  72. self.stride = stride
  73. padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
  74. self.activation_function = nn.ReLU()
  75. self.main_conv = nn.Conv2d(
  76. in_channels=in_channels,
  77. out_channels=out_channels,
  78. kernel_size=kernel_size,
  79. stride=stride,
  80. padding=padding,
  81. bias=False,
  82. )
  83. self.main_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
  84. vertical_padding = ((kernel_size[0] - 1) // 2, 0)
  85. horizontal_padding = (0, (kernel_size[1] - 1) // 2)
  86. if kernel_size[1] != 1:
  87. self.vertical_conv = nn.Conv2d(
  88. in_channels=in_channels,
  89. out_channels=out_channels,
  90. kernel_size=(kernel_size[0], 1),
  91. stride=stride,
  92. padding=vertical_padding,
  93. bias=False,
  94. )
  95. self.vertical_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
  96. else:
  97. self.vertical_conv, self.vertical_batch_norm = None, None
  98. if kernel_size[0] != 1:
  99. self.horizontal_conv = nn.Conv2d(
  100. in_channels=in_channels,
  101. out_channels=out_channels,
  102. kernel_size=(1, kernel_size[1]),
  103. stride=stride,
  104. padding=horizontal_padding,
  105. bias=False,
  106. )
  107. self.horizontal_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
  108. else:
  109. self.horizontal_conv, self.horizontal_batch_norm = None, None
  110. self.rbr_identity = (
  111. nn.BatchNorm2d(num_features=in_channels, eps=config.batch_norm_eps)
  112. if out_channels == in_channels and stride == 1
  113. else None
  114. )
  115. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  116. main_outputs = self.main_conv(hidden_states)
  117. main_outputs = self.main_batch_norm(main_outputs)
  118. # applies a convolution with a vertical kernel
  119. if self.vertical_conv is not None:
  120. vertical_outputs = self.vertical_conv(hidden_states)
  121. vertical_outputs = self.vertical_batch_norm(vertical_outputs)
  122. main_outputs = main_outputs + vertical_outputs
  123. # applies a convolution with a horizontal kernel
  124. if self.horizontal_conv is not None:
  125. horizontal_outputs = self.horizontal_conv(hidden_states)
  126. horizontal_outputs = self.horizontal_batch_norm(horizontal_outputs)
  127. main_outputs = main_outputs + horizontal_outputs
  128. if self.rbr_identity is not None:
  129. id_out = self.rbr_identity(hidden_states)
  130. main_outputs = main_outputs + id_out
  131. return self.activation_function(main_outputs)
  132. class TextNetStage(nn.Module):
  133. def __init__(self, config: TextNetConfig, depth: int):
  134. super().__init__()
  135. kernel_size = config.conv_layer_kernel_sizes[depth]
  136. stride = config.conv_layer_strides[depth]
  137. num_layers = len(kernel_size)
  138. stage_in_channel_size = config.hidden_sizes[depth]
  139. stage_out_channel_size = config.hidden_sizes[depth + 1]
  140. in_channels = [stage_in_channel_size] + [stage_out_channel_size] * (num_layers - 1)
  141. out_channels = [stage_out_channel_size] * num_layers
  142. stage = []
  143. for stage_config in zip(in_channels, out_channels, kernel_size, stride):
  144. stage.append(TextNetRepConvLayer(config, *stage_config))
  145. self.stage = nn.ModuleList(stage)
  146. def forward(self, hidden_state):
  147. for block in self.stage:
  148. hidden_state = block(hidden_state)
  149. return hidden_state
  150. class TextNetEncoder(nn.Module):
  151. def __init__(self, config: TextNetConfig):
  152. super().__init__()
  153. stages = []
  154. num_stages = len(config.conv_layer_kernel_sizes)
  155. for stage_ix in range(num_stages):
  156. stages.append(TextNetStage(config, stage_ix))
  157. self.stages = nn.ModuleList(stages)
  158. def forward(
  159. self,
  160. hidden_state: torch.Tensor,
  161. output_hidden_states: bool | None = None,
  162. return_dict: bool | None = None,
  163. ) -> BaseModelOutputWithNoAttention:
  164. hidden_states = [hidden_state]
  165. for stage in self.stages:
  166. hidden_state = stage(hidden_state)
  167. hidden_states.append(hidden_state)
  168. if not return_dict:
  169. output = (hidden_state,)
  170. return output + (hidden_states,) if output_hidden_states else output
  171. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
  172. @auto_docstring
  173. class TextNetPreTrainedModel(PreTrainedModel):
  174. config: TextNetConfig
  175. base_model_prefix = "textnet"
  176. main_input_name = "pixel_values"
  177. @auto_docstring
  178. class TextNetModel(TextNetPreTrainedModel):
  179. def __init__(self, config):
  180. super().__init__(config)
  181. self.stem = TextNetConvLayer(config)
  182. self.encoder = TextNetEncoder(config)
  183. self.pooler = nn.AdaptiveAvgPool2d((2, 2))
  184. self.post_init()
  185. @auto_docstring
  186. def forward(
  187. self,
  188. pixel_values: Tensor,
  189. output_hidden_states: bool | None = None,
  190. return_dict: bool | None = None,
  191. **kwargs,
  192. ) -> tuple[Any, list[Any]] | tuple[Any] | BaseModelOutputWithPoolingAndNoAttention:
  193. return_dict = return_dict if return_dict is not None else self.config.return_dict
  194. output_hidden_states = (
  195. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  196. )
  197. hidden_state = self.stem(pixel_values)
  198. encoder_outputs = self.encoder(
  199. hidden_state, output_hidden_states=output_hidden_states, return_dict=return_dict
  200. )
  201. last_hidden_state = encoder_outputs[0]
  202. pooled_output = self.pooler(last_hidden_state)
  203. if not return_dict:
  204. output = (last_hidden_state, pooled_output)
  205. return output + (encoder_outputs[1],) if output_hidden_states else output
  206. return BaseModelOutputWithPoolingAndNoAttention(
  207. last_hidden_state=last_hidden_state,
  208. pooler_output=pooled_output,
  209. hidden_states=encoder_outputs[1] if output_hidden_states else None,
  210. )
  211. @auto_docstring(
  212. custom_intro="""
  213. TextNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  214. ImageNet.
  215. """
  216. )
  217. class TextNetForImageClassification(TextNetPreTrainedModel):
  218. def __init__(self, config):
  219. super().__init__(config)
  220. self.num_labels = config.num_labels
  221. self.textnet = TextNetModel(config)
  222. self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
  223. self.flatten = nn.Flatten()
  224. self.fc = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  225. # classification head
  226. self.classifier = nn.ModuleList([self.avg_pool, self.flatten])
  227. # initialize weights and apply final processing
  228. self.post_init()
  229. @auto_docstring
  230. def forward(
  231. self,
  232. pixel_values: torch.FloatTensor | None = None,
  233. labels: torch.LongTensor | None = None,
  234. output_hidden_states: bool | None = None,
  235. return_dict: bool | None = None,
  236. **kwargs,
  237. ) -> ImageClassifierOutputWithNoAttention:
  238. r"""
  239. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  240. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  241. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  242. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  243. Examples:
  244. ```python
  245. >>> import torch
  246. >>> import httpx
  247. >>> from io import BytesIO
  248. >>> from transformers import TextNetForImageClassification, TextNetImageProcessor
  249. >>> from PIL import Image
  250. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  251. >>> with httpx.stream("GET", url) as response:
  252. ... image = Image.open(BytesIO(response.read()))
  253. >>> processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base")
  254. >>> model = TextNetForImageClassification.from_pretrained("czczup/textnet-base")
  255. >>> inputs = processor(images=image, return_tensors="pt")
  256. >>> with torch.no_grad():
  257. ... outputs = model(**inputs)
  258. >>> outputs.logits.shape
  259. torch.Size([1, 2])
  260. ```"""
  261. return_dict = return_dict if return_dict is not None else self.config.return_dict
  262. outputs = self.textnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  263. last_hidden_state = outputs[0]
  264. for layer in self.classifier:
  265. last_hidden_state = layer(last_hidden_state)
  266. logits = self.fc(last_hidden_state)
  267. loss = None
  268. if labels is not None:
  269. loss = self.loss_function(labels, logits, self.config)
  270. if not return_dict:
  271. output = (logits,) + outputs[2:]
  272. return (loss,) + output if loss is not None else output
  273. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  274. @auto_docstring(
  275. custom_intro="""
  276. TextNet backbone, to be used with frameworks like DETR and MaskFormer.
  277. """
  278. )
  279. class TextNetBackbone(BackboneMixin, TextNetPreTrainedModel):
  280. has_attentions = False
  281. def __init__(self, config):
  282. super().__init__(config)
  283. self.textnet = TextNetModel(config)
  284. self.num_features = config.hidden_sizes
  285. # initialize weights and apply final processing
  286. self.post_init()
  287. @can_return_tuple
  288. @filter_output_hidden_states
  289. @auto_docstring
  290. def forward(
  291. self,
  292. pixel_values: Tensor,
  293. output_hidden_states: bool | None = None,
  294. return_dict: bool | None = None,
  295. **kwargs,
  296. ) -> tuple[tuple] | BackboneOutput:
  297. r"""
  298. Examples:
  299. ```python
  300. >>> import torch
  301. >>> import httpx
  302. >>> from io import BytesIO
  303. >>> from PIL import Image
  304. >>> from transformers import AutoImageProcessor, AutoBackbone
  305. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  306. >>> with httpx.stream("GET", url) as response:
  307. ... image = Image.open(BytesIO(response.read()))
  308. >>> processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
  309. >>> model = AutoBackbone.from_pretrained("czczup/textnet-base")
  310. >>> inputs = processor(image, return_tensors="pt")
  311. >>> with torch.no_grad():
  312. >>> outputs = model(**inputs)
  313. ```"""
  314. return_dict = return_dict if return_dict is not None else self.config.return_dict
  315. output_hidden_states = (
  316. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  317. )
  318. outputs = self.textnet(pixel_values, output_hidden_states=True, return_dict=return_dict)
  319. hidden_states = outputs.hidden_states if return_dict else outputs[2]
  320. feature_maps = ()
  321. for idx, stage in enumerate(self.stage_names):
  322. if stage in self.out_features:
  323. feature_maps += (hidden_states[idx],)
  324. if not return_dict:
  325. output = (feature_maps,)
  326. if output_hidden_states:
  327. hidden_states = outputs.hidden_states if return_dict else outputs[2]
  328. output += (hidden_states,)
  329. return output
  330. return BackboneOutput(
  331. feature_maps=feature_maps,
  332. hidden_states=outputs.hidden_states if output_hidden_states else None,
  333. attentions=None,
  334. )
  335. __all__ = ["TextNetBackbone", "TextNetModel", "TextNetPreTrainedModel", "TextNetForImageClassification"]