modeling_upernet.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch UperNet model. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation."""
  15. import torch
  16. from torch import nn
  17. from torch.nn import CrossEntropyLoss
  18. from ...backbone_utils import load_backbone
  19. from ...modeling_outputs import SemanticSegmenterOutput
  20. from ...modeling_utils import PreTrainedModel
  21. from ...utils import auto_docstring
  22. from .configuration_upernet import UperNetConfig
  23. class UperNetConvModule(nn.Module):
  24. """
  25. A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
  26. layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
  27. """
  28. def __init__(
  29. self,
  30. in_channels: int,
  31. out_channels: int,
  32. kernel_size: int | tuple[int, int],
  33. padding: int | tuple[int, int] | str = 0,
  34. bias: bool = False,
  35. dilation: int | tuple[int, int] = 1,
  36. ) -> None:
  37. super().__init__()
  38. self.conv = nn.Conv2d(
  39. in_channels=in_channels,
  40. out_channels=out_channels,
  41. kernel_size=kernel_size,
  42. padding=padding,
  43. bias=bias,
  44. dilation=dilation,
  45. )
  46. self.batch_norm = nn.BatchNorm2d(out_channels)
  47. self.activation = nn.ReLU()
  48. def forward(self, input: torch.Tensor) -> torch.Tensor:
  49. output = self.conv(input)
  50. output = self.batch_norm(output)
  51. output = self.activation(output)
  52. return output
  53. class UperNetPyramidPoolingBlock(nn.Module):
  54. def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
  55. super().__init__()
  56. self.layers = [
  57. nn.AdaptiveAvgPool2d(pool_scale),
  58. UperNetConvModule(in_channels, channels, kernel_size=1),
  59. ]
  60. for i, layer in enumerate(self.layers):
  61. self.add_module(str(i), layer)
  62. def forward(self, input: torch.Tensor) -> torch.Tensor:
  63. hidden_state = input
  64. for layer in self.layers:
  65. hidden_state = layer(hidden_state)
  66. return hidden_state
  67. class UperNetPyramidPoolingModule(nn.Module):
  68. """
  69. Pyramid Pooling Module (PPM) used in PSPNet.
  70. Args:
  71. pool_scales (`tuple[int]`):
  72. Pooling scales used in Pooling Pyramid Module.
  73. in_channels (`int`):
  74. Input channels.
  75. channels (`int`):
  76. Channels after modules, before conv_seg.
  77. align_corners (`bool`):
  78. align_corners argument of F.interpolate.
  79. """
  80. def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
  81. super().__init__()
  82. self.pool_scales = pool_scales
  83. self.align_corners = align_corners
  84. self.in_channels = in_channels
  85. self.channels = channels
  86. self.blocks = []
  87. for i, pool_scale in enumerate(pool_scales):
  88. block = UperNetPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
  89. self.blocks.append(block)
  90. self.add_module(str(i), block)
  91. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  92. ppm_outs = []
  93. for ppm in self.blocks:
  94. ppm_out = ppm(x)
  95. upsampled_ppm_out = nn.functional.interpolate(
  96. ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
  97. )
  98. ppm_outs.append(upsampled_ppm_out)
  99. return ppm_outs
  100. class UperNetHead(nn.Module):
  101. """
  102. Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
  103. [UPerNet](https://huggingface.co/papers/1807.10221).
  104. """
  105. def __init__(self, config, in_channels):
  106. super().__init__()
  107. self.config = config
  108. self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
  109. self.in_channels = in_channels
  110. self.channels = config.hidden_size
  111. self.align_corners = False
  112. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  113. # PSP Module
  114. self.psp_modules = UperNetPyramidPoolingModule(
  115. self.pool_scales,
  116. self.in_channels[-1],
  117. self.channels,
  118. align_corners=self.align_corners,
  119. )
  120. self.bottleneck = UperNetConvModule(
  121. self.in_channels[-1] + len(self.pool_scales) * self.channels,
  122. self.channels,
  123. kernel_size=3,
  124. padding=1,
  125. )
  126. # FPN Module
  127. self.lateral_convs = nn.ModuleList()
  128. self.fpn_convs = nn.ModuleList()
  129. for in_channels in self.in_channels[:-1]: # skip the top layer
  130. l_conv = UperNetConvModule(in_channels, self.channels, kernel_size=1)
  131. fpn_conv = UperNetConvModule(self.channels, self.channels, kernel_size=3, padding=1)
  132. self.lateral_convs.append(l_conv)
  133. self.fpn_convs.append(fpn_conv)
  134. self.fpn_bottleneck = UperNetConvModule(
  135. len(self.in_channels) * self.channels,
  136. self.channels,
  137. kernel_size=3,
  138. padding=1,
  139. )
  140. def psp_forward(self, inputs):
  141. x = inputs[-1]
  142. psp_outs = [x]
  143. psp_outs.extend(self.psp_modules(x))
  144. psp_outs = torch.cat(psp_outs, dim=1)
  145. output = self.bottleneck(psp_outs)
  146. return output
  147. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  148. # build laterals
  149. laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
  150. laterals.append(self.psp_forward(encoder_hidden_states))
  151. # build top-down path
  152. used_backbone_levels = len(laterals)
  153. for i in range(used_backbone_levels - 1, 0, -1):
  154. prev_shape = laterals[i - 1].shape[2:]
  155. laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
  156. laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
  157. )
  158. # build outputs
  159. fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
  160. # append psp feature
  161. fpn_outs.append(laterals[-1])
  162. for i in range(used_backbone_levels - 1, 0, -1):
  163. fpn_outs[i] = nn.functional.interpolate(
  164. fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
  165. )
  166. fpn_outs = torch.cat(fpn_outs, dim=1)
  167. output = self.fpn_bottleneck(fpn_outs)
  168. output = self.classifier(output)
  169. return output
  170. class UperNetFCNHead(nn.Module):
  171. """
  172. Fully Convolution Networks for Semantic Segmentation. This head is the implementation of
  173. [FCNNet](https://huggingface.co/papers/1411.4038>).
  174. Args:
  175. config:
  176. Configuration.
  177. in_channels (int):
  178. Number of input channels.
  179. kernel_size (int):
  180. The kernel size for convs in the head. Default: 3.
  181. dilation (int):
  182. The dilation rate for convs in the head. Default: 1.
  183. """
  184. def __init__(
  185. self, config, in_channels, in_index: int = 2, kernel_size: int = 3, dilation: int | tuple[int, int] = 1
  186. ) -> None:
  187. super().__init__()
  188. self.config = config
  189. self.in_channels = (
  190. in_channels[in_index] if config.auxiliary_in_channels is None else config.auxiliary_in_channels
  191. )
  192. self.channels = config.auxiliary_channels
  193. self.num_convs = config.auxiliary_num_convs
  194. self.concat_input = config.auxiliary_concat_input
  195. self.in_index = in_index
  196. conv_padding = (kernel_size // 2) * dilation
  197. convs = []
  198. convs.append(
  199. UperNetConvModule(
  200. self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  201. )
  202. )
  203. for i in range(self.num_convs - 1):
  204. convs.append(
  205. UperNetConvModule(
  206. self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  207. )
  208. )
  209. if self.num_convs == 0:
  210. self.convs = nn.Identity()
  211. else:
  212. self.convs = nn.Sequential(*convs)
  213. if self.concat_input:
  214. self.conv_cat = UperNetConvModule(
  215. self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
  216. )
  217. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  218. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  219. # just take the relevant feature maps
  220. hidden_states = encoder_hidden_states[self.in_index]
  221. output = self.convs(hidden_states)
  222. if self.concat_input:
  223. output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
  224. output = self.classifier(output)
  225. return output
  226. @auto_docstring
  227. class UperNetPreTrainedModel(PreTrainedModel):
  228. config: UperNetConfig
  229. main_input_name = "pixel_values"
  230. input_modalities = ("image",)
  231. _no_split_modules = []
  232. @auto_docstring(
  233. custom_intro="""
  234. UperNet framework leveraging any vision backbone e.g. for ADE20k, CityScapes.
  235. """
  236. )
  237. class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
  238. def __init__(self, config):
  239. super().__init__(config)
  240. self.backbone = load_backbone(config)
  241. # Semantic segmentation head(s)
  242. self.decode_head = UperNetHead(config, in_channels=self.backbone.channels)
  243. self.auxiliary_head = (
  244. UperNetFCNHead(config, in_channels=self.backbone.channels) if config.use_auxiliary_head else None
  245. )
  246. # Initialize weights and apply final processing
  247. self.post_init()
  248. @auto_docstring
  249. def forward(
  250. self,
  251. pixel_values: torch.Tensor | None = None,
  252. output_attentions: bool | None = None,
  253. output_hidden_states: bool | None = None,
  254. labels: torch.Tensor | None = None,
  255. return_dict: bool | None = None,
  256. **kwargs,
  257. ) -> tuple | SemanticSegmenterOutput:
  258. r"""
  259. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  260. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  261. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  262. Examples:
  263. ```python
  264. >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
  265. >>> from PIL import Image
  266. >>> from huggingface_hub import hf_hub_download
  267. >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-tiny")
  268. >>> model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-tiny")
  269. >>> filepath = hf_hub_download(
  270. ... repo_id="hf-internal-testing/fixtures_ade20k", filename="ADE_val_00000001.jpg", repo_type="dataset"
  271. ... )
  272. >>> image = Image.open(filepath).convert("RGB")
  273. >>> inputs = image_processor(images=image, return_tensors="pt")
  274. >>> outputs = model(**inputs)
  275. >>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
  276. >>> list(logits.shape)
  277. [1, 150, 512, 512]
  278. ```"""
  279. if labels is not None and self.config.num_labels == 1:
  280. raise ValueError("The number of labels should be greater than one")
  281. return_dict = return_dict if return_dict is not None else self.config.return_dict
  282. output_hidden_states = (
  283. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  284. )
  285. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  286. outputs = self.backbone.forward_with_filtered_kwargs(
  287. pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  288. )
  289. features = outputs.feature_maps
  290. logits = self.decode_head(features)
  291. logits = nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)
  292. auxiliary_logits = None
  293. if self.auxiliary_head is not None:
  294. auxiliary_logits = self.auxiliary_head(features)
  295. auxiliary_logits = nn.functional.interpolate(
  296. auxiliary_logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False
  297. )
  298. loss = None
  299. if labels is not None:
  300. # compute weighted loss
  301. loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index)
  302. loss = loss_fct(logits, labels)
  303. if auxiliary_logits is not None:
  304. auxiliary_loss = loss_fct(auxiliary_logits, labels)
  305. loss += self.config.auxiliary_loss_weight * auxiliary_loss
  306. if not return_dict:
  307. if output_hidden_states:
  308. output = (logits,) + outputs[1:]
  309. else:
  310. output = (logits,) + outputs[2:]
  311. return ((loss,) + output) if loss is not None else output
  312. return SemanticSegmenterOutput(
  313. loss=loss,
  314. logits=logits,
  315. hidden_states=outputs.hidden_states,
  316. attentions=outputs.attentions,
  317. )
  318. __all__ = ["UperNetForSemanticSegmentation", "UperNetPreTrainedModel"]