modeling_beit.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448
  1. # Copyright 2021 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BEiT model."""
  15. import collections.abc
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import Tensor, nn
  20. from torch.nn import CrossEntropyLoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BackboneOutput,
  27. BaseModelOutput,
  28. BaseModelOutputWithPooling,
  29. ImageClassifierOutput,
  30. MaskedLMOutput,
  31. SemanticSegmenterOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...pytorch_utils import compile_compatible_method_lru_cache
  35. from ...utils import auto_docstring, logging, torch_int
  36. from ...utils.generic import can_return_tuple
  37. from .configuration_beit import BeitConfig
  38. logger = logging.get_logger(__name__)
  39. @dataclass
  40. @auto_docstring(
  41. custom_intro="""
  42. Class for outputs of [`BeitModel`].
  43. """
  44. )
  45. class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
  46. r"""
  47. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  48. Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
  49. *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
  50. will be returned.
  51. """
  52. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  53. """
  54. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  55. """
  56. if drop_prob == 0.0 or not training:
  57. return input
  58. keep_prob = 1 - drop_prob
  59. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  60. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  61. random_tensor.floor_() # binarize
  62. output = input.div(keep_prob) * random_tensor
  63. return output
  64. class BeitDropPath(nn.Module):
  65. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  66. def __init__(self, drop_prob: float | None = None) -> None:
  67. super().__init__()
  68. self.drop_prob = drop_prob
  69. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  70. return drop_path(hidden_states, self.drop_prob, self.training)
  71. def extra_repr(self) -> str:
  72. return f"p={self.drop_prob}"
  73. # Based on timm implementation, which can be found here:
  74. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  75. class BeitEmbeddings(nn.Module):
  76. """
  77. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  78. """
  79. def __init__(self, config: BeitConfig) -> None:
  80. super().__init__()
  81. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  82. if config.use_mask_token:
  83. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  84. else:
  85. self.mask_token = None
  86. self.patch_embeddings = BeitPatchEmbeddings(config)
  87. self.patch_size = config.patch_size
  88. self.image_size = (
  89. config.image_size
  90. if isinstance(config.image_size, collections.abc.Iterable)
  91. else (config.image_size, config.image_size)
  92. )
  93. num_patches = self.patch_embeddings.num_patches
  94. if config.use_absolute_position_embeddings:
  95. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  96. else:
  97. self.position_embeddings = None
  98. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  99. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  100. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  101. """
  102. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  103. images. This method is also adapted to support torch.jit tracing.
  104. Adapted from:
  105. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  106. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  107. """
  108. num_patches = embeddings.shape[1] - 1
  109. num_positions = self.position_embeddings.shape[1] - 1
  110. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  111. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  112. return self.position_embeddings
  113. class_pos_embed = self.position_embeddings[:, :1]
  114. patch_pos_embed = self.position_embeddings[:, 1:]
  115. dim = embeddings.shape[-1]
  116. new_height = height // self.patch_size
  117. new_width = width // self.patch_size
  118. sqrt_num_positions = torch_int(num_positions**0.5)
  119. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  120. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  121. patch_pos_embed = nn.functional.interpolate(
  122. patch_pos_embed,
  123. size=(new_height, new_width),
  124. mode="bicubic",
  125. align_corners=False,
  126. )
  127. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  128. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  129. def forward(
  130. self,
  131. pixel_values: torch.Tensor,
  132. bool_masked_pos: torch.BoolTensor | None = None,
  133. ) -> torch.Tensor:
  134. _, _, height, width = pixel_values.shape
  135. embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
  136. batch_size, seq_len, _ = embeddings.size()
  137. if bool_masked_pos is not None:
  138. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  139. # replace the masked visual tokens by mask_tokens
  140. w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  141. embeddings = embeddings * (1 - w) + mask_tokens * w
  142. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  143. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  144. if self.position_embeddings is not None:
  145. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  146. embeddings = self.dropout(embeddings)
  147. return embeddings, (patch_height, patch_width)
  148. class BeitPatchEmbeddings(nn.Module):
  149. """
  150. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  151. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  152. Transformer.
  153. """
  154. def __init__(self, config):
  155. super().__init__()
  156. image_size, patch_size = config.image_size, config.patch_size
  157. num_channels, hidden_size = config.num_channels, config.hidden_size
  158. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  159. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  160. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  161. patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  162. self.image_size = image_size
  163. self.patch_size = patch_size
  164. self.num_channels = num_channels
  165. self.num_patches = num_patches
  166. self.patch_shape = patch_shape
  167. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  168. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  169. batch_size, num_channels, height, width = pixel_values.shape
  170. if num_channels != self.num_channels:
  171. raise ValueError(
  172. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  173. )
  174. embeddings = self.projection(pixel_values.to(self.projection.weight.dtype))
  175. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  176. embeddings = embeddings.flatten(2).transpose(1, 2)
  177. return embeddings, (patch_height, patch_width)
  178. class BeitSelfAttention(nn.Module):
  179. def __init__(self, config: BeitConfig, window_size: tuple | None = None) -> None:
  180. super().__init__()
  181. self.config = config
  182. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  183. raise ValueError(
  184. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  185. f"heads {config.num_attention_heads}."
  186. )
  187. self.num_attention_heads = config.num_attention_heads
  188. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  189. self.all_head_size = self.num_attention_heads * self.attention_head_size
  190. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  191. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  192. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  193. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  194. self.has_relative_position_bias = bool(window_size)
  195. if self.has_relative_position_bias:
  196. self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
  197. def forward(
  198. self,
  199. hidden_states: torch.Tensor,
  200. output_attentions: bool = False,
  201. relative_position_bias: torch.Tensor | None = None,
  202. interpolate_pos_encoding: bool = False,
  203. resolution: tuple[int] | None = None,
  204. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
  205. input_shape = hidden_states.shape[:-1]
  206. hidden_shape = (*input_shape, -1, self.attention_head_size)
  207. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  208. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  209. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  210. # Take the dot product between "query" and "key" to get the raw attention scores.
  211. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  212. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  213. # Add relative position bias if present.
  214. if self.has_relative_position_bias:
  215. height, width = resolution
  216. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  217. attention_scores = attention_scores + self.relative_position_bias(
  218. window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  219. )
  220. # Add shared relative position bias if provided.
  221. if relative_position_bias is not None:
  222. attention_scores = attention_scores + relative_position_bias
  223. # Normalize the attention scores to probabilities.
  224. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  225. # This is actually dropping out entire tokens to attend to, which might
  226. # seem a bit unusual, but is taken from the original Transformer paper.
  227. attention_probs = self.dropout(attention_probs)
  228. context_layer = torch.matmul(attention_probs, value_layer)
  229. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  230. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  231. context_layer = context_layer.view(*new_context_layer_shape)
  232. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  233. return outputs
  234. class BeitSdpaSelfAttention(BeitSelfAttention):
  235. def forward(
  236. self,
  237. hidden_states: torch.Tensor,
  238. output_attentions: bool = False,
  239. relative_position_bias: torch.Tensor | None = None,
  240. interpolate_pos_encoding: bool = False,
  241. resolution: tuple[int] | None = None,
  242. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
  243. if output_attentions:
  244. logger.warning_once(
  245. f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
  246. "be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
  247. )
  248. input_shape = hidden_states.shape[:-1]
  249. hidden_shape = (*input_shape, -1, self.attention_head_size)
  250. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  251. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  252. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  253. attn_bias = None
  254. if self.has_relative_position_bias:
  255. height, width = resolution
  256. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  257. attn_bias = self.relative_position_bias(
  258. window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  259. )
  260. # Add shared relative position bias if provided.
  261. if relative_position_bias is not None:
  262. if attn_bias is None:
  263. attn_bias = relative_position_bias
  264. else:
  265. attn_bias += relative_position_bias
  266. scaling = 1 / math.sqrt(self.attention_head_size)
  267. context_layer = torch.nn.functional.scaled_dot_product_attention(
  268. query_layer,
  269. key_layer,
  270. value_layer,
  271. attn_mask=attn_bias,
  272. dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
  273. is_causal=False,
  274. scale=scaling,
  275. )
  276. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  277. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  278. context_layer = context_layer.view(*new_context_layer_shape)
  279. return context_layer, None
  280. class BeitSelfOutput(nn.Module):
  281. """
  282. The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
  283. layernorm applied before each block.
  284. """
  285. def __init__(self, config: BeitConfig) -> None:
  286. super().__init__()
  287. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  288. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  289. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
  290. hidden_states = self.dense(hidden_states)
  291. hidden_states = self.dropout(hidden_states)
  292. return hidden_states
  293. BEIT_SELF_ATTENTION_CLASSES = {
  294. "eager": BeitSelfAttention,
  295. "sdpa": BeitSdpaSelfAttention,
  296. }
  297. class BeitAttention(nn.Module):
  298. def __init__(self, config: BeitConfig, window_size: tuple | None = None) -> None:
  299. super().__init__()
  300. self.attention = BEIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, window_size=window_size)
  301. self.output = BeitSelfOutput(config)
  302. def forward(
  303. self,
  304. hidden_states: torch.Tensor,
  305. output_attentions: bool = False,
  306. relative_position_bias: torch.Tensor | None = None,
  307. interpolate_pos_encoding: bool = False,
  308. resolution: tuple[int] | None = None,
  309. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
  310. self_outputs = self.attention(
  311. hidden_states, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
  312. )
  313. attention_output = self.output(self_outputs[0], hidden_states)
  314. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  315. return outputs
  316. class BeitIntermediate(nn.Module):
  317. def __init__(self, config: BeitConfig) -> None:
  318. super().__init__()
  319. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  320. if isinstance(config.hidden_act, str):
  321. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  322. else:
  323. self.intermediate_act_fn = config.hidden_act
  324. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  325. hidden_states = self.dense(hidden_states)
  326. hidden_states = self.intermediate_act_fn(hidden_states)
  327. return hidden_states
  328. class BeitOutput(nn.Module):
  329. def __init__(self, config: BeitConfig) -> None:
  330. super().__init__()
  331. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  332. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  333. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  334. hidden_states = self.dense(hidden_states)
  335. hidden_states = self.dropout(hidden_states)
  336. return hidden_states
  337. class BeitLayer(GradientCheckpointingLayer):
  338. """This corresponds to the Block class in the timm implementation."""
  339. def __init__(self, config: BeitConfig, window_size: tuple | None = None, drop_path_rate: float = 0.0) -> None:
  340. super().__init__()
  341. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  342. self.seq_len_dim = 1
  343. self.attention = BeitAttention(config, window_size=window_size)
  344. self.intermediate = BeitIntermediate(config)
  345. self.output = BeitOutput(config)
  346. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  347. self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  348. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  349. init_values = config.layer_scale_init_value
  350. if init_values > 0:
  351. self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  352. self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
  353. else:
  354. self.lambda_1, self.lambda_2 = None, None
  355. def forward(
  356. self,
  357. hidden_states: torch.Tensor,
  358. output_attentions: bool = False,
  359. relative_position_bias: torch.Tensor | None = None,
  360. interpolate_pos_encoding: bool = False,
  361. resolution: tuple[int, int] | None = None,
  362. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
  363. self_attention_outputs = self.attention(
  364. self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
  365. output_attentions=output_attentions,
  366. relative_position_bias=relative_position_bias,
  367. interpolate_pos_encoding=interpolate_pos_encoding,
  368. resolution=resolution,
  369. )
  370. attention_output = self_attention_outputs[0]
  371. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  372. # apply lambda_1 if present
  373. if self.lambda_1 is not None:
  374. attention_output = self.lambda_1 * attention_output
  375. # first residual connection
  376. hidden_states = self.drop_path(attention_output) + hidden_states
  377. # in BEiT, layernorm is also applied after self-attention
  378. layer_output = self.layernorm_after(hidden_states)
  379. layer_output = self.intermediate(layer_output)
  380. layer_output = self.output(layer_output)
  381. if self.lambda_2 is not None:
  382. layer_output = self.lambda_2 * layer_output
  383. # second residual connection
  384. layer_output = self.drop_path(layer_output) + hidden_states
  385. outputs = (layer_output,) + outputs
  386. return outputs
  387. class BeitRelativePositionBias(nn.Module):
  388. def __init__(self, config: BeitConfig, window_size: tuple) -> None:
  389. super().__init__()
  390. self.window_size = window_size
  391. self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  392. self.relative_position_bias_table = nn.Parameter(
  393. torch.zeros(self.num_relative_distance, config.num_attention_heads)
  394. ) # 2*Wh-1 * 2*Ww-1, nH
  395. # cls to token & token 2 cls & cls to cls
  396. @compile_compatible_method_lru_cache(maxsize=10)
  397. def generate_relative_position_index(self, window_size: tuple[int, int]) -> torch.Tensor:
  398. """
  399. This method creates the relative position index, modified to support arbitrary window sizes,
  400. as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460).
  401. """
  402. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  403. # cls to token & token 2 cls & cls to cls
  404. # get pair-wise relative position index for each token inside the window
  405. window_area = window_size[0] * window_size[1]
  406. grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
  407. coords = torch.stack(grid) # 2, Wh, Ww
  408. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  409. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  410. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  411. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  412. relative_coords[:, :, 1] += window_size[1] - 1
  413. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  414. relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
  415. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  416. relative_position_index[0, 0:] = num_relative_distance - 3
  417. relative_position_index[0:, 0] = num_relative_distance - 2
  418. relative_position_index[0, 0] = num_relative_distance - 1
  419. return relative_position_index
  420. def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
  421. """
  422. Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
  423. """
  424. old_height = 2 * self.window_size[0] - 1
  425. old_width = 2 * self.window_size[1] - 1
  426. new_height = 2 * window_size[0] - 1
  427. new_width = 2 * window_size[1] - 1
  428. old_relative_position_bias_table = self.relative_position_bias_table
  429. old_num_relative_distance = self.num_relative_distance
  430. new_num_relative_distance = new_height * new_width + 3
  431. old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
  432. old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
  433. new_sub_table = nn.functional.interpolate(
  434. old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
  435. )
  436. new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
  437. new_relative_position_bias_table = torch.cat(
  438. [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
  439. )
  440. relative_position_index = self.generate_relative_position_index(window_size)
  441. relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
  442. # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
  443. relative_position_bias = relative_position_bias.view(
  444. window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
  445. )
  446. # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
  447. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  448. if interpolate_pos_encoding:
  449. relative_position_bias = nn.functional.interpolate(
  450. relative_position_bias.unsqueeze(1),
  451. size=(dim_size, dim_size),
  452. mode="bilinear",
  453. align_corners=False,
  454. ).squeeze(1)
  455. return relative_position_bias.unsqueeze(0)
  456. class BeitEncoder(nn.Module):
  457. def __init__(self, config: BeitConfig, window_size: tuple | None = None) -> None:
  458. super().__init__()
  459. self.config = config
  460. self.has_relative_position_bias = config.use_shared_relative_position_bias
  461. if self.has_relative_position_bias:
  462. self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
  463. # stochastic depth decay rule
  464. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
  465. self.layer = nn.ModuleList(
  466. [
  467. BeitLayer(
  468. config,
  469. window_size=window_size if config.use_relative_position_bias else None,
  470. drop_path_rate=dpr[i],
  471. )
  472. for i in range(config.num_hidden_layers)
  473. ]
  474. )
  475. self.gradient_checkpointing = False
  476. def forward(
  477. self,
  478. hidden_states: torch.Tensor,
  479. output_attentions: bool = False,
  480. output_hidden_states: bool = False,
  481. interpolate_pos_encoding: bool = False,
  482. resolution: tuple[int, int] | None = None,
  483. return_dict: bool = True,
  484. ) -> tuple | BaseModelOutput:
  485. all_hidden_states = () if output_hidden_states else None
  486. all_self_attentions = () if output_attentions else None
  487. for i, layer_module in enumerate(self.layer):
  488. if output_hidden_states:
  489. all_hidden_states = all_hidden_states + (hidden_states,)
  490. if self.has_relative_position_bias:
  491. height, width = resolution
  492. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  493. relative_position_bias = self.relative_position_bias(
  494. window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  495. )
  496. else:
  497. relative_position_bias = None
  498. layer_outputs = layer_module(
  499. hidden_states,
  500. output_attentions=output_attentions,
  501. relative_position_bias=relative_position_bias,
  502. interpolate_pos_encoding=interpolate_pos_encoding,
  503. resolution=resolution,
  504. )
  505. hidden_states = layer_outputs[0]
  506. if output_attentions:
  507. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  508. if output_hidden_states:
  509. all_hidden_states = all_hidden_states + (hidden_states,)
  510. if not return_dict:
  511. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  512. return BaseModelOutput(
  513. last_hidden_state=hidden_states,
  514. hidden_states=all_hidden_states,
  515. attentions=all_self_attentions,
  516. )
  517. @auto_docstring
  518. class BeitPreTrainedModel(PreTrainedModel):
  519. config: BeitConfig
  520. base_model_prefix = "beit"
  521. input_modalities = ("image",)
  522. main_input_name = "pixel_values"
  523. supports_gradient_checkpointing = True
  524. _no_split_modules = ["BeitLayer"]
  525. _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
  526. _supports_sdpa = True
  527. @torch.no_grad()
  528. def _init_weights(self, module):
  529. """Initialize the weights"""
  530. super()._init_weights(module)
  531. if isinstance(module, BeitEmbeddings):
  532. init.zeros_(module.cls_token)
  533. if module.mask_token is not None:
  534. init.zeros_(module.mask_token)
  535. if module.position_embeddings is not None:
  536. init.zeros_(module.position_embeddings)
  537. elif isinstance(module, BeitRelativePositionBias):
  538. init.zeros_(module.relative_position_bias_table)
  539. elif isinstance(module, BeitLayer):
  540. if module.lambda_1 is not None:
  541. init.constant_(module.lambda_1, self.config.layer_scale_init_value)
  542. init.constant_(module.lambda_2, self.config.layer_scale_init_value)
  543. @auto_docstring
  544. class BeitModel(BeitPreTrainedModel):
  545. def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None:
  546. r"""
  547. add_pooling_layer (bool, *optional*, defaults to `True`):
  548. Whether to add a pooling layer
  549. """
  550. super().__init__(config)
  551. self.config = config
  552. self.embeddings = BeitEmbeddings(config)
  553. self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
  554. self.layernorm = (
  555. nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  556. )
  557. self.pooler = BeitPooler(config) if add_pooling_layer else None
  558. # Initialize weights and apply final processing
  559. self.post_init()
  560. def get_input_embeddings(self):
  561. return self.embeddings.patch_embeddings
  562. @auto_docstring
  563. def forward(
  564. self,
  565. pixel_values: torch.Tensor,
  566. bool_masked_pos: torch.BoolTensor | None = None,
  567. output_attentions: bool | None = None,
  568. output_hidden_states: bool | None = None,
  569. interpolate_pos_encoding: bool = False,
  570. return_dict: bool | None = None,
  571. **kwargs,
  572. ) -> tuple | BeitModelOutputWithPooling:
  573. r"""
  574. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  575. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  576. """
  577. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  578. output_hidden_states = (
  579. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  580. )
  581. return_dict = return_dict if return_dict is not None else self.config.return_dict
  582. embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  583. resolution = pixel_values.shape[2:]
  584. encoder_outputs = self.encoder(
  585. embedding_output,
  586. output_attentions=output_attentions,
  587. output_hidden_states=output_hidden_states,
  588. resolution=resolution,
  589. return_dict=return_dict,
  590. interpolate_pos_encoding=interpolate_pos_encoding,
  591. )
  592. sequence_output = encoder_outputs[0]
  593. sequence_output = self.layernorm(sequence_output)
  594. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  595. if not return_dict:
  596. head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
  597. return head_outputs + encoder_outputs[1:]
  598. return BeitModelOutputWithPooling(
  599. last_hidden_state=sequence_output,
  600. pooler_output=pooled_output,
  601. hidden_states=encoder_outputs.hidden_states,
  602. attentions=encoder_outputs.attentions,
  603. )
  604. class BeitPooler(nn.Module):
  605. def __init__(self, config: BeitConfig) -> None:
  606. super().__init__()
  607. self.layernorm = (
  608. nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
  609. )
  610. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  611. if self.layernorm is not None:
  612. # Mean pool the final hidden states of the patch tokens
  613. patch_tokens = hidden_states[:, 1:, :]
  614. pooled_output = self.layernorm(patch_tokens.mean(1))
  615. else:
  616. # Pool by simply taking the final hidden state of the [CLS] token
  617. pooled_output = hidden_states[:, 0]
  618. return pooled_output
  619. @auto_docstring(
  620. custom_intro="""
  621. Beit Model transformer with a 'language' modeling head on top. BEiT does masked image modeling by predicting
  622. visual tokens of a Vector-Quantize Variational Autoencoder (VQ-VAE), whereas other vision models like ViT and DeiT
  623. predict RGB pixel values. As a result, this class is incompatible with [`AutoModelForMaskedImageModeling`], so you
  624. will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.
  625. """
  626. )
  627. class BeitForMaskedImageModeling(BeitPreTrainedModel):
  628. def __init__(self, config: BeitConfig) -> None:
  629. super().__init__(config)
  630. self.num_labels = config.num_labels
  631. self.beit = BeitModel(config, add_pooling_layer=False)
  632. # Classifier head
  633. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  634. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
  635. # Initialize weights and apply final processing
  636. self.post_init()
  637. def get_output_embeddings(self):
  638. return None
  639. @auto_docstring
  640. def forward(
  641. self,
  642. pixel_values: torch.Tensor | None = None,
  643. bool_masked_pos: torch.BoolTensor | None = None,
  644. labels: torch.Tensor | None = None,
  645. output_attentions: bool | None = None,
  646. output_hidden_states: bool | None = None,
  647. interpolate_pos_encoding: bool = False,
  648. return_dict: bool | None = None,
  649. **kwargs,
  650. ) -> tuple | MaskedLMOutput:
  651. r"""
  652. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  653. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  654. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  655. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  656. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  657. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  658. Examples:
  659. ```python
  660. >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
  661. >>> import torch
  662. >>> from PIL import Image
  663. >>> import httpx
  664. >>> from io import BytesIO
  665. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  666. >>> with httpx.stream("GET", url) as response:
  667. ... image = Image.open(BytesIO(response.read()))
  668. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
  669. >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
  670. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  671. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  672. >>> # create random boolean mask of shape (batch_size, num_patches)
  673. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  674. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  675. >>> loss, logits = outputs.loss, outputs.logits
  676. >>> list(logits.shape)
  677. [1, 196, 8192]
  678. ```"""
  679. return_dict = return_dict if return_dict is not None else self.config.return_dict
  680. outputs = self.beit(
  681. pixel_values,
  682. bool_masked_pos=bool_masked_pos,
  683. output_attentions=output_attentions,
  684. output_hidden_states=output_hidden_states,
  685. interpolate_pos_encoding=interpolate_pos_encoding,
  686. return_dict=return_dict,
  687. )
  688. sequence_output = outputs[0]
  689. sequence_output = self.layernorm(sequence_output)
  690. prediction_scores = self.lm_head(sequence_output[:, 1:])
  691. masked_lm_loss = None
  692. if labels is not None:
  693. loss_fct = CrossEntropyLoss() # -100 index = padding token
  694. masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
  695. if not return_dict:
  696. output = (prediction_scores,) + outputs[1:]
  697. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  698. return MaskedLMOutput(
  699. loss=masked_lm_loss,
  700. logits=prediction_scores,
  701. hidden_states=outputs.hidden_states,
  702. attentions=outputs.attentions,
  703. )
  704. @auto_docstring(
  705. custom_intro="""
  706. Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
  707. hidden states of the patch tokens) e.g. for ImageNet.
  708. """
  709. )
  710. class BeitForImageClassification(BeitPreTrainedModel):
  711. def __init__(self, config: BeitConfig) -> None:
  712. super().__init__(config)
  713. self.num_labels = config.num_labels
  714. self.beit = BeitModel(config, add_pooling_layer=True)
  715. # Classifier head
  716. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  717. # Initialize weights and apply final processing
  718. self.post_init()
  719. @auto_docstring
  720. def forward(
  721. self,
  722. pixel_values: torch.Tensor | None = None,
  723. labels: torch.Tensor | None = None,
  724. output_attentions: bool | None = None,
  725. output_hidden_states: bool | None = None,
  726. interpolate_pos_encoding: bool = False,
  727. return_dict: bool | None = None,
  728. **kwargs,
  729. ) -> tuple | ImageClassifierOutput:
  730. r"""
  731. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  732. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  733. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  734. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  735. """
  736. return_dict = return_dict if return_dict is not None else self.config.return_dict
  737. outputs = self.beit(
  738. pixel_values,
  739. output_attentions=output_attentions,
  740. output_hidden_states=output_hidden_states,
  741. interpolate_pos_encoding=interpolate_pos_encoding,
  742. return_dict=return_dict,
  743. )
  744. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  745. logits = self.classifier(pooled_output)
  746. loss = None
  747. if labels is not None:
  748. loss = self.loss_function(labels, logits, self.config)
  749. if not return_dict:
  750. output = (logits,) + outputs[2:]
  751. return ((loss,) + output) if loss is not None else output
  752. return ImageClassifierOutput(
  753. loss=loss,
  754. logits=logits,
  755. hidden_states=outputs.hidden_states,
  756. attentions=outputs.attentions,
  757. )
  758. class BeitConvModule(nn.Module):
  759. """
  760. A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
  761. layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
  762. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  763. """
  764. def __init__(
  765. self,
  766. in_channels: int,
  767. out_channels: int,
  768. kernel_size: int | tuple[int, int],
  769. padding: int | tuple[int, int] | str = 0,
  770. bias: bool = False,
  771. dilation: int | tuple[int, int] = 1,
  772. ) -> None:
  773. super().__init__()
  774. self.conv = nn.Conv2d(
  775. in_channels=in_channels,
  776. out_channels=out_channels,
  777. kernel_size=kernel_size,
  778. padding=padding,
  779. bias=bias,
  780. dilation=dilation,
  781. )
  782. self.bn = nn.BatchNorm2d(out_channels)
  783. self.activation = nn.ReLU()
  784. def forward(self, input: torch.Tensor) -> torch.Tensor:
  785. output = self.conv(input)
  786. output = self.bn(output)
  787. output = self.activation(output)
  788. return output
  789. class BeitPyramidPoolingBlock(nn.Module):
  790. def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
  791. super().__init__()
  792. self.layers = [
  793. nn.AdaptiveAvgPool2d(pool_scale),
  794. BeitConvModule(in_channels, channels, kernel_size=1),
  795. ]
  796. for i, layer in enumerate(self.layers):
  797. self.add_module(str(i), layer)
  798. def forward(self, input: torch.Tensor) -> torch.Tensor:
  799. hidden_state = input
  800. for layer in self.layers:
  801. hidden_state = layer(hidden_state)
  802. return hidden_state
  803. class BeitPyramidPoolingModule(nn.Module):
  804. """
  805. Pyramid Pooling Module (PPM) used in PSPNet.
  806. Args:
  807. pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
  808. Module.
  809. in_channels (int): Input channels.
  810. channels (int): Channels after modules, before conv_seg.
  811. align_corners (bool): align_corners argument of F.interpolate.
  812. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  813. """
  814. def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
  815. super().__init__()
  816. self.pool_scales = pool_scales
  817. self.align_corners = align_corners
  818. self.in_channels = in_channels
  819. self.channels = channels
  820. self.blocks = []
  821. for i, pool_scale in enumerate(pool_scales):
  822. block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
  823. self.blocks.append(block)
  824. self.add_module(str(i), block)
  825. def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
  826. ppm_outs = []
  827. for ppm in self.blocks:
  828. ppm_out = ppm(x)
  829. upsampled_ppm_out = nn.functional.interpolate(
  830. ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
  831. )
  832. ppm_outs.append(upsampled_ppm_out)
  833. return ppm_outs
  834. class BeitUperHead(nn.Module):
  835. """
  836. Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
  837. [UPerNet](https://huggingface.co/papers/1807.10221).
  838. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  839. """
  840. def __init__(self, config: BeitConfig) -> None:
  841. super().__init__()
  842. self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
  843. self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
  844. self.channels = config.hidden_size
  845. self.align_corners = False
  846. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  847. # PSP Module
  848. self.psp_modules = BeitPyramidPoolingModule(
  849. self.pool_scales,
  850. self.in_channels[-1],
  851. self.channels,
  852. align_corners=self.align_corners,
  853. )
  854. self.bottleneck = BeitConvModule(
  855. self.in_channels[-1] + len(self.pool_scales) * self.channels,
  856. self.channels,
  857. kernel_size=3,
  858. padding=1,
  859. )
  860. # FPN Module
  861. self.lateral_convs = nn.ModuleList()
  862. self.fpn_convs = nn.ModuleList()
  863. for in_channels in self.in_channels[:-1]: # skip the top layer
  864. l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
  865. fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
  866. self.lateral_convs.append(l_conv)
  867. self.fpn_convs.append(fpn_conv)
  868. self.fpn_bottleneck = BeitConvModule(
  869. len(self.in_channels) * self.channels,
  870. self.channels,
  871. kernel_size=3,
  872. padding=1,
  873. )
  874. def psp_forward(self, inputs):
  875. x = inputs[-1]
  876. psp_outs = [x]
  877. psp_outs.extend(self.psp_modules(x))
  878. psp_outs = torch.cat(psp_outs, dim=1)
  879. output = self.bottleneck(psp_outs)
  880. return output
  881. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  882. # build laterals
  883. laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
  884. laterals.append(self.psp_forward(encoder_hidden_states))
  885. # build top-down path
  886. used_backbone_levels = len(laterals)
  887. for i in range(used_backbone_levels - 1, 0, -1):
  888. prev_shape = laterals[i - 1].shape[2:]
  889. laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
  890. laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
  891. )
  892. # build outputs
  893. fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
  894. # append psp feature
  895. fpn_outs.append(laterals[-1])
  896. for i in range(used_backbone_levels - 1, 0, -1):
  897. fpn_outs[i] = nn.functional.interpolate(
  898. fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
  899. )
  900. fpn_outs = torch.cat(fpn_outs, dim=1)
  901. output = self.fpn_bottleneck(fpn_outs)
  902. output = self.classifier(output)
  903. return output
  904. class BeitFCNHead(nn.Module):
  905. """
  906. Fully Convolution Networks for Semantic Segmentation. This head is implemented of
  907. [FCNNet](https://huggingface.co/papers/1411.4038>).
  908. Args:
  909. config (BeitConfig): Configuration.
  910. in_channels
  911. kernel_size (int): The kernel size for convs in the head. Default: 3.
  912. dilation (int): The dilation rate for convs in the head. Default: 1.
  913. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  914. """
  915. def __init__(
  916. self, config: BeitConfig, in_index: int = 2, kernel_size: int = 3, dilation: int | tuple[int, int] = 1
  917. ) -> None:
  918. super().__init__()
  919. self.in_channels = config.hidden_size
  920. self.channels = config.auxiliary_channels
  921. self.num_convs = config.auxiliary_num_convs
  922. self.concat_input = config.auxiliary_concat_input
  923. self.in_index = in_index
  924. conv_padding = (kernel_size // 2) * dilation
  925. convs = []
  926. convs.append(
  927. BeitConvModule(
  928. self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  929. )
  930. )
  931. for i in range(self.num_convs - 1):
  932. convs.append(
  933. BeitConvModule(
  934. self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  935. )
  936. )
  937. if self.num_convs == 0:
  938. self.convs = nn.Identity()
  939. else:
  940. self.convs = nn.Sequential(*convs)
  941. if self.concat_input:
  942. self.conv_cat = BeitConvModule(
  943. self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
  944. )
  945. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  946. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  947. # just take the relevant feature maps
  948. hidden_states = encoder_hidden_states[self.in_index]
  949. output = self.convs(hidden_states)
  950. if self.concat_input:
  951. output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
  952. output = self.classifier(output)
  953. return output
  954. @auto_docstring
  955. class BeitForSemanticSegmentation(BeitPreTrainedModel):
  956. def __init__(self, config: BeitConfig) -> None:
  957. super().__init__(config)
  958. self.num_labels = config.num_labels
  959. self.beit = BeitModel(config, add_pooling_layer=False)
  960. # FPNs
  961. if len(self.config.out_indices) != 4:
  962. raise ValueError(
  963. "BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
  964. "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
  965. "a base-sized architecture."
  966. )
  967. self.fpn1 = nn.Sequential(
  968. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  969. nn.BatchNorm2d(config.hidden_size),
  970. nn.GELU(),
  971. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  972. )
  973. self.fpn2 = nn.Sequential(
  974. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  975. )
  976. self.fpn3 = nn.Identity()
  977. self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
  978. # Semantic segmentation head(s)
  979. self.decode_head = BeitUperHead(config)
  980. self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
  981. # Initialize weights and apply final processing
  982. self.post_init()
  983. def compute_loss(self, logits, auxiliary_logits, labels):
  984. # upsample logits to the images' original size
  985. upsampled_logits = nn.functional.interpolate(
  986. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  987. )
  988. if auxiliary_logits is not None:
  989. upsampled_auxiliary_logits = nn.functional.interpolate(
  990. auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  991. )
  992. # compute weighted loss
  993. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  994. main_loss = loss_fct(upsampled_logits, labels)
  995. loss = main_loss
  996. if auxiliary_logits is not None:
  997. auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
  998. loss += self.config.auxiliary_loss_weight * auxiliary_loss
  999. return loss
  1000. @auto_docstring
  1001. def forward(
  1002. self,
  1003. pixel_values: torch.Tensor | None = None,
  1004. labels: torch.Tensor | None = None,
  1005. output_attentions: bool | None = None,
  1006. output_hidden_states: bool | None = None,
  1007. interpolate_pos_encoding: bool = False,
  1008. return_dict: bool | None = None,
  1009. **kwargs,
  1010. ) -> tuple | SemanticSegmenterOutput:
  1011. r"""
  1012. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  1013. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  1014. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  1015. Examples:
  1016. ```python
  1017. >>> from transformers import AutoImageProcessor, BeitForSemanticSegmentation
  1018. >>> from PIL import Image
  1019. >>> import httpx
  1020. >>> from io import BytesIO
  1021. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1022. >>> with httpx.stream("GET", url) as response:
  1023. ... image = Image.open(BytesIO(response.read()))
  1024. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
  1025. >>> model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
  1026. >>> inputs = image_processor(images=image, return_tensors="pt")
  1027. >>> outputs = model(**inputs)
  1028. >>> # logits are of shape (batch_size, num_labels, height, width)
  1029. >>> logits = outputs.logits
  1030. ```"""
  1031. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1032. output_hidden_states = (
  1033. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1034. )
  1035. if labels is not None and self.config.num_labels == 1:
  1036. raise ValueError("The number of labels should be greater than one")
  1037. outputs = self.beit(
  1038. pixel_values,
  1039. output_attentions=output_attentions,
  1040. output_hidden_states=True, # we need the intermediate hidden states
  1041. interpolate_pos_encoding=interpolate_pos_encoding,
  1042. return_dict=return_dict,
  1043. )
  1044. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  1045. # only keep certain features, and reshape
  1046. # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
  1047. features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
  1048. batch_size = pixel_values.shape[0]
  1049. patch_resolution = self.config.image_size // self.config.patch_size
  1050. features = [
  1051. x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
  1052. ]
  1053. # apply FPNs
  1054. ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
  1055. for i in range(len(features)):
  1056. features[i] = ops[i](features[i])
  1057. logits = self.decode_head(features)
  1058. auxiliary_logits = None
  1059. if self.auxiliary_head is not None:
  1060. auxiliary_logits = self.auxiliary_head(features)
  1061. loss = None
  1062. if labels is not None:
  1063. loss = self.compute_loss(logits, auxiliary_logits, labels)
  1064. if not return_dict:
  1065. if output_hidden_states:
  1066. output = (logits,) + outputs[1:]
  1067. else:
  1068. output = (logits,) + outputs[2:]
  1069. return ((loss,) + output) if loss is not None else output
  1070. return SemanticSegmenterOutput(
  1071. loss=loss,
  1072. logits=logits,
  1073. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1074. attentions=outputs.attentions,
  1075. )
  1076. @auto_docstring(
  1077. custom_intro="""
  1078. BEiT backbone, to be used with frameworks like DETR and MaskFormer.
  1079. """
  1080. )
  1081. class BeitBackbone(BackboneMixin, BeitPreTrainedModel):
  1082. def __init__(self, config):
  1083. super().__init__(config)
  1084. self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
  1085. self.embeddings = BeitEmbeddings(config)
  1086. self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
  1087. if config.add_fpn:
  1088. if len(self.config.out_indices) != 4:
  1089. raise ValueError(
  1090. "BeitBackbone requires config.out_indices to be a list of 4 integers, "
  1091. "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
  1092. "a base-sized architecture."
  1093. )
  1094. hidden_size = config.hidden_size
  1095. self.fpn1 = nn.Sequential(
  1096. nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
  1097. nn.BatchNorm2d(hidden_size, eps=config.batch_norm_eps),
  1098. nn.GELU(),
  1099. nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
  1100. )
  1101. self.fpn2 = nn.Sequential(nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2))
  1102. self.fpn3 = nn.Identity()
  1103. self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
  1104. # initialize weights and apply final processing
  1105. self.post_init()
  1106. def get_input_embeddings(self):
  1107. return self.embeddings.patch_embeddings
  1108. @can_return_tuple
  1109. @filter_output_hidden_states
  1110. @auto_docstring
  1111. def forward(
  1112. self,
  1113. pixel_values: Tensor,
  1114. output_hidden_states: bool | None = None,
  1115. output_attentions: bool | None = None,
  1116. return_dict: bool | None = None,
  1117. **kwargs,
  1118. ) -> BackboneOutput:
  1119. r"""
  1120. Examples:
  1121. ```python
  1122. >>> from transformers import AutoImageProcessor, AutoBackbone
  1123. >>> import torch
  1124. >>> from PIL import Image
  1125. >>> import httpx
  1126. >>> from io import BytesIO
  1127. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1128. >>> with httpx.stream("GET", url) as response:
  1129. ... image = Image.open(BytesIO(response.read()))
  1130. >>> processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
  1131. >>> model = AutoBackbone.from_pretrained(
  1132. ... "microsoft/beit-base-patch16-224", out_features=["stage1", "stage2", "stage3", "stage4"]
  1133. ... )
  1134. >>> inputs = processor(image, return_tensors="pt")
  1135. >>> outputs = model(**inputs)
  1136. >>> feature_maps = outputs.feature_maps
  1137. >>> list(feature_maps[-1].shape)
  1138. [1, 768, 14, 14]
  1139. ```"""
  1140. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1141. output_hidden_states = (
  1142. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1143. )
  1144. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1145. batch_size = pixel_values.shape[0]
  1146. embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)
  1147. resolution = pixel_values.shape[2:]
  1148. outputs = self.encoder(
  1149. embedding_output,
  1150. output_hidden_states=True,
  1151. output_attentions=output_attentions,
  1152. resolution=resolution,
  1153. return_dict=return_dict,
  1154. )
  1155. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  1156. feature_maps = ()
  1157. for stage, hidden_state in zip(self.stage_names, hidden_states):
  1158. if stage in self.out_features:
  1159. if self.config.reshape_hidden_states:
  1160. hidden_state = hidden_state[:, 1:, :]
  1161. hidden_state = hidden_state.permute(0, 2, 1)
  1162. hidden_state = hidden_state.reshape(batch_size, -1, patch_height, patch_width)
  1163. feature_maps += (hidden_state,)
  1164. if self.config.add_fpn:
  1165. feature_maps = [
  1166. self.fpn1(feature_maps[0]),
  1167. self.fpn2(feature_maps[1]),
  1168. self.fpn3(feature_maps[2]),
  1169. self.fpn4(feature_maps[3]),
  1170. ]
  1171. feature_maps = tuple(feature_maps)
  1172. if not return_dict:
  1173. if output_hidden_states:
  1174. output = (feature_maps,) + outputs[1:]
  1175. else:
  1176. output = (feature_maps,) + outputs[2:]
  1177. return output
  1178. return BackboneOutput(
  1179. feature_maps=feature_maps,
  1180. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1181. attentions=outputs.attentions,
  1182. )
  1183. __all__ = [
  1184. "BeitForImageClassification",
  1185. "BeitForMaskedImageModeling",
  1186. "BeitForSemanticSegmentation",
  1187. "BeitModel",
  1188. "BeitPreTrainedModel",
  1189. "BeitBackbone",
  1190. ]