modeling_focalnet.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934
  1. # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch FocalNet model."""
  15. import collections.abc
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BackboneOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import ModelOutput, auto_docstring, logging
  27. from ...utils.generic import can_return_tuple
  28. from .configuration_focalnet import FocalNetConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. FocalNet encoder's outputs, with potential hidden states.
  34. """
  35. )
  36. class FocalNetEncoderOutput(ModelOutput):
  37. r"""
  38. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  39. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  40. shape `(batch_size, hidden_size, height, width)`.
  41. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  42. include the spatial dimensions.
  43. """
  44. last_hidden_state: torch.FloatTensor | None = None
  45. hidden_states: tuple[torch.FloatTensor] | None = None
  46. reshaped_hidden_states: tuple[torch.FloatTensor] | None = None
  47. @dataclass
  48. @auto_docstring(
  49. custom_intro="""
  50. FocalNet model's outputs that also contains a pooling of the last hidden states.
  51. """
  52. )
  53. class FocalNetModelOutput(ModelOutput):
  54. r"""
  55. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  56. Average pooling of the last layer hidden-state.
  57. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  58. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  59. shape `(batch_size, hidden_size, height, width)`.
  60. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  61. include the spatial dimensions.
  62. """
  63. last_hidden_state: torch.FloatTensor | None = None
  64. pooler_output: torch.FloatTensor | None = None
  65. hidden_states: tuple[torch.FloatTensor] | None = None
  66. reshaped_hidden_states: tuple[torch.FloatTensor] | None = None
  67. @dataclass
  68. @auto_docstring(
  69. custom_intro="""
  70. FocalNet masked image model outputs.
  71. """
  72. )
  73. class FocalNetMaskedImageModelingOutput(ModelOutput):
  74. r"""
  75. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
  76. Masked image modeling (MLM) loss.
  77. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  78. Reconstructed pixel values.
  79. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  80. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  81. shape `(batch_size, hidden_size, height, width)`.
  82. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  83. include the spatial dimensions.
  84. """
  85. loss: torch.FloatTensor | None = None
  86. reconstruction: torch.FloatTensor | None = None
  87. hidden_states: tuple[torch.FloatTensor] | None = None
  88. reshaped_hidden_states: tuple[torch.FloatTensor] | None = None
  89. @dataclass
  90. @auto_docstring(
  91. custom_intro="""
  92. FocalNet outputs for image classification.
  93. """
  94. )
  95. class FocalNetImageClassifierOutput(ModelOutput):
  96. r"""
  97. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  98. Classification (or regression if config.num_labels==1) loss.
  99. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  100. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  101. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  102. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  103. shape `(batch_size, hidden_size, height, width)`.
  104. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  105. include the spatial dimensions.
  106. """
  107. loss: torch.FloatTensor | None = None
  108. logits: torch.FloatTensor | None = None
  109. hidden_states: tuple[torch.FloatTensor] | None = None
  110. reshaped_hidden_states: tuple[torch.FloatTensor] | None = None
  111. class FocalNetEmbeddings(nn.Module):
  112. """
  113. Construct the patch embeddings and layernorm. Optionally, also the mask token.
  114. """
  115. def __init__(self, config, use_mask_token=False):
  116. super().__init__()
  117. self.patch_embeddings = FocalNetPatchEmbeddings(
  118. config=config,
  119. image_size=config.image_size,
  120. patch_size=config.patch_size,
  121. num_channels=config.num_channels,
  122. embed_dim=config.embed_dim,
  123. use_conv_embed=config.use_conv_embed,
  124. is_stem=True,
  125. )
  126. self.patch_grid = self.patch_embeddings.grid_size
  127. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
  128. self.norm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
  129. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  130. def forward(
  131. self, pixel_values: torch.FloatTensor | None, bool_masked_pos: torch.BoolTensor | None = None
  132. ) -> tuple[torch.Tensor]:
  133. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  134. embeddings = self.norm(embeddings)
  135. batch_size, seq_len, _ = embeddings.size()
  136. if bool_masked_pos is not None:
  137. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  138. # replace the masked visual tokens by mask_tokens
  139. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  140. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  141. embeddings = self.dropout(embeddings)
  142. return embeddings, output_dimensions
  143. class FocalNetPatchEmbeddings(nn.Module):
  144. def __init__(
  145. self,
  146. config,
  147. image_size,
  148. patch_size,
  149. num_channels,
  150. embed_dim,
  151. add_norm=False,
  152. use_conv_embed=False,
  153. is_stem=False,
  154. ):
  155. super().__init__()
  156. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  157. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  158. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  159. self.image_size = image_size
  160. self.patch_size = patch_size
  161. self.num_channels = num_channels
  162. self.num_patches = num_patches
  163. self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  164. if use_conv_embed:
  165. # if we choose to use conv embedding, then we treat the stem and non-stem differently
  166. if is_stem:
  167. kernel_size = 7
  168. padding = 2
  169. stride = 4
  170. else:
  171. kernel_size = 3
  172. padding = 1
  173. stride = 2
  174. self.projection = nn.Conv2d(
  175. num_channels, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
  176. )
  177. else:
  178. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
  179. if add_norm:
  180. self.norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  181. else:
  182. self.norm = None
  183. def maybe_pad(self, pixel_values, height, width):
  184. if width % self.patch_size[1] != 0:
  185. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  186. pixel_values = nn.functional.pad(pixel_values, pad_values)
  187. if height % self.patch_size[0] != 0:
  188. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  189. pixel_values = nn.functional.pad(pixel_values, pad_values)
  190. return pixel_values
  191. def forward(self, pixel_values: torch.FloatTensor | None) -> tuple[torch.Tensor, tuple[int]]:
  192. _, num_channels, height, width = pixel_values.shape
  193. if num_channels != self.num_channels:
  194. raise ValueError(
  195. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  196. )
  197. # pad the input to be divisible by self.patch_size, if needed
  198. pixel_values = self.maybe_pad(pixel_values, height, width)
  199. embeddings = self.projection(pixel_values)
  200. _, _, height, width = embeddings.shape
  201. output_dimensions = (height, width)
  202. embeddings = embeddings.flatten(2).transpose(1, 2)
  203. if self.norm is not None:
  204. embeddings = self.norm(embeddings)
  205. return embeddings, output_dimensions
  206. # Copied from transformers.models.beit.modeling_beit.drop_path
  207. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  208. """
  209. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  210. """
  211. if drop_prob == 0.0 or not training:
  212. return input
  213. keep_prob = 1 - drop_prob
  214. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  215. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  216. random_tensor.floor_() # binarize
  217. output = input.div(keep_prob) * random_tensor
  218. return output
  219. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->FocalNet
  220. class FocalNetDropPath(nn.Module):
  221. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  222. def __init__(self, drop_prob: float | None = None) -> None:
  223. super().__init__()
  224. self.drop_prob = drop_prob
  225. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  226. return drop_path(hidden_states, self.drop_prob, self.training)
  227. def extra_repr(self) -> str:
  228. return f"p={self.drop_prob}"
  229. class FocalNetModulation(nn.Module):
  230. def __init__(self, config, index, dim, focal_factor=2, bias=True, projection_dropout=0.0):
  231. super().__init__()
  232. self.dim = dim
  233. self.focal_window = config.focal_windows[index]
  234. self.focal_level = config.focal_levels[index]
  235. self.focal_factor = focal_factor
  236. self.use_post_layernorm_in_modulation = config.use_post_layernorm_in_modulation
  237. self.normalize_modulator = config.normalize_modulator
  238. self.projection_in = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias)
  239. self.projection_context = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
  240. self.activation = nn.GELU()
  241. self.projection_out = nn.Linear(dim, dim)
  242. self.projection_dropout = nn.Dropout(projection_dropout)
  243. self.focal_layers = nn.ModuleList()
  244. self.kernel_sizes = []
  245. for k in range(self.focal_level):
  246. kernel_size = self.focal_factor * k + self.focal_window
  247. self.focal_layers.append(
  248. nn.Sequential(
  249. nn.Conv2d(
  250. dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size // 2, bias=False
  251. ),
  252. nn.GELU(),
  253. )
  254. )
  255. self.kernel_sizes.append(kernel_size)
  256. if self.use_post_layernorm_in_modulation:
  257. self.layernorm = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  258. def forward(self, hidden_state):
  259. """
  260. Args:
  261. hidden_state:
  262. Input features with shape of (batch_size, height, width, num_channels)
  263. """
  264. num_channels = hidden_state.shape[-1]
  265. # pre linear projection
  266. x = self.projection_in(hidden_state).permute(0, 3, 1, 2).contiguous()
  267. q, ctx, gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
  268. # context aggregation
  269. ctx_all = 0
  270. for level in range(self.focal_level):
  271. ctx = self.focal_layers[level](ctx)
  272. ctx_all = ctx_all + ctx * gates[:, level : level + 1]
  273. ctx_global = self.activation(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
  274. ctx_all = ctx_all + ctx_global * gates[:, self.focal_level :]
  275. # normalize context
  276. if self.normalize_modulator:
  277. ctx_all = ctx_all / (self.focal_level + 1)
  278. # focal modulation
  279. modulator = self.projection_context(ctx_all)
  280. x_out = q * modulator
  281. x_out = x_out.permute(0, 2, 3, 1).contiguous()
  282. if self.use_post_layernorm_in_modulation:
  283. x_out = self.layernorm(x_out)
  284. # post linear projection
  285. x_out = self.projection_out(x_out)
  286. x_out = self.projection_dropout(x_out)
  287. return x_out
  288. class FocalNetMlp(nn.Module):
  289. def __init__(self, config, in_features, hidden_features=None, out_features=None, drop=0.0):
  290. super().__init__()
  291. out_features = out_features or in_features
  292. hidden_features = hidden_features or in_features
  293. self.fc1 = nn.Linear(in_features, hidden_features)
  294. self.activation = ACT2FN[config.hidden_act]
  295. self.fc2 = nn.Linear(hidden_features, out_features)
  296. self.drop = nn.Dropout(drop)
  297. def forward(self, hidden_state):
  298. hidden_state = self.fc1(hidden_state)
  299. hidden_state = self.activation(hidden_state)
  300. hidden_state = self.drop(hidden_state)
  301. hidden_state = self.fc2(hidden_state)
  302. hidden_state = self.drop(hidden_state)
  303. return hidden_state
  304. class FocalNetLayer(nn.Module):
  305. r"""Focal Modulation Network layer (block).
  306. Args:
  307. config (`FocalNetConfig`):
  308. Model config.
  309. index (`int`):
  310. Layer index.
  311. dim (`int`):
  312. Number of input channels.
  313. input_resolution (`tuple[int]`):
  314. Input resolution.
  315. drop_path (`float`, *optional*, defaults to 0.0):
  316. Stochastic depth rate.
  317. """
  318. def __init__(self, config, index, dim, input_resolution, drop_path=0.0):
  319. super().__init__()
  320. self.config = config
  321. # layer-specific attributes
  322. self.dim = dim
  323. self.input_resolution = input_resolution
  324. # general attributes
  325. self.drop = config.hidden_dropout_prob
  326. self.use_post_layernorm = config.use_post_layernorm
  327. self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  328. self.modulation = FocalNetModulation(
  329. config=config,
  330. index=index,
  331. dim=dim,
  332. projection_dropout=self.drop,
  333. )
  334. self.drop_path = FocalNetDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  335. self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  336. mlp_hidden_dim = int(dim * config.mlp_ratio)
  337. self.mlp = FocalNetMlp(config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=self.drop)
  338. self.gamma_1 = 1.0
  339. self.gamma_2 = 1.0
  340. if config.use_layerscale:
  341. self.gamma_1 = nn.Parameter(config.layerscale_value * torch.ones(dim), requires_grad=True)
  342. self.gamma_2 = nn.Parameter(config.layerscale_value * torch.ones(dim), requires_grad=True)
  343. def forward(self, hidden_state, input_dimensions):
  344. height, width = input_dimensions
  345. batch_size, _, num_channels = hidden_state.shape
  346. shortcut = hidden_state
  347. # Focal Modulation
  348. hidden_state = hidden_state if self.use_post_layernorm else self.norm1(hidden_state)
  349. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  350. hidden_state = self.modulation(hidden_state).view(batch_size, height * width, num_channels)
  351. hidden_state = hidden_state if not self.use_post_layernorm else self.norm1(hidden_state)
  352. # FFN
  353. hidden_state = shortcut + self.drop_path(self.gamma_1 * hidden_state)
  354. hidden_state = hidden_state + self.drop_path(
  355. self.gamma_2
  356. * (self.norm2(self.mlp(hidden_state)) if self.use_post_layernorm else self.mlp(self.norm2(hidden_state)))
  357. )
  358. return hidden_state
  359. class FocalNetStage(GradientCheckpointingLayer):
  360. def __init__(self, config, index, input_resolution):
  361. super().__init__()
  362. self.config = config
  363. self.num_stages = len(config.depths)
  364. embed_dim = [config.embed_dim * (2**i) for i in range(self.num_stages)]
  365. dim = embed_dim[index]
  366. out_dim = embed_dim[index + 1] if (index < self.num_stages - 1) else None
  367. downsample = FocalNetPatchEmbeddings if (index < self.num_stages - 1) else None
  368. # stochastic depth decay rule
  369. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  370. drop_path = dpr[sum(config.depths[:index]) : sum(config.depths[: index + 1])]
  371. self.layers = nn.ModuleList(
  372. [
  373. FocalNetLayer(
  374. config=config,
  375. index=index,
  376. dim=dim,
  377. input_resolution=input_resolution,
  378. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  379. )
  380. for i in range(config.depths[index])
  381. ]
  382. )
  383. if downsample is not None:
  384. self.downsample = downsample(
  385. config=config,
  386. image_size=input_resolution,
  387. patch_size=2,
  388. num_channels=dim,
  389. embed_dim=out_dim,
  390. add_norm=True,
  391. use_conv_embed=config.use_conv_embed,
  392. is_stem=False,
  393. )
  394. else:
  395. self.downsample = None
  396. self.pointing = False
  397. def forward(self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int]) -> tuple[torch.Tensor]:
  398. height, width = input_dimensions
  399. for layer_module in self.layers:
  400. hidden_states = layer_module(hidden_states, input_dimensions)
  401. hidden_states_before_downsampling = hidden_states
  402. if self.downsample is not None:
  403. height, width = input_dimensions
  404. hidden_states = hidden_states.transpose(1, 2).reshape(
  405. hidden_states_before_downsampling.shape[0], -1, height, width
  406. )
  407. hidden_states, output_dimensions = self.downsample(hidden_states)
  408. else:
  409. output_dimensions = (height, width, height, width)
  410. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  411. return stage_outputs
  412. class FocalNetEncoder(nn.Module):
  413. def __init__(self, config, grid_size):
  414. super().__init__()
  415. self.num_stages = len(config.depths)
  416. self.config = config
  417. self.stages = nn.ModuleList(
  418. [
  419. FocalNetStage(
  420. config=config,
  421. index=i_layer,
  422. input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
  423. )
  424. for i_layer in range(self.num_stages)
  425. ]
  426. )
  427. self.gradient_checkpointing = False
  428. def forward(
  429. self,
  430. hidden_states: torch.Tensor,
  431. input_dimensions: tuple[int, int],
  432. output_hidden_states: bool | None = False,
  433. output_hidden_states_before_downsampling: bool | None = False,
  434. return_dict: bool | None = True,
  435. ) -> tuple | FocalNetEncoderOutput:
  436. all_hidden_states = () if output_hidden_states else None
  437. all_reshaped_hidden_states = () if output_hidden_states else None
  438. if output_hidden_states:
  439. batch_size, _, hidden_size = hidden_states.shape
  440. # rearrange b (h w) c -> b c h w
  441. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  442. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  443. all_hidden_states += (hidden_states,)
  444. all_reshaped_hidden_states += (reshaped_hidden_state,)
  445. for i, stage_module in enumerate(self.stages):
  446. stage_outputs = stage_module(hidden_states, input_dimensions)
  447. hidden_states = stage_outputs[0]
  448. hidden_states_before_downsampling = stage_outputs[1]
  449. output_dimensions = stage_outputs[2]
  450. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  451. if output_hidden_states and output_hidden_states_before_downsampling:
  452. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  453. # rearrange b (h w) c -> b c h w
  454. # here we use the original (not downsampled) height and width
  455. reshaped_hidden_state = hidden_states_before_downsampling.view(
  456. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  457. )
  458. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  459. all_hidden_states += (hidden_states_before_downsampling,)
  460. all_reshaped_hidden_states += (reshaped_hidden_state,)
  461. elif output_hidden_states and not output_hidden_states_before_downsampling:
  462. batch_size, _, hidden_size = hidden_states.shape
  463. # rearrange b (h w) c -> b c h w
  464. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  465. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  466. all_hidden_states += (hidden_states,)
  467. all_reshaped_hidden_states += (reshaped_hidden_state,)
  468. if not return_dict:
  469. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  470. return FocalNetEncoderOutput(
  471. last_hidden_state=hidden_states,
  472. hidden_states=all_hidden_states,
  473. reshaped_hidden_states=all_reshaped_hidden_states,
  474. )
  475. @auto_docstring
  476. class FocalNetPreTrainedModel(PreTrainedModel):
  477. config: FocalNetConfig
  478. base_model_prefix = "focalnet"
  479. main_input_name = "pixel_values"
  480. supports_gradient_checkpointing = True
  481. _no_split_modules = ["FocalNetStage"]
  482. @torch.no_grad()
  483. def _init_weights(self, module):
  484. """Initialize the weights"""
  485. super()._init_weights(module)
  486. if isinstance(module, FocalNetEmbeddings):
  487. if module.mask_token is not None:
  488. init.zeros_(module.mask_token)
  489. elif isinstance(module, FocalNetLayer):
  490. if self.config.use_layerscale:
  491. init.constant_(module.gamma_1, self.config.layerscale_value)
  492. init.constant_(module.gamma_2, self.config.layerscale_value)
  493. @auto_docstring
  494. class FocalNetModel(FocalNetPreTrainedModel):
  495. def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
  496. r"""
  497. add_pooling_layer (bool, *optional*, defaults to `True`):
  498. Whether to add a pooling layer
  499. use_mask_token (`bool`, *optional*, defaults to `False`):
  500. Whether to use a mask token for masked image modeling.
  501. """
  502. super().__init__(config)
  503. self.config = config
  504. self.num_stages = len(config.depths)
  505. self.num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
  506. self.embeddings = FocalNetEmbeddings(config, use_mask_token=use_mask_token)
  507. self.encoder = FocalNetEncoder(config, self.embeddings.patch_grid)
  508. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  509. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  510. # Initialize weights and apply final processing
  511. self.post_init()
  512. def get_input_embeddings(self):
  513. return self.embeddings.patch_embeddings
  514. @auto_docstring
  515. def forward(
  516. self,
  517. pixel_values: torch.FloatTensor | None = None,
  518. bool_masked_pos: torch.BoolTensor | None = None,
  519. output_hidden_states: bool | None = None,
  520. return_dict: bool | None = None,
  521. **kwargs,
  522. ) -> tuple | FocalNetModelOutput:
  523. r"""
  524. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  525. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  526. """
  527. output_hidden_states = (
  528. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  529. )
  530. return_dict = return_dict if return_dict is not None else self.config.return_dict
  531. if pixel_values is None:
  532. raise ValueError("You have to specify pixel_values")
  533. embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  534. encoder_outputs = self.encoder(
  535. embedding_output,
  536. input_dimensions,
  537. output_hidden_states=output_hidden_states,
  538. return_dict=return_dict,
  539. )
  540. sequence_output = encoder_outputs[0]
  541. sequence_output = self.layernorm(sequence_output)
  542. pooled_output = None
  543. if self.pooler is not None:
  544. pooled_output = self.pooler(sequence_output.transpose(1, 2))
  545. pooled_output = torch.flatten(pooled_output, 1)
  546. if not return_dict:
  547. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  548. return output
  549. return FocalNetModelOutput(
  550. last_hidden_state=sequence_output,
  551. pooler_output=pooled_output,
  552. hidden_states=encoder_outputs.hidden_states,
  553. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  554. )
  555. @auto_docstring(
  556. custom_intro="""
  557. FocalNet Model with a decoder on top for masked image modeling.
  558. This follows the same implementation as in [SimMIM](https://huggingface.co/papers/2111.09886).
  559. <Tip>
  560. Note that we provide a script to pre-train this model on custom data in our [examples
  561. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  562. </Tip>
  563. """
  564. )
  565. class FocalNetForMaskedImageModeling(FocalNetPreTrainedModel):
  566. def __init__(self, config):
  567. super().__init__(config)
  568. self.focalnet = FocalNetModel(config, add_pooling_layer=False, use_mask_token=True)
  569. self.num_stages = len(config.depths)
  570. num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
  571. self.decoder = nn.Sequential(
  572. nn.Conv2d(
  573. in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
  574. ),
  575. nn.PixelShuffle(config.encoder_stride),
  576. )
  577. # Initialize weights and apply final processing
  578. self.post_init()
  579. @auto_docstring
  580. def forward(
  581. self,
  582. pixel_values: torch.FloatTensor | None = None,
  583. bool_masked_pos: torch.BoolTensor | None = None,
  584. output_hidden_states: bool | None = None,
  585. return_dict: bool | None = None,
  586. **kwargs,
  587. ) -> tuple | FocalNetMaskedImageModelingOutput:
  588. r"""
  589. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  590. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  591. Examples:
  592. ```python
  593. >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
  594. >>> import torch
  595. >>> from PIL import Image
  596. >>> import httpx
  597. >>> from io import BytesIO
  598. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  599. >>> with httpx.stream("GET", url) as response:
  600. ... image = Image.open(BytesIO(response.read()))
  601. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
  602. >>> config = FocalNetConfig()
  603. >>> model = FocalNetForMaskedImageModeling(config)
  604. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  605. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  606. >>> # create random boolean mask of shape (batch_size, num_patches)
  607. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  608. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  609. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
  610. >>> list(reconstructed_pixel_values.shape)
  611. [1, 3, 192, 192]
  612. ```"""
  613. return_dict = return_dict if return_dict is not None else self.config.return_dict
  614. outputs = self.focalnet(
  615. pixel_values,
  616. bool_masked_pos=bool_masked_pos,
  617. output_hidden_states=output_hidden_states,
  618. return_dict=return_dict,
  619. )
  620. sequence_output = outputs[0]
  621. # Reshape to (batch_size, num_channels, height, width)
  622. sequence_output = sequence_output.transpose(1, 2)
  623. batch_size, num_channels, sequence_length = sequence_output.shape
  624. height = width = math.floor(sequence_length**0.5)
  625. sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
  626. # Reconstruct pixel values
  627. reconstructed_pixel_values = self.decoder(sequence_output)
  628. masked_im_loss = None
  629. if bool_masked_pos is not None:
  630. size = self.config.image_size // self.config.patch_size
  631. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  632. mask = (
  633. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  634. .repeat_interleave(self.config.patch_size, 2)
  635. .unsqueeze(1)
  636. .contiguous()
  637. )
  638. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  639. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  640. if not return_dict:
  641. output = (reconstructed_pixel_values,) + outputs[2:]
  642. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  643. return FocalNetMaskedImageModelingOutput(
  644. loss=masked_im_loss,
  645. reconstruction=reconstructed_pixel_values,
  646. hidden_states=outputs.hidden_states,
  647. reshaped_hidden_states=outputs.reshaped_hidden_states,
  648. )
  649. @auto_docstring(
  650. custom_intro="""
  651. FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
  652. ImageNet.
  653. """
  654. )
  655. class FocalNetForImageClassification(FocalNetPreTrainedModel):
  656. # Copied from transformers.models.swin.modeling_swin.SwinForImageClassification.__init__ with Swin->FocalNet, swin->focalnet
  657. def __init__(self, config):
  658. super().__init__(config)
  659. self.num_labels = config.num_labels
  660. self.focalnet = FocalNetModel(config)
  661. # Classifier head
  662. self.classifier = (
  663. nn.Linear(self.focalnet.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  664. )
  665. # Initialize weights and apply final processing
  666. self.post_init()
  667. @auto_docstring
  668. def forward(
  669. self,
  670. pixel_values: torch.FloatTensor | None = None,
  671. labels: torch.LongTensor | None = None,
  672. output_hidden_states: bool | None = None,
  673. return_dict: bool | None = None,
  674. **kwargs,
  675. ) -> tuple | FocalNetImageClassifierOutput:
  676. r"""
  677. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  678. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  679. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  680. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  681. """
  682. return_dict = return_dict if return_dict is not None else self.config.return_dict
  683. outputs = self.focalnet(
  684. pixel_values,
  685. output_hidden_states=output_hidden_states,
  686. return_dict=return_dict,
  687. )
  688. pooled_output = outputs[1]
  689. logits = self.classifier(pooled_output)
  690. loss = None
  691. if labels is not None:
  692. loss = self.loss_function(labels, logits, self.config)
  693. if not return_dict:
  694. output = (logits,) + outputs[2:]
  695. return ((loss,) + output) if loss is not None else output
  696. return FocalNetImageClassifierOutput(
  697. loss=loss,
  698. logits=logits,
  699. hidden_states=outputs.hidden_states,
  700. reshaped_hidden_states=outputs.reshaped_hidden_states,
  701. )
  702. @auto_docstring(
  703. custom_intro="""
  704. FocalNet backbone, to be used with frameworks like X-Decoder.
  705. """
  706. )
  707. class FocalNetBackbone(BackboneMixin, FocalNetPreTrainedModel):
  708. has_attentions = False
  709. def __init__(self, config: FocalNetConfig):
  710. super().__init__(config)
  711. self.num_features = [config.embed_dim] + config.hidden_sizes
  712. self.focalnet = FocalNetModel(config)
  713. # initialize weights and apply final processing
  714. self.post_init()
  715. @can_return_tuple
  716. @filter_output_hidden_states
  717. @auto_docstring
  718. def forward(
  719. self,
  720. pixel_values: torch.Tensor,
  721. output_hidden_states: bool | None = None,
  722. return_dict: bool | None = None,
  723. **kwargs,
  724. ) -> BackboneOutput:
  725. r"""
  726. Examples:
  727. ```python
  728. >>> from transformers import AutoImageProcessor, AutoBackbone
  729. >>> import torch
  730. >>> from PIL import Image
  731. >>> import httpx
  732. >>> from io import BytesIO
  733. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  734. >>> with httpx.stream("GET", url) as response:
  735. ... image = Image.open(BytesIO(response.read()))
  736. >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
  737. >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")
  738. >>> inputs = processor(image, return_tensors="pt")
  739. >>> outputs = model(**inputs)
  740. ```"""
  741. return_dict = return_dict if return_dict is not None else self.config.return_dict
  742. output_hidden_states = (
  743. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  744. )
  745. outputs = self.focalnet(pixel_values, output_hidden_states=True, return_dict=True)
  746. hidden_states = outputs.reshaped_hidden_states
  747. feature_maps = ()
  748. for idx, stage in enumerate(self.stage_names):
  749. if stage in self.out_features:
  750. feature_maps += (hidden_states[idx],)
  751. if not return_dict:
  752. output = (feature_maps,)
  753. if output_hidden_states:
  754. output += (outputs.hidden_states,)
  755. return output
  756. return BackboneOutput(
  757. feature_maps=feature_maps,
  758. hidden_states=outputs.hidden_states if output_hidden_states else None,
  759. attentions=None,
  760. )
  761. __all__ = [
  762. "FocalNetForImageClassification",
  763. "FocalNetForMaskedImageModeling",
  764. "FocalNetBackbone",
  765. "FocalNetModel",
  766. "FocalNetPreTrainedModel",
  767. ]