modeling_vitmatte.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright 2023 HUST-VL and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ViTMatte model."""
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...backbone_utils import load_backbone
  20. from ...modeling_utils import PreTrainedModel
  21. from ...utils import ModelOutput, auto_docstring
  22. from .configuration_vitmatte import VitMatteConfig
  23. @dataclass
  24. @auto_docstring(
  25. custom_intro="""
  26. Class for outputs of image matting models.
  27. """
  28. )
  29. class ImageMattingOutput(ModelOutput):
  30. r"""
  31. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  32. Loss.
  33. alphas (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  34. Estimated alpha values.
  35. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  36. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  37. one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
  38. (also called feature maps) of the model at the output of each stage.
  39. """
  40. loss: torch.FloatTensor | None = None
  41. alphas: torch.FloatTensor | None = None
  42. hidden_states: tuple[torch.FloatTensor] | None = None
  43. attentions: tuple[torch.FloatTensor] | None = None
  44. @auto_docstring
  45. class VitMattePreTrainedModel(PreTrainedModel):
  46. config: VitMatteConfig
  47. main_input_name = "pixel_values"
  48. input_modalities = ("image",)
  49. supports_gradient_checkpointing = True
  50. _no_split_modules = []
  51. @torch.no_grad()
  52. def _init_weights(self, module: nn.Module):
  53. if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
  54. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  55. if module.bias is not None:
  56. init.zeros_(module.bias)
  57. if getattr(module, "running_mean", None) is not None:
  58. init.zeros_(module.running_mean)
  59. init.ones_(module.running_var)
  60. init.zeros_(module.num_batches_tracked)
  61. class VitMatteBasicConv3x3(nn.Module):
  62. """
  63. Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
  64. """
  65. def __init__(self, config, in_channels, out_channels, stride=2, padding=1):
  66. super().__init__()
  67. self.conv = nn.Conv2d(
  68. in_channels=in_channels,
  69. out_channels=out_channels,
  70. kernel_size=3,
  71. stride=stride,
  72. padding=padding,
  73. bias=False,
  74. )
  75. self.batch_norm = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
  76. self.relu = nn.ReLU()
  77. def forward(self, hidden_state):
  78. hidden_state = self.conv(hidden_state)
  79. hidden_state = self.batch_norm(hidden_state)
  80. hidden_state = self.relu(hidden_state)
  81. return hidden_state
  82. class VitMatteConvStream(nn.Module):
  83. """
  84. Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
  85. """
  86. def __init__(self, config):
  87. super().__init__()
  88. # We use a default in-case there isn't a backbone config set. This is for backwards compatibility and
  89. # to enable loading HF backbone models.
  90. in_channels = 4
  91. if config.backbone_config is not None:
  92. in_channels = config.backbone_config.num_channels
  93. out_channels = list(config.convstream_hidden_sizes)
  94. self.convs = nn.ModuleList()
  95. self.conv_chans = [in_channels] + out_channels
  96. for i in range(len(self.conv_chans) - 1):
  97. in_chan_ = self.conv_chans[i]
  98. out_chan_ = self.conv_chans[i + 1]
  99. self.convs.append(VitMatteBasicConv3x3(config, in_chan_, out_chan_))
  100. def forward(self, pixel_values):
  101. out_dict = {"detailed_feature_map_0": pixel_values}
  102. embeddings = pixel_values
  103. for i in range(len(self.convs)):
  104. embeddings = self.convs[i](embeddings)
  105. name_ = "detailed_feature_map_" + str(i + 1)
  106. out_dict[name_] = embeddings
  107. return out_dict
  108. class VitMatteFusionBlock(nn.Module):
  109. """
  110. Simple fusion block to fuse features from ConvStream and Plain Vision Transformer.
  111. """
  112. def __init__(self, config, in_channels, out_channels):
  113. super().__init__()
  114. self.conv = VitMatteBasicConv3x3(config, in_channels, out_channels, stride=1, padding=1)
  115. def forward(self, features, detailed_feature_map):
  116. upscaled_features = nn.functional.interpolate(features, scale_factor=2, mode="bilinear", align_corners=False)
  117. out = torch.cat([detailed_feature_map, upscaled_features], dim=1)
  118. out = self.conv(out)
  119. return out
  120. class VitMatteHead(nn.Module):
  121. """
  122. Simple Matting Head, containing only conv3x3 and conv1x1 layers.
  123. """
  124. def __init__(self, config):
  125. super().__init__()
  126. in_channels = config.fusion_hidden_sizes[-1]
  127. mid_channels = 16
  128. self.matting_convs = nn.Sequential(
  129. nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
  130. nn.BatchNorm2d(mid_channels),
  131. nn.ReLU(True),
  132. nn.Conv2d(mid_channels, 1, kernel_size=1, stride=1, padding=0),
  133. )
  134. def forward(self, hidden_state):
  135. hidden_state = self.matting_convs(hidden_state)
  136. return hidden_state
  137. class VitMatteDetailCaptureModule(nn.Module):
  138. """
  139. Simple and lightweight Detail Capture Module for ViT Matting.
  140. """
  141. def __init__(self, config):
  142. super().__init__()
  143. if len(config.fusion_hidden_sizes) != len(config.convstream_hidden_sizes) + 1:
  144. raise ValueError(
  145. "The length of fusion_hidden_sizes should be equal to the length of convstream_hidden_sizes + 1."
  146. )
  147. self.config = config
  148. self.convstream = VitMatteConvStream(config)
  149. self.conv_chans = self.convstream.conv_chans
  150. self.fusion_blocks = nn.ModuleList()
  151. self.fusion_channels = [config.hidden_size] + config.fusion_hidden_sizes
  152. for i in range(len(self.fusion_channels) - 1):
  153. self.fusion_blocks.append(
  154. VitMatteFusionBlock(
  155. config=config,
  156. in_channels=self.fusion_channels[i] + self.conv_chans[-(i + 1)],
  157. out_channels=self.fusion_channels[i + 1],
  158. )
  159. )
  160. self.matting_head = VitMatteHead(config)
  161. def forward(self, features, pixel_values):
  162. detail_features = self.convstream(pixel_values)
  163. for i in range(len(self.fusion_blocks)):
  164. detailed_feature_map_name = "detailed_feature_map_" + str(len(self.fusion_blocks) - i - 1)
  165. features = self.fusion_blocks[i](features, detail_features[detailed_feature_map_name])
  166. alphas = torch.sigmoid(self.matting_head(features))
  167. return alphas
  168. @auto_docstring(
  169. custom_intro="""
  170. ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes.
  171. """
  172. )
  173. class VitMatteForImageMatting(VitMattePreTrainedModel):
  174. def __init__(self, config):
  175. super().__init__(config)
  176. self.config = config
  177. self.backbone = load_backbone(config)
  178. self.decoder = VitMatteDetailCaptureModule(config)
  179. # Initialize weights and apply final processing
  180. self.post_init()
  181. @auto_docstring
  182. def forward(
  183. self,
  184. pixel_values: torch.Tensor | None = None,
  185. output_attentions: bool | None = None,
  186. output_hidden_states: bool | None = None,
  187. labels: torch.Tensor | None = None,
  188. return_dict: bool | None = None,
  189. **kwargs,
  190. ):
  191. r"""
  192. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  193. Ground truth image matting for computing the loss.
  194. Examples:
  195. ```python
  196. >>> from transformers import VitMatteImageProcessor, VitMatteForImageMatting
  197. >>> import torch
  198. >>> from PIL import Image
  199. >>> from huggingface_hub import hf_hub_download
  200. >>> processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k")
  201. >>> model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k")
  202. >>> filepath = hf_hub_download(
  203. ... repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset"
  204. ... )
  205. >>> image = Image.open(filepath).convert("RGB")
  206. >>> filepath = hf_hub_download(
  207. ... repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset"
  208. ... )
  209. >>> trimap = Image.open(filepath).convert("L")
  210. >>> # prepare image + trimap for the model
  211. >>> inputs = processor(images=image, trimaps=trimap, return_tensors="pt")
  212. >>> with torch.no_grad():
  213. ... alphas = model(**inputs).alphas
  214. >>> print(alphas.shape)
  215. torch.Size([1, 1, 640, 960])
  216. ```"""
  217. return_dict = return_dict if return_dict is not None else self.config.return_dict
  218. output_hidden_states = (
  219. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  220. )
  221. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  222. loss = None
  223. if labels is not None:
  224. raise NotImplementedError("Training is not yet supported")
  225. outputs = self.backbone.forward_with_filtered_kwargs(
  226. pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  227. )
  228. features = outputs.feature_maps[-1]
  229. alphas = self.decoder(features, pixel_values)
  230. if not return_dict:
  231. output = (alphas,) + outputs[1:]
  232. return ((loss,) + output) if loss is not None else output
  233. return ImageMattingOutput(
  234. loss=loss,
  235. alphas=alphas,
  236. hidden_states=outputs.hidden_states,
  237. attentions=outputs.attentions,
  238. )
  239. __all__ = ["VitMattePreTrainedModel", "VitMatteForImageMatting"]