modeling_hiera.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404
  1. # Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Hiera model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BackboneOutput,
  25. BaseModelOutput,
  26. BaseModelOutputWithPooling,
  27. ImageClassifierOutput,
  28. ModelOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import auto_docstring, logging, torch_int
  32. from ...utils.generic import can_return_tuple
  33. from .configuration_hiera import HieraConfig
  34. logger = logging.get_logger(__name__)
  35. @dataclass
  36. @auto_docstring(
  37. custom_intro="""
  38. Hiera encoder's outputs, with potential hidden states and attentions.
  39. """
  40. )
  41. class HieraEncoderOutput(ModelOutput):
  42. r"""
  43. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  44. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  45. shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
  46. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  47. include the spatial dimensions.
  48. """
  49. last_hidden_state: torch.FloatTensor | None = None
  50. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  51. attentions: tuple[torch.FloatTensor, ...] | None = None
  52. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  53. @dataclass
  54. @auto_docstring(
  55. custom_intro="""
  56. Hiera model's outputs that also contains a pooling of the last hidden states.
  57. """
  58. )
  59. class HieraModelOutput(ModelOutput):
  60. r"""
  61. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  62. Average pooling of the last layer hidden-state.
  63. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
  64. Tensor indicating which patches are masked (0) and which are not (1).
  65. ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  66. Tensor containing the original index of the (shuffled) masked patches.
  67. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  68. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  69. shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
  70. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  71. include the spatial dimensions.
  72. """
  73. last_hidden_state: torch.FloatTensor | None = None
  74. pooler_output: torch.FloatTensor | None = None
  75. bool_masked_pos: torch.BoolTensor | None = None
  76. ids_restore: torch.LongTensor | None = None
  77. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  78. attentions: tuple[torch.FloatTensor, ...] | None = None
  79. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  80. @dataclass
  81. @auto_docstring(
  82. custom_intro="""
  83. Hiera image classification outputs.
  84. """
  85. )
  86. class HieraForImageClassificationOutput(ImageClassifierOutput):
  87. r"""
  88. loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
  89. Loss value for the training task.
  90. logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
  91. Prediction scores of the classification head (logits of the output layer).
  92. hidden_states (`tuple(torch.FloatTensor)`, `optional`):
  93. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  94. shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
  95. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  96. attentions (`tuple(torch.FloatTensor)`, `optional`):
  97. Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
  98. sequence_length)`.
  99. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  100. heads.
  101. reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
  102. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  103. shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
  104. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  105. include the spatial dimensions.
  106. """
  107. loss: torch.FloatTensor | None = None
  108. logits: torch.FloatTensor | None = None
  109. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  110. attentions: tuple[torch.FloatTensor, ...] | None = None
  111. reshaped_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  112. @dataclass
  113. @auto_docstring(
  114. custom_intro="""
  115. Class for HieraForPreTraining's outputs, with potential hidden states and attentions.
  116. """
  117. )
  118. class HieraForPreTrainingOutput(ModelOutput):
  119. r"""
  120. loss (`torch.FloatTensor` of shape `(1,)`):
  121. Pixel reconstruction loss.
  122. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
  123. Pixel reconstruction logits.
  124. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
  125. Tensor indicating which patches are masked (0) and which are not (1).
  126. ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  127. Tensor containing the original index of the (shuffled) masked patches.
  128. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  129. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  130. shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
  131. plus the initial embedding outputs reshaped to include the spatial dimensions.
  132. """
  133. loss: torch.FloatTensor | None = None
  134. logits: torch.FloatTensor | None = None
  135. bool_masked_pos: torch.BoolTensor | None = None
  136. ids_restore: torch.LongTensor | None = None
  137. hidden_states: tuple[torch.FloatTensor] | None = None
  138. attentions: tuple[torch.FloatTensor] | None = None
  139. reshaped_hidden_states: tuple[torch.FloatTensor] | None = None
  140. class HieraPatchEmbeddings(nn.Module):
  141. """
  142. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  143. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  144. Transformer.
  145. """
  146. def __init__(self, config, is_mae: bool = False):
  147. super().__init__()
  148. # Support any number of spatial dimensions
  149. self.spatial_dims = len(config.patch_size)
  150. if self.spatial_dims != 2:
  151. raise ValueError(f"The number of dimensions of the input image should be 2, but got {self.spatial_dims}.")
  152. self.num_channels = config.num_channels
  153. self.image_size = config.image_size[-2:]
  154. self.tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
  155. self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
  156. self.mask_ratio = config.mask_ratio
  157. self.is_mae = is_mae
  158. self.projection = nn.Conv2d(
  159. self.num_channels,
  160. config.embed_dim,
  161. kernel_size=config.patch_size,
  162. stride=config.patch_stride,
  163. padding=config.patch_padding,
  164. )
  165. def masked_conv(
  166. self, pixel_values: torch.FloatTensor, bool_masked_pos: torch.BoolTensor | None = None
  167. ) -> torch.Tensor:
  168. """Zero-out the masked regions of the input before conv.
  169. Prevents leakage of masked regions when using overlapping kernels.
  170. """
  171. if bool_masked_pos is None:
  172. return self.projection(pixel_values)
  173. target_size = pixel_values.shape[2:]
  174. # Reshape bool_masked_pos to (batch_size, 1, mask_unit_height, mask_unit_width)
  175. bool_masked_pos = bool_masked_pos.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
  176. bool_masked_pos = nn.functional.interpolate(bool_masked_pos.float(), size=target_size)
  177. return self.projection(pixel_values * bool_masked_pos)
  178. def random_masking(
  179. self, pixel_values: torch.FloatTensor, noise: torch.FloatTensor | None = None
  180. ) -> tuple[torch.BoolTensor, torch.LongTensor]:
  181. """
  182. Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
  183. noise.
  184. Args:
  185. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`)
  186. noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
  187. mainly used for testing purposes to control randomness and maintain the reproducibility
  188. """
  189. batch_size = pixel_values.shape[0]
  190. # Tokens selected for masking at mask unit level
  191. num_windows = math.prod(self.mask_spatial_shape)
  192. len_keep = int(num_windows * (1 - self.mask_ratio))
  193. if noise is None:
  194. noise = torch.rand(batch_size, num_windows, device=pixel_values.device)
  195. # Sort noise for each sample
  196. ids_shuffle = torch.argsort(noise, dim=1)
  197. # ascend: small is keep, large is remove
  198. ids_restore = torch.argsort(ids_shuffle, dim=1).to(pixel_values.device)
  199. # Generate the binary bool_masked_pos: 1 is *keep*, 0 is *remove*
  200. # Note this is opposite to original MAE
  201. bool_masked_pos = torch.zeros([batch_size, num_windows], device=pixel_values.device)
  202. bool_masked_pos[:, :len_keep] = 1
  203. # Unshuffle to get the binary bool_masked_pos
  204. bool_masked_pos = torch.gather(bool_masked_pos, dim=1, index=ids_restore).bool()
  205. return bool_masked_pos, ids_restore
  206. def forward(
  207. self,
  208. pixel_values: torch.FloatTensor,
  209. noise: torch.FloatTensor | None = None,
  210. ) -> tuple[torch.Tensor, torch.BoolTensor | None, torch.LongTensor | None]:
  211. (bool_masked_pos, ids_restore) = (
  212. self.random_masking(pixel_values, noise=noise) if self.is_mae else (None, None)
  213. )
  214. embeddings = self.masked_conv(pixel_values, bool_masked_pos)
  215. embeddings = embeddings.flatten(2).transpose(2, 1)
  216. return embeddings, bool_masked_pos, ids_restore
  217. class HieraEmbeddings(nn.Module):
  218. """
  219. Construct position and patch embeddings.
  220. """
  221. def __init__(self, config: HieraConfig, is_mae: bool = False) -> None:
  222. super().__init__()
  223. self.patch_stride = config.patch_stride
  224. tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
  225. self.mask_spatial_shape = [i // s for i, s in zip(tokens_spatial_shape, config.masked_unit_size)]
  226. self.num_tokens = math.prod(tokens_spatial_shape)
  227. self.is_mae = is_mae
  228. self.patch_embeddings = HieraPatchEmbeddings(config, is_mae=is_mae)
  229. self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, config.embed_dim))
  230. def interpolate_pos_encoding(
  231. self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
  232. ) -> torch.Tensor:
  233. """
  234. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  235. images. This method is also adapted to support torch.jit tracing, no class embeddings, and different patch strides.
  236. Adapted from:
  237. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  238. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  239. """
  240. num_patches = embeddings.shape[1]
  241. num_positions = pos_embeds.shape[1]
  242. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  243. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  244. return pos_embeds
  245. dim = embeddings.shape[-1]
  246. new_height = height // self.patch_stride[0]
  247. new_width = width // self.patch_stride[1]
  248. sqrt_num_positions = torch_int(num_positions**0.5)
  249. pos_embeds = pos_embeds.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  250. pos_embeds = pos_embeds.permute(0, 3, 1, 2)
  251. pos_embeds = nn.functional.interpolate(
  252. pos_embeds,
  253. size=(new_height, new_width),
  254. mode="bicubic",
  255. align_corners=False,
  256. )
  257. pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
  258. return pos_embeds
  259. def get_position_embedding(
  260. self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
  261. ) -> torch.FloatTensor:
  262. return (
  263. self.interpolate_pos_encoding(embeddings, self.position_embeddings, height, width)
  264. if interpolate_pos_encoding
  265. else self.position_embeddings
  266. )
  267. def forward(
  268. self,
  269. pixel_values: torch.FloatTensor,
  270. noise: torch.FloatTensor | None = None,
  271. interpolate_pos_encoding: bool = False,
  272. ) -> tuple[torch.Tensor, torch.BoolTensor | None, torch.LongTensor | None]:
  273. height, width = pixel_values.shape[-2:]
  274. embeddings, bool_masked_pos, ids_restore = self.patch_embeddings(pixel_values, noise=noise)
  275. embeddings = embeddings + self.get_position_embedding(embeddings, height, width, interpolate_pos_encoding)
  276. return embeddings, bool_masked_pos, ids_restore
  277. class HieraMaskUnitAttention(nn.Module):
  278. """
  279. Computes either Mask Unit or Global Attention. Also is able to perform query pooling.
  280. Note: this assumes the tokens have already been flattened and unrolled into mask units.
  281. """
  282. def __init__(
  283. self,
  284. hidden_size: int,
  285. hidden_size_output: int,
  286. num_heads: int,
  287. query_stride: int = 1,
  288. window_size: int = 0,
  289. use_mask_unit_attn: bool = False,
  290. ) -> None:
  291. super().__init__()
  292. self.num_heads = num_heads
  293. self.query_stride = query_stride
  294. self.hidden_size_output = hidden_size_output
  295. self.head_dim = hidden_size_output // num_heads
  296. self.scale = (self.head_dim) ** -0.5
  297. self.qkv = nn.Linear(hidden_size, 3 * hidden_size_output)
  298. self.proj = nn.Linear(hidden_size_output, hidden_size_output)
  299. self.window_size = window_size
  300. self.use_mask_unit_attn = use_mask_unit_attn
  301. def forward(
  302. self,
  303. hidden_states: torch.Tensor,
  304. output_attentions: bool = False,
  305. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  306. """Input should be of shape [batch, tokens, channels]."""
  307. batch_size, seq_len, _ = hidden_states.shape
  308. num_windows = 1
  309. if self.use_mask_unit_attn:
  310. num_windows = seq_len // (self.query_stride * self.window_size)
  311. qkv = self.qkv(hidden_states)
  312. qkv = qkv.reshape(batch_size, -1, num_windows, 3, self.num_heads, self.head_dim)
  313. qkv = qkv.permute(3, 0, 4, 2, 1, 5)
  314. query, key, value = qkv.unbind(0)
  315. if self.query_stride > 1:
  316. # Refer to unroll to see how this performs a maxpool-Nd
  317. query = query.view(batch_size, self.num_heads, num_windows, self.query_stride, -1, self.head_dim)
  318. query = query.max(dim=3).values
  319. attn_weights = (query * self.scale) @ key.transpose(-1, -2)
  320. attn_weights = attn_weights.softmax(dim=-1)
  321. attn_output = attn_weights @ value
  322. attn_output = attn_output.transpose(1, 3).reshape(batch_size, -1, self.hidden_size_output)
  323. attn_output = self.proj(attn_output)
  324. return (attn_output, attn_weights) if output_attentions else (attn_output, None)
  325. # Copied from transformers.models.beit.modeling_beit.drop_path
  326. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  327. """
  328. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  329. """
  330. if drop_prob == 0.0 or not training:
  331. return input
  332. keep_prob = 1 - drop_prob
  333. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  334. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  335. random_tensor.floor_() # binarize
  336. output = input.div(keep_prob) * random_tensor
  337. return output
  338. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Hiera
  339. class HieraDropPath(nn.Module):
  340. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  341. def __init__(self, drop_prob: float | None = None) -> None:
  342. super().__init__()
  343. self.drop_prob = drop_prob
  344. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  345. return drop_path(hidden_states, self.drop_prob, self.training)
  346. def extra_repr(self) -> str:
  347. return f"p={self.drop_prob}"
  348. class HieraMlp(nn.Module):
  349. def __init__(self, config, dim: int) -> None:
  350. super().__init__()
  351. self.activation_fn = ACT2FN[config.hidden_act]
  352. self.fc1 = nn.Linear(dim, int(dim * config.mlp_ratio))
  353. self.fc2 = nn.Linear(int(dim * config.mlp_ratio), dim)
  354. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  355. hidden_states = self.fc1(hidden_states)
  356. hidden_states = self.activation_fn(hidden_states)
  357. hidden_states = self.fc2(hidden_states)
  358. return hidden_states
  359. class HieraLayer(nn.Module):
  360. def __init__(
  361. self,
  362. config,
  363. hidden_size: int,
  364. hidden_size_output: int,
  365. num_heads: int,
  366. drop_path: float = 0.0,
  367. query_stride: int = 1,
  368. window_size: int = 0,
  369. use_mask_unit_attn: bool = False,
  370. ) -> None:
  371. super().__init__()
  372. self.hidden_size = hidden_size
  373. self.hidden_size_output = hidden_size_output
  374. self.query_stride = query_stride
  375. self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  376. self.attn = HieraMaskUnitAttention(
  377. hidden_size=hidden_size,
  378. hidden_size_output=hidden_size_output,
  379. num_heads=num_heads,
  380. query_stride=query_stride,
  381. window_size=window_size,
  382. use_mask_unit_attn=use_mask_unit_attn,
  383. )
  384. self.layernorm_after = nn.LayerNorm(hidden_size_output, eps=config.layer_norm_eps)
  385. self.mlp = HieraMlp(config, hidden_size_output)
  386. self.drop_path = HieraDropPath(drop_path) if drop_path > 0 else nn.Identity()
  387. if hidden_size != hidden_size_output:
  388. self.proj = nn.Linear(hidden_size, hidden_size_output)
  389. def forward(
  390. self,
  391. hidden_states: torch.Tensor,
  392. output_attentions: bool = False,
  393. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  394. batch_size, seq_len, _ = hidden_states.shape
  395. # Attention + Q Pooling
  396. hidden_states_norm = self.layernorm_before(hidden_states)
  397. if self.hidden_size != self.hidden_size_output:
  398. hidden_states = self.proj(hidden_states_norm)
  399. # Refer to unroll to see how this performs a maxpool-Nd
  400. hidden_states = (
  401. hidden_states.view(batch_size, self.query_stride, -1, self.hidden_size_output).max(dim=1).values
  402. )
  403. (hidden_states_norm, attn_weights) = self.attn(hidden_states_norm, output_attentions=output_attentions)
  404. hidden_states = hidden_states + self.drop_path(hidden_states_norm)
  405. residual = hidden_states
  406. hidden_states = self.layernorm_after(hidden_states)
  407. hidden_states = self.mlp(hidden_states)
  408. hidden_states = residual + self.drop_path(hidden_states)
  409. return (hidden_states, attn_weights)
  410. class HieraStage(GradientCheckpointingLayer):
  411. def __init__(
  412. self,
  413. config,
  414. depth: int,
  415. hidden_size: int,
  416. hidden_size_output: int,
  417. num_heads: int,
  418. drop_path: list[float],
  419. query_stride: list[int],
  420. window_size: int,
  421. use_mask_unit_attn: bool,
  422. stage_num: int | None = None,
  423. ) -> None:
  424. super().__init__()
  425. # we need to know if the previous stage used masked attention
  426. # mask unit or global attention.
  427. # lag by 1 layer, so that global attention,
  428. # applied post pooling on lower resolution
  429. previous_stage_used_masked_attention = False
  430. if stage_num is not None:
  431. previous_stage_used_masked_attention = config.masked_unit_attention[stage_num - 1 if stage_num > 0 else 0]
  432. self.layers = nn.ModuleList(
  433. [
  434. HieraLayer(
  435. config=config,
  436. hidden_size=hidden_size if i == 0 else hidden_size_output,
  437. hidden_size_output=hidden_size_output,
  438. num_heads=num_heads,
  439. drop_path=drop_path[i],
  440. query_stride=query_stride[i],
  441. window_size=window_size,
  442. use_mask_unit_attn=use_mask_unit_attn or (previous_stage_used_masked_attention and i == 0),
  443. )
  444. for i in range(depth)
  445. ]
  446. )
  447. def forward(
  448. self, hidden_states: torch.Tensor, output_attentions: bool = False
  449. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  450. for i, layer_module in enumerate(self.layers):
  451. (hidden_states, attn_weights) = layer_module(hidden_states, output_attentions=output_attentions)
  452. return hidden_states, attn_weights
  453. def undo_windowing(hidden_states: torch.Tensor, shape: list[int], mask_unit_shape: list[int]) -> torch.Tensor:
  454. """
  455. Restore spatial organization by undoing windowed organization of mask units.
  456. Args:
  457. hidden_states (`torch.Tensor`): The hidden states tensor of shape `[batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]`.
  458. shape (`list[int]`): The original shape of the hidden states tensor before windowing.
  459. mask_unit_shape (`list[int]`): The shape of the mask units used for windowing.
  460. Returns:
  461. torch.Tensor: The restored hidden states tensor of shape [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size].
  462. """
  463. batch_size, hidden_size = hidden_states.shape[0], hidden_states.shape[-1]
  464. # From: [batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]
  465. # To: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
  466. num_mask_units = [s // mu for s, mu in zip(shape, mask_unit_shape)]
  467. hidden_states = hidden_states.view(batch_size, *num_mask_units, *mask_unit_shape, hidden_size)
  468. # From: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
  469. # To: [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size]
  470. hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5)
  471. hidden_states = hidden_states.reshape(batch_size, *shape, hidden_size)
  472. return hidden_states
  473. class HieraEncoder(nn.Module):
  474. def __init__(self, config: HieraConfig) -> None:
  475. super().__init__()
  476. total_depth = sum(config.depths)
  477. # stochastic depth decay rule
  478. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth, device="cpu")]
  479. # query strides rule
  480. cumulative_depths = torch.tensor(config.depths, device="cpu").cumsum(0).tolist()
  481. query_pool_layer = cumulative_depths[: config.num_query_pool]
  482. query_strides = [math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(total_depth)]
  483. # Transformer blocks
  484. self.stages = nn.ModuleList()
  485. hidden_size = config.embed_dim
  486. stage_ends = [0] + cumulative_depths
  487. masked_unit_area = math.prod(config.masked_unit_size)
  488. query_stride_area = math.prod(config.query_stride)
  489. for idx_stage, depth in enumerate(config.depths):
  490. hidden_size_output = int(config.embed_dim * config.embed_dim_multiplier**idx_stage)
  491. stage = HieraStage(
  492. config=config,
  493. depth=depth,
  494. hidden_size=hidden_size,
  495. hidden_size_output=hidden_size_output,
  496. num_heads=config.num_heads[idx_stage],
  497. drop_path=dpr[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
  498. query_stride=query_strides[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
  499. window_size=int(masked_unit_area * query_stride_area**-idx_stage),
  500. use_mask_unit_attn=config.masked_unit_attention[idx_stage],
  501. stage_num=idx_stage,
  502. )
  503. hidden_size = hidden_size_output
  504. self.stages.append(stage)
  505. # Setting reroll schedule
  506. # The first stage has to reverse everything
  507. # The next stage has to reverse all but the first unroll, etc.
  508. stage_size = [i // s for i, s in zip(config.image_size, config.patch_stride)]
  509. unroll_schedule = [config.query_stride] * len(config.depths[:-1])
  510. self.schedule = {}
  511. for idx_stage in range(len(config.depths)):
  512. self.schedule[idx_stage] = unroll_schedule, stage_size
  513. if idx_stage < config.num_query_pool:
  514. stage_size = [i // s for i, s in zip(stage_size, config.query_stride)]
  515. unroll_schedule = unroll_schedule[1:]
  516. self.gradient_checkpointing = False
  517. def reroll(
  518. self, hidden_states: torch.Tensor, stage_idx: int, bool_masked_pos: torch.BoolTensor | None = None
  519. ) -> torch.Tensor:
  520. """
  521. Roll the given tensor back up to spatial order assuming it's from the given block.
  522. If no bool_masked_pos is provided returns:
  523. - [batch_size, height, width, hidden_size]
  524. If a bool_masked_pos is provided returns:
  525. - [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
  526. """
  527. schedule, size = self.schedule[stage_idx]
  528. batch_size, seq_len, hidden_size = hidden_states.shape
  529. num_dim = len(size)
  530. mask_unit_shape = [1] * num_dim
  531. for strides in schedule:
  532. # Extract the current patch from seq_len
  533. hidden_states = hidden_states.view(
  534. batch_size, *strides, seq_len // math.prod(strides), *mask_unit_shape, hidden_size
  535. )
  536. # Move that patch into the current MU
  537. # Input: [batch_size, stride, stride, seq_len//(stride*stride), mask_unit_height, mask_unit_width, hidden_size]
  538. # Output: [batch_size, seq_len//(stride*stride), stride, mask_unit_height, stride, mask_unit_width, hidden_size]
  539. hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5, 6)
  540. # Reshape to [batch_size, seq_len//(stride*stride), *mask_units, hidden_size]
  541. for i in range(num_dim):
  542. mask_unit_shape[i] *= strides[i]
  543. hidden_states = hidden_states.reshape(batch_size, -1, *mask_unit_shape, hidden_size)
  544. seq_len = hidden_states.shape[1]
  545. # Current shape (e.g., 2d: [batch_size, #num_mask_units_height*#num_mask_units_width, mask_unit_height, mask_unit_width, hidden_size])
  546. hidden_states = hidden_states.view(batch_size, seq_len, *mask_unit_shape, hidden_size)
  547. # If masked, return [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
  548. if bool_masked_pos is not None:
  549. return hidden_states
  550. # If not masked, we can return [batch_size, height, width, hidden_size]
  551. hidden_states = undo_windowing(hidden_states, size, mask_unit_shape)
  552. return hidden_states
  553. def forward(
  554. self,
  555. hidden_states: torch.Tensor,
  556. bool_masked_pos: torch.BoolTensor | None = None,
  557. output_attentions: bool = False,
  558. output_hidden_states: bool = False,
  559. return_dict: bool = True,
  560. ) -> tuple | BaseModelOutput:
  561. all_hidden_states = () if output_hidden_states else None
  562. all_reshaped_hidden_states = () if output_hidden_states else None
  563. all_self_attentions = () if output_attentions else None
  564. if output_hidden_states:
  565. all_hidden_states = all_hidden_states + (hidden_states,)
  566. reshaped_hidden_states = self.reroll(hidden_states, stage_idx=0, bool_masked_pos=bool_masked_pos)
  567. all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
  568. for i, stage_module in enumerate(self.stages):
  569. layer_outputs = stage_module(hidden_states, output_attentions)
  570. hidden_states = layer_outputs[0]
  571. if output_attentions:
  572. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  573. if output_hidden_states:
  574. all_hidden_states = all_hidden_states + (hidden_states,)
  575. reshaped_hidden_states = self.reroll(hidden_states, stage_idx=i, bool_masked_pos=bool_masked_pos)
  576. all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
  577. if not return_dict:
  578. return tuple(
  579. v
  580. for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
  581. if v is not None
  582. )
  583. return HieraEncoderOutput(
  584. last_hidden_state=hidden_states,
  585. hidden_states=all_hidden_states,
  586. attentions=all_self_attentions,
  587. reshaped_hidden_states=all_reshaped_hidden_states,
  588. )
  589. def unroll(
  590. hidden_states: torch.Tensor, image_shape: tuple[int, int], patch_stride: tuple[int, int], schedule: list[list[int]]
  591. ) -> torch.Tensor:
  592. """
  593. Reorders the tokens such that patches are contiguous in memory.
  594. E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
  595. [batch_size, (stride, stride, height // stride, width // stride), hidden_size]
  596. This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
  597. Not only is this faster, but it also makes it easy to support inputs of arbitrary
  598. dimensions in addition to patch-wise sparsity.
  599. Performing this operation multiple times in sequence puts entire windows as contiguous
  600. in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
  601. size 8x8 would be contiguous in memory, allowing operations like mask unit attention
  602. computed easily and efficiently, while also allowing max to be applied sequentially.
  603. Note: This means that intermediate values of the model are not in height x width order, so they
  604. need to be re-rolled if you want to use the intermediate values as a height x width feature map.
  605. The last block of the network is fine though, since by then the strides are all consumed.
  606. """
  607. batch_size, _, hidden_size = hidden_states.shape
  608. size = [i // s for i, s in zip(image_shape, patch_stride)]
  609. current_size = size
  610. hidden_states = hidden_states.view(*([batch_size] + current_size + [hidden_size]))
  611. for strides in schedule:
  612. # Move patches with the given strides to the batch dimension
  613. # Create a view of the tensor with the patch stride as separate dims
  614. # For example in 2d: [batch_size, height // stride, stride, width // stride, stride, C]
  615. current_size = [i // s for i, s in zip(current_size, strides)]
  616. # initialize new_shape with [height // stride, stride, width // stride, stride]
  617. new_shape = [item for pair in zip(current_size, strides) for item in pair]
  618. # add batch_size and hidden_size to new_shape
  619. new_shape = [batch_size] + new_shape + [hidden_size]
  620. hidden_states = hidden_states.view(new_shape)
  621. # Move the patch stride into the batch dimension
  622. # For example in 2d: [batch_size, stride, stride, height // stride, width // stride, hidden_size]
  623. num_dims = len(new_shape)
  624. permute = [0] + list(range(2, num_dims - 1, 2)) + list(range(1, num_dims - 1, 2)) + [num_dims - 1]
  625. hidden_states = hidden_states.permute(permute)
  626. # Now finally flatten the relevant dims into the batch dimension
  627. hidden_states = hidden_states.flatten(0, len(strides))
  628. batch_size *= math.prod(strides)
  629. hidden_states = hidden_states.reshape(-1, math.prod(size), hidden_size)
  630. return hidden_states
  631. @auto_docstring
  632. class HieraPreTrainedModel(PreTrainedModel):
  633. config: HieraConfig
  634. base_model_prefix = "hiera"
  635. main_input_name = "pixel_values"
  636. input_modalities = ("image",)
  637. supports_gradient_checkpointing = True
  638. @torch.no_grad()
  639. def _init_weights(self, module) -> None:
  640. """Initialize the weights"""
  641. std = self.config.initializer_range
  642. if isinstance(module, HieraEmbeddings):
  643. init.trunc_normal_(module.position_embeddings, std=std)
  644. elif isinstance(module, HieraDecoder):
  645. init.trunc_normal_(module.mask_token, std=std)
  646. init.trunc_normal_(module.decoder_position_embeddings, std=std)
  647. elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
  648. init.trunc_normal_(module.weight, std=std)
  649. if module.bias is not None:
  650. init.constant_(module.bias, std)
  651. elif isinstance(module, nn.LayerNorm):
  652. init.constant_(module.bias, std)
  653. init.constant_(module.weight, self.config.layer_norm_init)
  654. class HieraPooler(nn.Module):
  655. def __init__(self, config: HieraConfig):
  656. super().__init__()
  657. num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
  658. self.layernorm = nn.LayerNorm(num_features, eps=config.layer_norm_eps)
  659. self.pooler = nn.AdaptiveAvgPool1d(1)
  660. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  661. hidden_states = hidden_states.transpose(1, 2)
  662. pooled_output = self.pooler(hidden_states)
  663. pooled_output = torch.flatten(pooled_output, 1)
  664. pooled_output = self.layernorm(pooled_output)
  665. return pooled_output
  666. @auto_docstring
  667. class HieraModel(HieraPreTrainedModel):
  668. def __init__(self, config: HieraConfig, add_pooling_layer: bool = True, is_mae: bool = False):
  669. r"""
  670. add_pooling_layer (`bool`, *optional*, defaults to `True`):
  671. Whether or not to apply pooling layer.
  672. is_mae (`bool`, *optional*, defaults to `False`):
  673. Whether or not to run the model on MAE mode.
  674. """
  675. super().__init__(config)
  676. self.num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
  677. self.embeddings = HieraEmbeddings(config, is_mae=is_mae)
  678. self.encoder = HieraEncoder(config)
  679. self.unroll_schedule = [config.query_stride] * len(config.depths[:-1])
  680. self.pooler = HieraPooler(config) if add_pooling_layer else None
  681. # Initialize weights and apply final processing
  682. self.post_init()
  683. def get_input_embeddings(self) -> HieraPatchEmbeddings:
  684. return self.embeddings.patch_embeddings
  685. @auto_docstring
  686. def forward(
  687. self,
  688. pixel_values: torch.Tensor | None = None,
  689. noise: torch.FloatTensor | None = None,
  690. output_attentions: bool | None = None,
  691. output_hidden_states: bool | None = None,
  692. interpolate_pos_encoding: bool | None = None,
  693. return_dict: bool | None = None,
  694. **kwargs,
  695. ) -> tuple | BaseModelOutputWithPooling:
  696. r"""
  697. noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
  698. Mainly used for testing purposes to control randomness and maintain the reproducibility
  699. """
  700. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  701. output_hidden_states = (
  702. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  703. )
  704. return_dict = return_dict if return_dict is not None else self.config.return_dict
  705. if pixel_values is None:
  706. raise ValueError("You have to specify pixel_values")
  707. embedding_output, bool_masked_pos, ids_restore = self.embeddings(
  708. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, noise=noise
  709. )
  710. image_shape = (pixel_values.shape[-2], pixel_values.shape[-1])
  711. hidden_states = unroll(
  712. embedding_output,
  713. image_shape=image_shape,
  714. patch_stride=self.config.patch_stride,
  715. schedule=self.unroll_schedule,
  716. )
  717. # Discard masked tokens if bool_masked_pos is provided
  718. if bool_masked_pos is not None:
  719. mask_unit_area = math.prod(self.config.masked_unit_size)
  720. batch_size, _, hidden_size = hidden_states.shape
  721. positions = bool_masked_pos.unsqueeze(-1).tile(1, mask_unit_area, hidden_size)
  722. hidden_states = hidden_states[positions]
  723. hidden_states = hidden_states.view(batch_size, -1, hidden_size)
  724. encoder_outputs = self.encoder(
  725. hidden_states,
  726. bool_masked_pos=bool_masked_pos,
  727. output_attentions=output_attentions,
  728. output_hidden_states=output_hidden_states,
  729. return_dict=return_dict,
  730. )
  731. sequence_output = encoder_outputs[0]
  732. pooled_output = None
  733. if self.pooler is not None:
  734. pooled_output = self.pooler(sequence_output)
  735. if not return_dict:
  736. head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
  737. head_outputs = (
  738. head_outputs + (bool_masked_pos, ids_restore) if bool_masked_pos is not None else head_outputs
  739. )
  740. return head_outputs + encoder_outputs[1:]
  741. return HieraModelOutput(
  742. last_hidden_state=sequence_output,
  743. pooler_output=pooled_output,
  744. bool_masked_pos=bool_masked_pos,
  745. ids_restore=ids_restore,
  746. hidden_states=encoder_outputs.hidden_states,
  747. attentions=encoder_outputs.attentions,
  748. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  749. )
  750. class HieraDecoder(nn.Module):
  751. def __init__(self, config: HieraConfig):
  752. super().__init__()
  753. num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
  754. tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
  755. self.tokens_spatial_shape_final = [
  756. i // s ** (config.num_query_pool) for i, s in zip(tokens_spatial_shape, config.query_stride)
  757. ]
  758. self.mask_unit_spatial_shape_final = [
  759. i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
  760. ]
  761. self.decoder_embeddings = nn.Linear(num_features, config.decoder_hidden_size)
  762. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
  763. self.decoder_position_embeddings = nn.Parameter(
  764. torch.zeros(1, math.prod(self.tokens_spatial_shape_final), config.decoder_hidden_size)
  765. )
  766. self.decoder_block = HieraStage(
  767. config=config,
  768. hidden_size=config.decoder_hidden_size,
  769. hidden_size_output=config.decoder_hidden_size,
  770. num_heads=config.decoder_num_heads,
  771. depth=config.decoder_depth,
  772. use_mask_unit_attn=False,
  773. drop_path=[0.0] * config.decoder_depth,
  774. query_stride=[1] * config.decoder_depth,
  775. window_size=0,
  776. )
  777. self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
  778. # patch stride of prediction
  779. self.pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
  780. pred_dim = (self.pred_stride ** len(config.query_stride)) * config.num_channels
  781. self.decoder_pred = nn.Linear(config.decoder_hidden_size, pred_dim)
  782. def forward(
  783. self,
  784. encoder_hidden_states: torch.Tensor,
  785. bool_masked_pos: torch.BoolTensor,
  786. output_attentions: bool = False,
  787. ) -> tuple[torch.Tensor, torch.BoolTensor]:
  788. # Embed tokens
  789. hidden_states = self.decoder_embeddings(encoder_hidden_states)
  790. # Combine visible and bool_masked_pos tokens
  791. # hidden_states : [batch_size, num_mask_units_visible, *mask_unit_spatial_shape_final, decoder_hidden_size]
  792. # bool_masked_pos: [batch_size, num_mask_units]
  793. mask_unit_height, mask_unit_width, decoder_hidden_size = hidden_states.shape[2:]
  794. batch_size, num_mask_units = bool_masked_pos.shape
  795. decoder_hidden_states = torch.zeros(
  796. batch_size,
  797. num_mask_units,
  798. mask_unit_height,
  799. mask_unit_width,
  800. decoder_hidden_size,
  801. device=hidden_states.device,
  802. dtype=hidden_states.dtype,
  803. )
  804. mask_tokens = self.mask_token.view(1, 1, 1, 1, -1)
  805. bool_masked_pos = bool_masked_pos.reshape(batch_size, num_mask_units, 1, 1, 1)
  806. bool_masked_pos = bool_masked_pos.expand(-1, -1, mask_unit_height, mask_unit_width, decoder_hidden_size)
  807. decoder_hidden_states[bool_masked_pos] = hidden_states.flatten()
  808. decoder_hidden_states = (
  809. 1 - bool_masked_pos.float()
  810. ) * mask_tokens + bool_masked_pos.float() * decoder_hidden_states
  811. # Get back spatial order
  812. hidden_states = undo_windowing(
  813. decoder_hidden_states,
  814. self.tokens_spatial_shape_final,
  815. self.mask_unit_spatial_shape_final,
  816. )
  817. bool_masked_pos = undo_windowing(
  818. bool_masked_pos[..., 0:1],
  819. self.tokens_spatial_shape_final,
  820. self.mask_unit_spatial_shape_final,
  821. )
  822. # Flatten
  823. hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[-1])
  824. bool_masked_pos = bool_masked_pos.view(hidden_states.shape[0], -1)
  825. # Add pos embed
  826. hidden_states = hidden_states + self.decoder_position_embeddings
  827. # Apply decoder blocks
  828. hidden_states, attn_weights = self.decoder_block(hidden_states, output_attentions=output_attentions)
  829. hidden_states = self.decoder_norm(hidden_states)
  830. # Predictor projection
  831. hidden_states = self.decoder_pred(hidden_states)
  832. return hidden_states, bool_masked_pos
  833. class HieraMultiScaleHead(nn.Module):
  834. def __init__(self, config: HieraConfig):
  835. super().__init__()
  836. self.mask_unit_spatial_shape_final = [
  837. i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
  838. ]
  839. self.stage_dimensions = [
  840. int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
  841. ]
  842. current_masked_unit_size = config.masked_unit_size
  843. self.multi_scale_fusion_heads = nn.ModuleList()
  844. for idx in range(config.num_query_pool):
  845. kernel = [i // s for i, s in zip(current_masked_unit_size, self.mask_unit_spatial_shape_final)]
  846. current_masked_unit_size = [i // s for i, s in zip(current_masked_unit_size, config.query_stride)]
  847. self.multi_scale_fusion_heads.append(
  848. nn.Conv2d(
  849. self.stage_dimensions[idx],
  850. self.stage_dimensions[-1],
  851. kernel_size=kernel,
  852. stride=kernel,
  853. )
  854. )
  855. self.multi_scale_fusion_heads.append(nn.Identity())
  856. def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
  857. if isinstance(head, nn.Identity):
  858. return hidden_states
  859. batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size = hidden_states.shape
  860. # From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
  861. # To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
  862. hidden_states = hidden_states.reshape(
  863. batch_size * num_mask_units, mask_unit_height, mask_unit_width, hidden_size
  864. )
  865. hidden_states = hidden_states.permute(0, 3, 1, 2)
  866. hidden_states = head(hidden_states)
  867. # Restore original layout
  868. hidden_states = hidden_states.permute(0, 2, 3, 1)
  869. mask_unit_height_final, mask_unit_width_final, hidden_size = hidden_states.shape[1:]
  870. hidden_states = hidden_states.reshape(
  871. batch_size, num_mask_units, mask_unit_height_final, mask_unit_width_final, hidden_size
  872. )
  873. return hidden_states
  874. def forward(self, feature_maps: list[torch.Tensor]) -> torch.Tensor:
  875. # Multi-scale fusion
  876. hidden_states = 0.0
  877. for head, feature_map in zip(self.multi_scale_fusion_heads, feature_maps):
  878. hidden_states = hidden_states + self.apply_fusion_head(head, feature_map)
  879. return hidden_states
  880. @auto_docstring(
  881. custom_intro="""
  882. The Hiera Model transformer with the decoder on top for self-supervised pre-training.
  883. <Tip>
  884. Note that we provide a script to pre-train this model on custom data in our [examples
  885. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  886. </Tip>
  887. """
  888. )
  889. class HieraForPreTraining(HieraPreTrainedModel):
  890. def __init__(self, config: HieraConfig) -> None:
  891. super().__init__(config)
  892. # Encoder
  893. self.hiera = HieraModel(config, add_pooling_layer=False, is_mae=True)
  894. self.encoder_norm = nn.LayerNorm(self.hiera.num_features, eps=config.layer_norm_eps)
  895. # Multi-scale fusion heads
  896. self.multiscale_fusion = HieraMultiScaleHead(config)
  897. # Decoder
  898. self.decoder = HieraDecoder(config)
  899. self.pred_stride = self.decoder.pred_stride
  900. # Initialize weights and apply final processing
  901. self.post_init()
  902. def get_pixel_label_2d(self, pixel_values: torch.Tensor, bool_masked_pos: torch.BoolTensor) -> torch.Tensor:
  903. # bool_masked_pos (boolean tensor): True means *masked*
  904. pixel_values = pixel_values.permute(0, 2, 3, 1)
  905. size = self.pred_stride
  906. label = pixel_values.unfold(1, size, size).unfold(2, size, size)
  907. label = label.flatten(1, 2).flatten(2)
  908. label = label[bool_masked_pos]
  909. if self.config.normalize_pixel_loss:
  910. mean = label.mean(dim=-1, keepdim=True)
  911. var = label.var(dim=-1, keepdim=True)
  912. label = (label - mean) / (var + 1.0e-6) ** 0.5
  913. return label
  914. def forward_loss(self, pixel_values: torch.Tensor, logits: torch.Tensor, bool_masked_pos: torch.BoolTensor):
  915. # We invert the bool_masked_pos such that 1.0 is *masked*
  916. bool_masked_pos = ~bool_masked_pos
  917. label = self.get_pixel_label_2d(pixel_values, bool_masked_pos)
  918. logits = logits[bool_masked_pos]
  919. loss = (logits - label) ** 2
  920. loss = loss.mean()
  921. return loss
  922. @auto_docstring
  923. def forward(
  924. self,
  925. pixel_values: torch.Tensor | None = None,
  926. noise: torch.FloatTensor | None = None,
  927. output_attentions: bool | None = None,
  928. output_hidden_states: bool | None = None,
  929. interpolate_pos_encoding: bool | None = None,
  930. return_dict: bool | None = None,
  931. **kwargs,
  932. ) -> tuple | HieraForPreTrainingOutput:
  933. r"""
  934. noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
  935. Mainly used for testing purposes to control randomness and maintain the reproducibility
  936. Examples:
  937. ```python
  938. >>> from transformers import AutoImageProcessor, HieraForPreTraining
  939. >>> import torch
  940. >>> from PIL import Image
  941. >>> import httpx
  942. >>> from io import BytesIO
  943. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  944. >>> with httpx.stream("GET", url) as response:
  945. ... image = Image.open(BytesIO(response.read()))
  946. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-mae-hf")
  947. >>> model = HieraForPreTraining.from_pretrained("facebook/hiera-tiny-224-mae-hf")
  948. >>> inputs = image_processor(images=image, return_tensors="pt")
  949. >>> outputs = model(**inputs)
  950. >>> logits = outputs.logits
  951. >>> loss = outputs.loss
  952. >>> print(list(logits.shape))
  953. [1, 196, 768]
  954. ```"""
  955. return_dict = return_dict if return_dict is not None else self.config.return_dict
  956. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  957. output_hidden_states = (
  958. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  959. )
  960. outputs = self.hiera(
  961. pixel_values,
  962. noise=noise,
  963. output_attentions=output_attentions,
  964. output_hidden_states=True,
  965. interpolate_pos_encoding=interpolate_pos_encoding,
  966. return_dict=return_dict,
  967. )
  968. feature_maps = outputs[-1]
  969. bool_masked_pos = outputs[1]
  970. ids_to_restore = outputs[2]
  971. # Take only the query pooled and last hidden states
  972. feature_maps = feature_maps[1 : self.hiera.config.num_query_pool + 1] + (feature_maps[-1],)
  973. fused_hidden_states = self.multiscale_fusion(feature_maps)
  974. fused_hidden_states = self.encoder_norm(fused_hidden_states)
  975. # Reconstruct pixel values
  976. logits, bool_masked_pos = self.decoder(
  977. fused_hidden_states,
  978. bool_masked_pos=bool_masked_pos,
  979. output_attentions=output_attentions,
  980. )
  981. loss = self.forward_loss(pixel_values, logits, bool_masked_pos)
  982. if not return_dict:
  983. output = (logits, bool_masked_pos, ids_to_restore)
  984. if output_hidden_states:
  985. output = output + (outputs[3],)
  986. if output_attentions:
  987. output = output + (outputs[4],)
  988. if output_hidden_states:
  989. output = output + (outputs[-1],)
  990. return ((loss,) + output) if loss is not None else output
  991. return HieraForPreTrainingOutput(
  992. loss=loss,
  993. logits=logits,
  994. bool_masked_pos=bool_masked_pos,
  995. ids_restore=ids_to_restore,
  996. hidden_states=outputs.hidden_states if output_hidden_states else None,
  997. attentions=outputs.attentions,
  998. reshaped_hidden_states=outputs.reshaped_hidden_states if output_hidden_states else None,
  999. )
  1000. @auto_docstring(
  1001. custom_intro="""
  1002. Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
  1003. average pooling) e.g. for ImageNet.
  1004. <Tip>
  1005. Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
  1006. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  1007. position embeddings to the higher resolution.
  1008. </Tip>
  1009. """
  1010. )
  1011. class HieraForImageClassification(HieraPreTrainedModel):
  1012. def __init__(self, config: HieraConfig) -> None:
  1013. super().__init__(config)
  1014. self.num_labels = config.num_labels
  1015. self.hiera = HieraModel(config, add_pooling_layer=True, is_mae=False)
  1016. # Classifier head
  1017. self.classifier = (
  1018. nn.Linear(self.hiera.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  1019. )
  1020. # Initialize weights and apply final processing
  1021. self.post_init()
  1022. @auto_docstring
  1023. def forward(
  1024. self,
  1025. pixel_values,
  1026. labels: torch.Tensor | None = None,
  1027. output_attentions: bool | None = None,
  1028. output_hidden_states: bool | None = None,
  1029. interpolate_pos_encoding: bool | None = None,
  1030. return_dict: bool | None = None,
  1031. **kwargs,
  1032. ) -> tuple | HieraForImageClassificationOutput:
  1033. r"""
  1034. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1035. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1036. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1037. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1038. """
  1039. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1040. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1041. output_hidden_states = (
  1042. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1043. )
  1044. outputs = self.hiera(
  1045. pixel_values,
  1046. output_attentions=output_attentions,
  1047. output_hidden_states=output_hidden_states,
  1048. interpolate_pos_encoding=interpolate_pos_encoding,
  1049. return_dict=return_dict,
  1050. )
  1051. pooled_output = outputs[1]
  1052. logits = self.classifier(pooled_output)
  1053. loss = None
  1054. if labels is not None:
  1055. loss = self.loss_function(labels, logits, self.config)
  1056. if not return_dict:
  1057. output = (logits,) + outputs[2:]
  1058. return ((loss,) + output) if loss is not None else output
  1059. return HieraForImageClassificationOutput(
  1060. loss=loss,
  1061. logits=logits,
  1062. hidden_states=outputs.hidden_states,
  1063. attentions=outputs.attentions,
  1064. reshaped_hidden_states=outputs.reshaped_hidden_states,
  1065. )
  1066. @auto_docstring(
  1067. custom_intro="""
  1068. Hiera backbone, to be used with frameworks like DETR and MaskFormer.
  1069. """
  1070. )
  1071. class HieraBackbone(BackboneMixin, HieraPreTrainedModel):
  1072. def __init__(self, config: HieraConfig):
  1073. super().__init__(config)
  1074. self.num_features = [config.embed_dim] + [
  1075. int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
  1076. ]
  1077. self.embeddings = HieraEmbeddings(config, is_mae=False)
  1078. self.encoder = HieraEncoder(config)
  1079. # Add layer norms to hidden states of out_features
  1080. hidden_states_norms = {}
  1081. for stage, num_channels in zip(self.out_features, self.channels):
  1082. hidden_states_norms[stage] = nn.LayerNorm(num_channels)
  1083. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  1084. # Initialize weights and apply final processing
  1085. self.post_init()
  1086. def get_input_embeddings(self):
  1087. return self.embeddings.patch_embeddings
  1088. @can_return_tuple
  1089. @filter_output_hidden_states
  1090. def forward(
  1091. self,
  1092. pixel_values: torch.Tensor,
  1093. output_hidden_states: bool | None = None,
  1094. output_attentions: bool | None = None,
  1095. return_dict: bool | None = None,
  1096. **kwargs,
  1097. ) -> BackboneOutput:
  1098. """
  1099. Returns:
  1100. Examples:
  1101. ```python
  1102. >>> from transformers import AutoImageProcessor, AutoBackbone
  1103. >>> import torch
  1104. >>> from PIL import Image
  1105. >>> import httpx
  1106. >>> from io import BytesIO
  1107. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1108. >>> with httpx.stream("GET", url) as response:
  1109. ... image = Image.open(BytesIO(response.read()))
  1110. >>> processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-hf")
  1111. >>> model = AutoBackbone.from_pretrained(
  1112. ... "facebook/hiera-tiny-224-hf", out_features=["stage1", "stage2", "stage3", "stage4"]
  1113. ... )
  1114. >>> inputs = processor(image, return_tensors="pt")
  1115. >>> outputs = model(**inputs)
  1116. >>> feature_maps = outputs.feature_maps
  1117. >>> list(feature_maps[-1].shape)
  1118. [1, 768, 7, 7]
  1119. ```"""
  1120. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1121. output_hidden_states = (
  1122. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1123. )
  1124. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1125. embedding_output, _, _ = self.embeddings(pixel_values)
  1126. outputs = self.encoder(
  1127. embedding_output,
  1128. output_attentions=output_attentions,
  1129. output_hidden_states=True,
  1130. return_dict=return_dict,
  1131. )
  1132. hidden_states = outputs[-1]
  1133. feature_maps = ()
  1134. for stage, hidden_state in zip(self.stage_names, hidden_states):
  1135. if stage in self.out_features:
  1136. batch_size, height, width, num_channels = hidden_state.shape
  1137. hidden_state = hidden_state.view(batch_size, height * width, num_channels)
  1138. hidden_state = self.hidden_states_norms[stage](hidden_state)
  1139. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  1140. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  1141. feature_maps += (hidden_state,)
  1142. if not return_dict:
  1143. output = (feature_maps,)
  1144. if output_hidden_states:
  1145. output += (outputs[1],)
  1146. if output_attentions:
  1147. output += (outputs[2],)
  1148. return output
  1149. return BackboneOutput(
  1150. feature_maps=feature_maps,
  1151. hidden_states=outputs[1] if output_hidden_states else None,
  1152. attentions=outputs[2] if output_attentions else None,
  1153. )
  1154. __all__ = ["HieraForImageClassification", "HieraForPreTraining", "HieraBackbone", "HieraModel", "HieraPreTrainedModel"]