modeling_vitdet.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ViTDet backbone."""
  15. import collections.abc
  16. import math
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BackboneOutput, BaseModelOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...utils import auto_docstring, logging
  26. from ...utils.generic import can_return_tuple
  27. from .configuration_vitdet import VitDetConfig
  28. logger = logging.get_logger(__name__)
  29. class VitDetEmbeddings(nn.Module):
  30. """
  31. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  32. `hidden_states` (patch embeddings) to be consumed by a Transformer.
  33. """
  34. def __init__(self, config):
  35. super().__init__()
  36. image_size, patch_size = config.pretrain_image_size, config.patch_size
  37. num_channels, hidden_size = config.num_channels, config.hidden_size
  38. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  39. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  40. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  41. self.image_size = image_size
  42. self.patch_size = patch_size
  43. self.num_channels = num_channels
  44. self.num_patches = num_patches
  45. if config.use_absolute_position_embeddings:
  46. # Initialize absolute positional embedding with pretrain image size.
  47. num_positions = num_patches + 1
  48. self.position_embeddings = nn.Parameter(torch.zeros(1, num_positions, config.hidden_size))
  49. else:
  50. self.position_embeddings = None
  51. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  52. def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, width):
  53. """
  54. Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
  55. original embeddings.
  56. Args:
  57. abs_pos_embeddings (`torch.Tensor`):
  58. Absolute positional embeddings with (1, num_position, num_channels).
  59. has_cls_token (`bool`):
  60. If true, has 1 embedding in abs_pos_embeddings for cls token.
  61. height (`int`):
  62. Height of input image tokens.
  63. width (`int`):
  64. Width of input image tokens.
  65. Returns:
  66. Absolute positional embeddings after processing with shape (1, height, width, num_channels)
  67. """
  68. if has_cls_token:
  69. abs_pos_embeddings = abs_pos_embeddings[:, 1:]
  70. num_position = abs_pos_embeddings.shape[1]
  71. size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
  72. if size * size != num_position:
  73. raise ValueError("Absolute position embeddings must be a square number.")
  74. if torch.jit.is_tracing() or (size != height or size != width):
  75. # nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
  76. new_abs_pos_embeddings = nn.functional.interpolate(
  77. abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
  78. size=(height, width),
  79. mode="bicubic",
  80. align_corners=False,
  81. )
  82. return new_abs_pos_embeddings.permute(0, 2, 3, 1)
  83. else:
  84. return abs_pos_embeddings.reshape(1, height, width, -1)
  85. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  86. num_channels = pixel_values.shape[1]
  87. if num_channels != self.num_channels:
  88. raise ValueError(
  89. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  90. f" Expected {self.num_channels} but got {num_channels}."
  91. )
  92. embeddings = self.projection(pixel_values)
  93. if self.position_embeddings is not None:
  94. # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
  95. embeddings = embeddings.permute(0, 2, 3, 1)
  96. # add position embeddings
  97. embeddings = embeddings + self.get_absolute_positions(
  98. self.position_embeddings, True, embeddings.shape[1], embeddings.shape[2]
  99. )
  100. # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
  101. embeddings = embeddings.permute(0, 3, 1, 2)
  102. return embeddings
  103. @torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
  104. def get_rel_pos(q_size, k_size, rel_pos):
  105. """
  106. Get relative positional embeddings according to the relative positions of query and key sizes.
  107. Args:
  108. q_size (`int`):
  109. Size of query q.
  110. k_size (`int`):
  111. Size of key k.
  112. rel_pos (`torch.Tensor`):
  113. Relative position embeddings (num_embeddings, num_channels).
  114. Returns:
  115. Extracted positional embeddings according to relative positions.
  116. """
  117. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  118. # Interpolate rel pos if needed.
  119. if rel_pos.shape[0] != max_rel_dist:
  120. # Interpolate rel position embeddings.
  121. rel_pos_resized = nn.functional.interpolate(
  122. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  123. size=max_rel_dist,
  124. mode="linear",
  125. )
  126. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  127. else:
  128. rel_pos_resized = rel_pos
  129. # Scale the coords with short length if shapes for q and k are different.
  130. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  131. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  132. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  133. return rel_pos_resized[relative_coords.long()]
  134. def add_decomposed_relative_positions(attn, queries, rel_pos_h, rel_pos_w, q_size, k_size):
  135. """
  136. Calculate decomposed Relative Positional Embeddings as introduced in
  137. [MViT2](https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py).
  138. Args:
  139. attn (`torch.Tensor`):
  140. Attention map.
  141. queries (`torch.Tensor`):
  142. Query q in the attention layer with shape (batch_size, queries_height * queries_width, num_channels).
  143. rel_pos_h (`torch.Tensor`):
  144. Relative position embeddings (Lh, num_channels) for height axis.
  145. rel_pos_w (`torch.Tensor`):
  146. Relative position embeddings (Lw, num_channels) for width axis.
  147. q_size (`tuple[int]`):
  148. Spatial sequence size of query q with (queries_height, queries_width).
  149. k_size (`tuple[int]`):
  150. Spatial sequence size of key k with (keys_height, keys_width).
  151. Returns:
  152. attn (Tensor): attention map with added relative positional embeddings.
  153. """
  154. queries_height, queries_width = q_size
  155. keys_height, keys_width = k_size
  156. relative_height = get_rel_pos(queries_height, keys_height, rel_pos_h)
  157. relative_width = get_rel_pos(queries_width, keys_width, rel_pos_w)
  158. batch_size, _, dim = queries.shape
  159. r_q = queries.reshape(batch_size, queries_height, queries_width, dim)
  160. relative_height = torch.einsum("bhwc,hkc->bhwk", r_q, relative_height)
  161. relative_weight = torch.einsum("bhwc,wkc->bhwk", r_q, relative_width)
  162. attn = (
  163. attn.view(batch_size, queries_height, queries_width, keys_height, keys_width)
  164. + relative_height[:, :, :, :, None]
  165. + relative_weight[:, :, :, None, :]
  166. ).view(batch_size, queries_height * queries_width, keys_height * keys_width)
  167. return attn
  168. class VitDetAttention(nn.Module):
  169. """Multi-head Attention block with relative position embeddings."""
  170. def __init__(self, config, input_size=None):
  171. """
  172. Args:
  173. config (`VitDetConfig`):
  174. Model configuration.
  175. input_size (`tuple[int]`, *optional*):
  176. Input resolution, only required in case relative position embeddings are added.
  177. """
  178. super().__init__()
  179. dim = config.hidden_size
  180. num_heads = config.num_attention_heads
  181. self.num_heads = num_heads
  182. head_dim = dim // num_heads
  183. self.scale = head_dim**-0.5
  184. self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
  185. self.proj = nn.Linear(dim, dim)
  186. self.use_relative_position_embeddings = config.use_relative_position_embeddings
  187. if self.use_relative_position_embeddings:
  188. # initialize relative positional embeddings
  189. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  190. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  191. def forward(self, hidden_state, output_attentions=False):
  192. batch_size, height, width, _ = hidden_state.shape
  193. # qkv with shape (3, batch_size, num_heads, height * width, num_channels)
  194. qkv = self.qkv(hidden_state).reshape(batch_size, height * width, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  195. # queries, keys and values have shape (batch_size * num_heads, height * width, num_channels)
  196. queries, keys, values = qkv.reshape(3, batch_size * self.num_heads, height * width, -1).unbind(0)
  197. attention_scores = (queries * self.scale) @ keys.transpose(-2, -1)
  198. if self.use_relative_position_embeddings:
  199. attention_scores = add_decomposed_relative_positions(
  200. attention_scores, queries, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  201. )
  202. attention_probs = attention_scores.softmax(dim=-1)
  203. hidden_state = attention_probs @ values
  204. hidden_state = hidden_state.view(batch_size, self.num_heads, height, width, -1)
  205. hidden_state = hidden_state.permute(0, 2, 3, 1, 4)
  206. hidden_state = hidden_state.reshape(batch_size, height, width, -1)
  207. hidden_state = self.proj(hidden_state)
  208. if output_attentions:
  209. attention_probs = attention_probs.reshape(
  210. batch_size, self.num_heads, attention_probs.shape[-2], attention_probs.shape[-1]
  211. )
  212. outputs = (hidden_state, attention_probs)
  213. else:
  214. outputs = (hidden_state,)
  215. return outputs
  216. # Copied from transformers.models.beit.modeling_beit.drop_path
  217. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  218. """
  219. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  220. """
  221. if drop_prob == 0.0 or not training:
  222. return input
  223. keep_prob = 1 - drop_prob
  224. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  225. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  226. random_tensor.floor_() # binarize
  227. output = input.div(keep_prob) * random_tensor
  228. return output
  229. # Copied from transformers.models.beit.modeling_beit.BeitDropPath
  230. class VitDetDropPath(nn.Module):
  231. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  232. def __init__(self, drop_prob: float | None = None) -> None:
  233. super().__init__()
  234. self.drop_prob = drop_prob
  235. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  236. return drop_path(hidden_states, self.drop_prob, self.training)
  237. def extra_repr(self) -> str:
  238. return f"p={self.drop_prob}"
  239. class VitDetLayerNorm(nn.Module):
  240. """
  241. A LayerNorm variant, popularized by Transformers, that performs point-wise mean and variance normalization over the
  242. channel dimension for inputs that have shape (batch_size, channels, height, width).
  243. https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
  244. """
  245. def __init__(self, normalized_shape, eps=1e-6):
  246. super().__init__()
  247. self.weight = nn.Parameter(torch.ones(normalized_shape))
  248. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  249. self.eps = eps
  250. self.normalized_shape = (normalized_shape,)
  251. def forward(self, x):
  252. u = x.mean(1, keepdim=True)
  253. s = (x - u).pow(2).mean(1, keepdim=True)
  254. x = (x - u) / torch.sqrt(s + self.eps)
  255. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  256. return x
  257. class VitDetResBottleneckBlock(nn.Module):
  258. """
  259. The standard bottleneck residual block without the last activation layer. It contains 3 conv layers with kernels
  260. 1x1, 3x3, 1x1.
  261. """
  262. def __init__(self, config, in_channels, out_channels, bottleneck_channels):
  263. """
  264. Args:
  265. config (`VitDetConfig`):
  266. Model configuration.
  267. in_channels (`int`):
  268. Number of input channels.
  269. out_channels (`int`):
  270. Number of output channels.
  271. bottleneck_channels (`int`):
  272. Number of output channels for the 3x3 "bottleneck" conv layers.
  273. """
  274. super().__init__()
  275. self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False)
  276. self.norm1 = VitDetLayerNorm(bottleneck_channels)
  277. self.act1 = ACT2FN[config.hidden_act]
  278. self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1, bias=False)
  279. self.norm2 = VitDetLayerNorm(bottleneck_channels)
  280. self.act2 = ACT2FN[config.hidden_act]
  281. self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False)
  282. self.norm3 = VitDetLayerNorm(out_channels)
  283. def forward(self, x):
  284. out = x
  285. for layer in self.children():
  286. out = layer(out)
  287. out = x + out
  288. return out
  289. class VitDetMlp(nn.Module):
  290. def __init__(self, config, in_features: int, hidden_features: int) -> None:
  291. super().__init__()
  292. self.fc1 = nn.Linear(in_features, hidden_features)
  293. self.act = ACT2FN[config.hidden_act]
  294. self.fc2 = nn.Linear(hidden_features, in_features)
  295. self.drop = nn.Dropout(config.dropout_prob)
  296. def forward(self, x: torch.Tensor) -> torch.Tensor:
  297. x = self.fc1(x)
  298. x = self.act(x)
  299. x = self.drop(x)
  300. x = self.fc2(x)
  301. x = self.drop(x)
  302. return x
  303. def window_partition(hidden_state, window_size):
  304. """
  305. Partition into non-overlapping windows with padding if needed.
  306. Args:
  307. hidden_state (`torch.Tensor`):
  308. Input tokens with [batch_size, height, width, num_channels].
  309. window_size (`int`):
  310. Window size.
  311. Returns:
  312. `tuple(torch.FloatTensor)` comprising various elements:
  313. - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
  314. - (padded_height, padded_width): padded height and width before partition
  315. """
  316. batch_size, height, width, num_channels = hidden_state.shape
  317. pad_height = (window_size - height % window_size) % window_size
  318. pad_width = (window_size - width % window_size) % window_size
  319. # Noop in case pad_width == 0 and pad_height == 0.
  320. hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
  321. padded_height, padded_width = height + pad_height, width + pad_width
  322. hidden_state = hidden_state.view(
  323. batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
  324. )
  325. windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  326. return windows, (padded_height, padded_width)
  327. def window_unpartition(windows, window_size, pad_height_width, height_width):
  328. """
  329. Window unpartition into original sequences and removing padding.
  330. Args:
  331. windows (`torch.Tensor`):
  332. Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
  333. window_size (`int`):
  334. Window size.
  335. pad_height_width (`tuple[int]`):
  336. Padded height and width (padded_height, padded_width).
  337. height_width (`tuple[int]`):
  338. Original height and width before padding.
  339. Returns:
  340. hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
  341. """
  342. padded_height, padded_width = pad_height_width
  343. height, width = height_width
  344. batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
  345. hidden_state = windows.view(
  346. batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
  347. )
  348. hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
  349. hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
  350. # We always have height <= padded_height and width <= padded_width
  351. hidden_state = hidden_state[:, :height, :width, :].contiguous()
  352. return hidden_state
  353. class VitDetLayer(GradientCheckpointingLayer):
  354. """This corresponds to the Block class in the original implementation."""
  355. def __init__(
  356. self, config: VitDetConfig, drop_path_rate: float = 0, window_size: int = 0, use_residual_block: bool = False
  357. ) -> None:
  358. super().__init__()
  359. dim = config.hidden_size
  360. image_size = config.image_size
  361. image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
  362. patch_size = config.patch_size
  363. patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
  364. input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  365. self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  366. self.attention = VitDetAttention(
  367. config, input_size=input_size if window_size == 0 else (window_size, window_size)
  368. )
  369. self.drop_path = VitDetDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  370. self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  371. self.mlp = VitDetMlp(config=config, in_features=dim, hidden_features=int(dim * config.mlp_ratio))
  372. self.window_size = window_size
  373. self.use_residual_block = use_residual_block
  374. if self.use_residual_block:
  375. # Use a residual block with bottleneck channel as dim // 2
  376. self.residual = VitDetResBottleneckBlock(
  377. config=config,
  378. in_channels=dim,
  379. out_channels=dim,
  380. bottleneck_channels=dim // 2,
  381. )
  382. def forward(
  383. self,
  384. hidden_states: torch.Tensor,
  385. output_attentions: bool = False,
  386. ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
  387. hidden_states = hidden_states.permute(0, 2, 3, 1)
  388. shortcut = hidden_states
  389. hidden_states = self.norm1(hidden_states)
  390. # Window partition
  391. if self.window_size > 0:
  392. height, width = hidden_states.shape[1], hidden_states.shape[2]
  393. hidden_states, pad_height_width = window_partition(hidden_states, self.window_size)
  394. self_attention_outputs = self.attention(
  395. hidden_states,
  396. output_attentions=output_attentions,
  397. )
  398. hidden_states = self_attention_outputs[0]
  399. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  400. # Reverse window partition
  401. if self.window_size > 0:
  402. hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width))
  403. # first residual connection
  404. hidden_states = shortcut + self.drop_path(hidden_states)
  405. hidden_states = hidden_states + self.drop_path(self.mlp(self.norm2(hidden_states)))
  406. hidden_states = hidden_states.permute(0, 3, 1, 2)
  407. if self.use_residual_block:
  408. hidden_states = self.residual(hidden_states)
  409. outputs = (hidden_states,) + outputs
  410. return outputs
  411. class VitDetEncoder(nn.Module):
  412. def __init__(self, config: VitDetConfig) -> None:
  413. super().__init__()
  414. self.config = config
  415. depth = config.num_hidden_layers
  416. # stochastic depth decay rule
  417. drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth, device="cpu")]
  418. layers = []
  419. for i in range(depth):
  420. layers.append(
  421. VitDetLayer(
  422. config,
  423. drop_path_rate=drop_path_rate[i],
  424. window_size=config.window_size if i in config.window_block_indices else 0,
  425. use_residual_block=i in config.residual_block_indices,
  426. )
  427. )
  428. self.layer = nn.ModuleList(layers)
  429. self.gradient_checkpointing = False
  430. def forward(
  431. self,
  432. hidden_states: torch.Tensor,
  433. output_attentions: bool = False,
  434. output_hidden_states: bool = False,
  435. return_dict: bool = True,
  436. ) -> tuple | BaseModelOutput:
  437. all_hidden_states = () if output_hidden_states else None
  438. all_self_attentions = () if output_attentions else None
  439. for i, layer_module in enumerate(self.layer):
  440. if output_hidden_states:
  441. all_hidden_states = all_hidden_states + (hidden_states,)
  442. layer_outputs = layer_module(hidden_states, output_attentions)
  443. hidden_states = layer_outputs[0]
  444. if output_attentions:
  445. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  446. if output_hidden_states:
  447. all_hidden_states = all_hidden_states + (hidden_states,)
  448. if not return_dict:
  449. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  450. return BaseModelOutput(
  451. last_hidden_state=hidden_states,
  452. hidden_states=all_hidden_states,
  453. attentions=all_self_attentions,
  454. )
  455. @auto_docstring
  456. class VitDetPreTrainedModel(PreTrainedModel):
  457. config: VitDetConfig
  458. base_model_prefix = "vitdet"
  459. main_input_name = "pixel_values"
  460. input_modalities = ("image",)
  461. supports_gradient_checkpointing = True
  462. _no_split_modules = []
  463. @torch.no_grad()
  464. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
  465. """Initialize the weights"""
  466. if isinstance(module, (nn.Linear, nn.Conv2d)):
  467. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  468. if module.bias is not None:
  469. init.zeros_(module.bias)
  470. elif isinstance(module, nn.LayerNorm):
  471. init.zeros_(module.bias)
  472. init.ones_(module.weight)
  473. elif isinstance(module, VitDetEmbeddings):
  474. init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
  475. elif isinstance(module, VitDetAttention) and self.config.use_relative_position_embeddings:
  476. init.trunc_normal_(module.rel_pos_h, mean=0.0, std=self.config.initializer_range)
  477. init.trunc_normal_(module.rel_pos_w, mean=0.0, std=self.config.initializer_range)
  478. elif isinstance(module, VitDetResBottleneckBlock):
  479. for layer in [module.conv1, module.conv2, module.conv3]:
  480. init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  481. if layer.bias is not None:
  482. init.constant_(layer.bias, 0)
  483. for layer in [module.norm1, module.norm2]:
  484. init.ones_(layer.weight)
  485. init.zeros_(layer.bias)
  486. # zero init last norm layer.
  487. init.zeros_(module.norm3.weight)
  488. init.zeros_(module.norm3.bias)
  489. @auto_docstring
  490. class VitDetModel(VitDetPreTrainedModel):
  491. def __init__(self, config: VitDetConfig):
  492. super().__init__(config)
  493. self.config = config
  494. self.embeddings = VitDetEmbeddings(config)
  495. self.encoder = VitDetEncoder(config)
  496. # Initialize weights and apply final processing
  497. self.post_init()
  498. def get_input_embeddings(self) -> VitDetEmbeddings:
  499. return self.embeddings.projection
  500. @auto_docstring
  501. def forward(
  502. self,
  503. pixel_values: torch.Tensor | None = None,
  504. output_attentions: bool | None = None,
  505. output_hidden_states: bool | None = None,
  506. return_dict: bool | None = None,
  507. **kwargs,
  508. ) -> tuple | BaseModelOutput:
  509. r"""
  510. Examples:
  511. ```python
  512. >>> from transformers import VitDetConfig, VitDetModel
  513. >>> import torch
  514. >>> config = VitDetConfig()
  515. >>> model = VitDetModel(config)
  516. >>> pixel_values = torch.randn(1, 3, 224, 224)
  517. >>> with torch.no_grad():
  518. ... outputs = model(pixel_values)
  519. >>> last_hidden_states = outputs.last_hidden_state
  520. >>> list(last_hidden_states.shape)
  521. [1, 768, 14, 14]
  522. ```"""
  523. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  524. output_hidden_states = (
  525. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  526. )
  527. return_dict = return_dict if return_dict is not None else self.config.return_dict
  528. if pixel_values is None:
  529. raise ValueError("You have to specify pixel_values")
  530. embedding_output = self.embeddings(pixel_values)
  531. encoder_outputs = self.encoder(
  532. embedding_output,
  533. output_attentions=output_attentions,
  534. output_hidden_states=output_hidden_states,
  535. return_dict=return_dict,
  536. )
  537. sequence_output = encoder_outputs[0]
  538. if not return_dict:
  539. return (sequence_output,) + encoder_outputs[1:]
  540. return BaseModelOutput(
  541. last_hidden_state=sequence_output,
  542. hidden_states=encoder_outputs.hidden_states,
  543. attentions=encoder_outputs.attentions,
  544. )
  545. @auto_docstring(
  546. custom_intro="""
  547. ViTDet backbone, to be used with frameworks like Mask R-CNN.
  548. """
  549. )
  550. class VitDetBackbone(BackboneMixin, VitDetPreTrainedModel):
  551. def __init__(self, config):
  552. super().__init__(config)
  553. self.embeddings = VitDetEmbeddings(config)
  554. self.encoder = VitDetEncoder(config)
  555. self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
  556. # initialize weights and apply final processing
  557. self.post_init()
  558. def get_input_embeddings(self) -> VitDetEmbeddings:
  559. return self.embeddings.projection
  560. @can_return_tuple
  561. @filter_output_hidden_states
  562. @auto_docstring
  563. def forward(
  564. self,
  565. pixel_values: torch.Tensor,
  566. output_hidden_states: bool | None = None,
  567. output_attentions: bool | None = None,
  568. return_dict: bool | None = None,
  569. **kwargs,
  570. ) -> BackboneOutput:
  571. r"""
  572. Examples:
  573. ```python
  574. >>> from transformers import VitDetConfig, VitDetBackbone
  575. >>> import torch
  576. >>> config = VitDetConfig()
  577. >>> model = VitDetBackbone(config)
  578. >>> pixel_values = torch.randn(1, 3, 224, 224)
  579. >>> with torch.no_grad():
  580. ... outputs = model(pixel_values)
  581. >>> feature_maps = outputs.feature_maps
  582. >>> list(feature_maps[-1].shape)
  583. [1, 768, 14, 14]
  584. ```"""
  585. return_dict = return_dict if return_dict is not None else self.config.return_dict
  586. output_hidden_states = (
  587. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  588. )
  589. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  590. embedding_output = self.embeddings(pixel_values)
  591. outputs = self.encoder(
  592. embedding_output,
  593. output_hidden_states=True,
  594. output_attentions=output_attentions,
  595. return_dict=return_dict,
  596. )
  597. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  598. feature_maps = ()
  599. for stage, hidden_state in zip(self.stage_names, hidden_states):
  600. if stage in self.out_features:
  601. feature_maps += (hidden_state,)
  602. if not return_dict:
  603. if output_hidden_states:
  604. output = (feature_maps,) + outputs[1:]
  605. else:
  606. output = (feature_maps,) + outputs[2:]
  607. return output
  608. return BackboneOutput(
  609. feature_maps=feature_maps,
  610. hidden_states=outputs.hidden_states if output_hidden_states else None,
  611. attentions=outputs.attentions,
  612. )
  613. __all__ = ["VitDetModel", "VitDetPreTrainedModel", "VitDetBackbone"]