modeling_bit.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. # Copyright 2022 Google AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BiT model. Also supports backbone for ViT hybrid."""
  15. import collections
  16. import math
  17. import numpy as np
  18. import torch
  19. from torch import Tensor, nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  23. from ...modeling_outputs import (
  24. BackboneOutput,
  25. BaseModelOutputWithNoAttention,
  26. BaseModelOutputWithPoolingAndNoAttention,
  27. ImageClassifierOutputWithNoAttention,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import auto_docstring, logging
  31. from ...utils.generic import can_return_tuple
  32. from .configuration_bit import BitConfig
  33. logger = logging.get_logger(__name__)
  34. def get_padding_value(padding=None, kernel_size=7, stride=1, dilation=1) -> tuple[tuple, bool]:
  35. r"""
  36. Utility function to get the tuple padding value given the kernel_size and padding.
  37. Args:
  38. padding (Union[`str`, `int`], *optional*):
  39. Padding value, can be either `"same"`, `"valid"`. If a different value is provided the default padding from
  40. PyTorch is used.
  41. kernel_size (`int`, *optional*, defaults to 7):
  42. Kernel size of the convolution layers.
  43. stride (`int`, *optional*, defaults to 1):
  44. Stride value of the convolution layers.
  45. dilation (`int`, *optional*, defaults to 1):
  46. Dilation value of the convolution layers.
  47. """
  48. dynamic = False
  49. if padding is None:
  50. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  51. return padding, dynamic
  52. if isinstance(padding, str):
  53. # for any string padding, the padding will be calculated for you, one of three ways
  54. padding = padding.lower()
  55. if padding == "same":
  56. # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
  57. if stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0:
  58. # static case, no extra overhead
  59. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  60. else:
  61. # dynamic 'SAME' padding, has runtime/GPU memory overhead
  62. padding = 0
  63. dynamic = True
  64. elif padding == "valid":
  65. # 'VALID' padding, same as padding=0
  66. padding = 0
  67. else:
  68. # Default to PyTorch style 'same'-ish symmetric padding
  69. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  70. return padding, dynamic
  71. class WeightStandardizedConv2d(nn.Conv2d):
  72. """Conv2d with Weight Standardization. Used for ViT Hybrid model.
  73. Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
  74. Standardization](https://huggingface.co/papers/1903.10520)
  75. """
  76. def __init__(
  77. self,
  78. in_channel,
  79. out_channels,
  80. kernel_size,
  81. stride=1,
  82. padding="SAME",
  83. dilation=1,
  84. groups=1,
  85. bias=False,
  86. eps=1e-6,
  87. ):
  88. padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
  89. super().__init__(
  90. in_channel,
  91. out_channels,
  92. kernel_size,
  93. stride=stride,
  94. padding=padding,
  95. dilation=dilation,
  96. groups=groups,
  97. bias=bias,
  98. )
  99. if is_dynamic:
  100. self.pad = DynamicPad2d(kernel_size, stride, dilation)
  101. else:
  102. self.pad = None
  103. self.eps = eps
  104. def forward(self, hidden_state):
  105. if self.pad is not None:
  106. hidden_state = self.pad(hidden_state)
  107. weight = nn.functional.batch_norm(
  108. self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0.0, eps=self.eps
  109. ).reshape_as(self.weight)
  110. hidden_state = nn.functional.conv2d(
  111. hidden_state, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
  112. )
  113. return hidden_state
  114. class BitGroupNormActivation(nn.GroupNorm):
  115. r"""
  116. A module that combines group normalization with an activation function.
  117. """
  118. def __init__(self, config, num_channels, eps=1e-5, affine=True, apply_activation=True):
  119. super().__init__(config.num_groups, num_channels, eps=eps, affine=affine)
  120. if apply_activation:
  121. self.activation = ACT2FN[config.hidden_act]
  122. else:
  123. self.activation = nn.Identity()
  124. def forward(self, hidden_state):
  125. hidden_state = nn.functional.group_norm(hidden_state, self.num_groups, self.weight, self.bias, self.eps)
  126. hidden_state = self.activation(hidden_state)
  127. return hidden_state
  128. class DynamicPad2d(nn.Module):
  129. r"""
  130. A module that wraps dynamic padding of any input, given the parameters of the convolutional layer and the input
  131. hidden states.
  132. """
  133. def __init__(self, kernel_size, stride, dilation, value=0):
  134. super().__init__()
  135. # Safety checkers
  136. if isinstance(kernel_size, int):
  137. kernel_size = (kernel_size, kernel_size)
  138. if isinstance(stride, int):
  139. stride = (stride, stride)
  140. if isinstance(dilation, int):
  141. dilation = (dilation, dilation)
  142. self.kernel_size = kernel_size
  143. self.stride = stride
  144. self.dilation = dilation
  145. self.value = value
  146. def compute_padding(x, kernel_size, stride, dilation):
  147. return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
  148. self.compute_padding = compute_padding
  149. def forward(self, input):
  150. # Get width and height
  151. input_height, input_width = input.size()[-2:]
  152. # Compute the padding values
  153. padding_height = self.compute_padding(input_height, self.kernel_size[0], self.stride[0], self.dilation[0])
  154. padding_width = self.compute_padding(input_width, self.kernel_size[1], self.stride[1], self.dilation[1])
  155. # apply pad
  156. if padding_height > 0 or padding_width > 0:
  157. input = nn.functional.pad(
  158. input,
  159. [
  160. padding_width // 2,
  161. padding_width - padding_width // 2,
  162. padding_height // 2,
  163. padding_height - padding_height // 2,
  164. ],
  165. value=self.value,
  166. )
  167. return input
  168. class BitMaxPool2d(nn.MaxPool2d):
  169. def __init__(
  170. self,
  171. kernel_size: int,
  172. stride=None,
  173. dilation=1,
  174. ceil_mode=False,
  175. padding=(0, 0),
  176. padding_value=0,
  177. use_dynamic_padding=True,
  178. ):
  179. kernel_size = kernel_size if isinstance(kernel_size, collections.abc.Iterable) else (kernel_size, kernel_size)
  180. stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
  181. dilation = dilation if isinstance(dilation, collections.abc.Iterable) else (dilation, dilation)
  182. super().__init__(kernel_size, stride, padding, dilation, ceil_mode)
  183. if use_dynamic_padding:
  184. self.pad = DynamicPad2d(kernel_size, stride, dilation, padding_value)
  185. else:
  186. self.pad = nn.Identity()
  187. def forward(self, hidden_states):
  188. hidden_states = self.pad(hidden_states)
  189. return nn.functional.max_pool2d(
  190. hidden_states, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode
  191. )
  192. class BitEmbeddings(nn.Module):
  193. """
  194. BiT Embeddings (stem) composed of a single aggressive convolution.
  195. """
  196. def __init__(self, config: BitConfig):
  197. super().__init__()
  198. self.convolution = WeightStandardizedConv2d(
  199. config.num_channels,
  200. config.embedding_size,
  201. kernel_size=7,
  202. stride=2,
  203. eps=1e-8,
  204. padding=config.global_padding,
  205. )
  206. self.pooler = BitMaxPool2d(kernel_size=3, stride=2, use_dynamic_padding=config.embedding_dynamic_padding)
  207. # Use the same padding strategy as convolutional layers
  208. if config.global_padding is not None and config.global_padding.upper() == "SAME":
  209. self.pad = nn.Identity()
  210. else:
  211. self.pad = nn.ConstantPad2d(padding=(1, 1, 1, 1), value=0.0)
  212. if config.layer_type != "preactivation":
  213. self.norm = BitGroupNormActivation(config, num_channels=config.embedding_size)
  214. else:
  215. self.norm = nn.Identity()
  216. self.num_channels = config.num_channels
  217. def forward(self, pixel_values: Tensor) -> Tensor:
  218. num_channels = pixel_values.shape[1]
  219. if num_channels != self.num_channels:
  220. raise ValueError(
  221. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  222. )
  223. embedding = self.convolution(pixel_values)
  224. embedding = self.pad(embedding)
  225. embedding = self.norm(embedding)
  226. embedding = self.pooler(embedding)
  227. return embedding
  228. # Copied from transformers.models.convnext.modeling_convnext.drop_path
  229. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  230. """
  231. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  232. """
  233. if drop_prob == 0.0 or not training:
  234. return input
  235. keep_prob = 1 - drop_prob
  236. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  237. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  238. random_tensor.floor_() # binarize
  239. output = input.div(keep_prob) * random_tensor
  240. return output
  241. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit
  242. class BitDropPath(nn.Module):
  243. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  244. def __init__(self, drop_prob: float | None = None) -> None:
  245. super().__init__()
  246. self.drop_prob = drop_prob
  247. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  248. return drop_path(hidden_states, self.drop_prob, self.training)
  249. def extra_repr(self) -> str:
  250. return f"p={self.drop_prob}"
  251. def make_div(value, divisor=8):
  252. min_value = divisor
  253. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  254. if new_value < 0.9 * value:
  255. new_value += divisor
  256. return new_value
  257. class BitPreActivationBottleneckLayer(nn.Module):
  258. """Pre-activation (v2) bottleneck block.
  259. Follows the implementation of "Identity Mappings in Deep Residual Networks":
  260. https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
  261. Except it puts the stride on 3x3 conv when available.
  262. """
  263. def __init__(
  264. self,
  265. config,
  266. in_channels,
  267. out_channels=None,
  268. bottle_ratio=0.25,
  269. stride=1,
  270. dilation=1,
  271. first_dilation=None,
  272. groups=1,
  273. drop_path_rate=0.0,
  274. is_first_layer=False,
  275. ):
  276. super().__init__()
  277. first_dilation = first_dilation or dilation
  278. out_channels = out_channels or in_channels
  279. mid_channels = make_div(out_channels * bottle_ratio)
  280. if is_first_layer:
  281. self.downsample = BitDownsampleConv(
  282. config,
  283. in_channels,
  284. out_channels,
  285. stride=stride,
  286. preact=True,
  287. )
  288. else:
  289. self.downsample = None
  290. self.norm1 = BitGroupNormActivation(config, in_channels)
  291. self.conv1 = WeightStandardizedConv2d(in_channels, mid_channels, 1, eps=1e-8, padding=config.global_padding)
  292. self.norm2 = BitGroupNormActivation(config, num_channels=mid_channels)
  293. self.conv2 = WeightStandardizedConv2d(
  294. mid_channels, mid_channels, 3, stride=stride, groups=groups, eps=1e-8, padding=config.global_padding
  295. )
  296. self.norm3 = BitGroupNormActivation(config, mid_channels)
  297. self.conv3 = WeightStandardizedConv2d(mid_channels, out_channels, 1, eps=1e-8, padding=config.global_padding)
  298. self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  299. def forward(self, hidden_states):
  300. hidden_states_preact = self.norm1(hidden_states)
  301. # shortcut branch
  302. shortcut = hidden_states
  303. if self.downsample is not None:
  304. shortcut = self.downsample(hidden_states_preact)
  305. # residual branch
  306. hidden_states = self.conv1(hidden_states_preact)
  307. hidden_states = self.conv2(self.norm2(hidden_states))
  308. hidden_states = self.conv3(self.norm3(hidden_states))
  309. hidden_states = self.drop_path(hidden_states)
  310. return hidden_states + shortcut
  311. class BitBottleneckLayer(nn.Module):
  312. """Non Pre-activation bottleneck block, equivalent to V1.5/V1b bottleneck. Used for ViT Hybrid."""
  313. def __init__(
  314. self,
  315. config,
  316. in_channels,
  317. out_channels=None,
  318. bottle_ratio=0.25,
  319. stride=1,
  320. dilation=1,
  321. first_dilation=None,
  322. groups=1,
  323. drop_path_rate=0.0,
  324. is_first_layer=False,
  325. ):
  326. super().__init__()
  327. first_dilation = first_dilation or dilation
  328. out_channels = out_channels or in_channels
  329. mid_chs = make_div(out_channels * bottle_ratio)
  330. if is_first_layer:
  331. self.downsample = BitDownsampleConv(
  332. config,
  333. in_channels,
  334. out_channels,
  335. stride=stride,
  336. preact=False,
  337. )
  338. else:
  339. self.downsample = None
  340. self.conv1 = WeightStandardizedConv2d(in_channels, mid_chs, 1, eps=1e-8, padding=config.global_padding)
  341. self.norm1 = BitGroupNormActivation(config, num_channels=mid_chs)
  342. self.conv2 = WeightStandardizedConv2d(
  343. mid_chs,
  344. mid_chs,
  345. 3,
  346. stride=stride,
  347. dilation=first_dilation,
  348. groups=groups,
  349. eps=1e-8,
  350. padding=config.global_padding,
  351. )
  352. self.norm2 = BitGroupNormActivation(config, num_channels=mid_chs)
  353. self.conv3 = WeightStandardizedConv2d(mid_chs, out_channels, 1, eps=1e-8, padding=config.global_padding)
  354. self.norm3 = BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False)
  355. self.drop_path = BitDropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  356. self.activation = ACT2FN[config.hidden_act]
  357. def forward(self, hidden_states):
  358. # shortcut branch
  359. shortcut = hidden_states
  360. if self.downsample is not None:
  361. shortcut = self.downsample(hidden_states)
  362. # residual
  363. hidden_states = self.conv1(hidden_states)
  364. hidden_states = self.norm1(hidden_states)
  365. hidden_states = self.conv2(hidden_states)
  366. hidden_states = self.norm2(hidden_states)
  367. hidden_states = self.conv3(hidden_states)
  368. hidden_states = self.norm3(hidden_states)
  369. hidden_states = self.drop_path(hidden_states)
  370. hidden_states = self.activation(hidden_states + shortcut)
  371. return hidden_states
  372. class BitDownsampleConv(nn.Module):
  373. def __init__(
  374. self,
  375. config,
  376. in_channels,
  377. out_channels,
  378. stride=1,
  379. preact=True,
  380. ):
  381. super().__init__()
  382. self.conv = WeightStandardizedConv2d(
  383. in_channels, out_channels, 1, stride=stride, eps=1e-8, padding=config.global_padding
  384. )
  385. self.norm = (
  386. nn.Identity()
  387. if preact
  388. else BitGroupNormActivation(config, num_channels=out_channels, apply_activation=False)
  389. )
  390. def forward(self, x):
  391. return self.norm(self.conv(x))
  392. class BitStage(nn.Module):
  393. """
  394. A ResNet v2 stage composed by stacked layers.
  395. """
  396. def __init__(
  397. self,
  398. config,
  399. in_channels,
  400. out_channels,
  401. stride,
  402. dilation,
  403. depth,
  404. bottle_ratio=0.25,
  405. layer_dropout=None,
  406. ):
  407. super().__init__()
  408. first_dilation = 1 if dilation in (1, 2) else 2
  409. # Get the layer type
  410. if config.layer_type == "bottleneck":
  411. layer_cls = BitBottleneckLayer
  412. else:
  413. layer_cls = BitPreActivationBottleneckLayer
  414. prev_chs = in_channels
  415. self.layers = nn.Sequential()
  416. for layer_idx in range(depth):
  417. # Get the current hyper-parameters
  418. stride, drop_path_rate, is_first_layer = self._get_updated_hyperparameters(
  419. layer_idx, stride, layer_dropout
  420. )
  421. self.layers.add_module(
  422. str(layer_idx),
  423. layer_cls(
  424. config,
  425. prev_chs,
  426. out_channels,
  427. stride=stride,
  428. dilation=dilation,
  429. bottle_ratio=bottle_ratio,
  430. first_dilation=first_dilation,
  431. drop_path_rate=drop_path_rate,
  432. is_first_layer=is_first_layer,
  433. ),
  434. )
  435. prev_chs = out_channels
  436. first_dilation = dilation
  437. def _get_updated_hyperparameters(self, layer_idx, stride, layer_dropout):
  438. r"""
  439. Get the new hyper-parameters with respect to the previous ones and the index of the current layer.
  440. """
  441. if layer_dropout:
  442. drop_path_rate = layer_dropout[layer_idx]
  443. else:
  444. drop_path_rate = 0.0
  445. if layer_idx != 0:
  446. stride = 1
  447. is_first_layer = layer_idx == 0
  448. return stride, drop_path_rate, is_first_layer
  449. def forward(self, input: Tensor) -> Tensor:
  450. hidden_state = input
  451. for _, layer in enumerate(self.layers):
  452. hidden_state = layer(hidden_state)
  453. return hidden_state
  454. class BitEncoder(nn.Module):
  455. def __init__(self, config: BitConfig):
  456. super().__init__()
  457. self.stages = nn.ModuleList([])
  458. prev_chs = config.embedding_size
  459. # These needs to stay hardcoded
  460. current_stride = 4
  461. dilation = 1
  462. layer_dropouts = [
  463. x.tolist()
  464. for x in torch.Tensor(np.linspace(0, config.drop_path_rate, sum(config.depths))).split(config.depths)
  465. ]
  466. for stage_idx, (current_depth, current_hidden_size, layer_dropout) in enumerate(
  467. zip(config.depths, config.hidden_sizes, layer_dropouts)
  468. ):
  469. # Get the updated hyper params
  470. out_channels, stride, dilation = self._get_updated_hyperparameters(
  471. stage_idx, current_stride, current_hidden_size, dilation, config
  472. )
  473. stage = BitStage(
  474. config,
  475. prev_chs,
  476. out_channels,
  477. stride=stride,
  478. dilation=dilation,
  479. depth=current_depth,
  480. layer_dropout=layer_dropout,
  481. )
  482. prev_chs = out_channels
  483. current_stride *= stride
  484. self.stages.add_module(str(stage_idx), stage)
  485. def _get_updated_hyperparameters(self, stage_idx, current_stride, current_hidden_size, dilation, config):
  486. out_channels = make_div(current_hidden_size * config.width_factor)
  487. stride = 1 if stage_idx == 0 else 2
  488. if current_stride >= config.output_stride:
  489. dilation *= stride
  490. stride = 1
  491. return out_channels, stride, dilation
  492. def forward(
  493. self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
  494. ) -> BaseModelOutputWithNoAttention:
  495. hidden_states = () if output_hidden_states else None
  496. for stage_module in self.stages:
  497. if output_hidden_states:
  498. hidden_states = hidden_states + (hidden_state,)
  499. hidden_state = stage_module(hidden_state)
  500. if output_hidden_states:
  501. hidden_states = hidden_states + (hidden_state,)
  502. if not return_dict:
  503. return tuple(v for v in [hidden_state, hidden_states] if v is not None)
  504. return BaseModelOutputWithNoAttention(
  505. last_hidden_state=hidden_state,
  506. hidden_states=hidden_states,
  507. )
  508. @auto_docstring
  509. class BitPreTrainedModel(PreTrainedModel):
  510. config: BitConfig
  511. base_model_prefix = "bit"
  512. input_modalities = ("image",)
  513. main_input_name = "pixel_values"
  514. _no_split_modules = ["BitEmbeddings"]
  515. @torch.no_grad()
  516. def _init_weights(self, module):
  517. if isinstance(module, nn.Conv2d):
  518. init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  519. # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
  520. elif isinstance(module, nn.Linear):
  521. init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  522. if module.bias is not None:
  523. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  524. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  525. init.uniform_(module.bias, -bound, bound)
  526. elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
  527. init.constant_(module.weight, 1)
  528. init.constant_(module.bias, 0)
  529. if getattr(module, "running_mean", None) is not None:
  530. init.zeros_(module.running_mean)
  531. init.ones_(module.running_var)
  532. init.zeros_(module.num_batches_tracked)
  533. @auto_docstring
  534. class BitModel(BitPreTrainedModel):
  535. def __init__(self, config):
  536. super().__init__(config)
  537. self.config = config
  538. self.embedder = BitEmbeddings(config)
  539. self.encoder = BitEncoder(config)
  540. self.norm = (
  541. BitGroupNormActivation(config, num_channels=config.hidden_sizes[-1])
  542. if config.layer_type == "preactivation"
  543. else nn.Identity()
  544. )
  545. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  546. # Initialize weights and apply final processing
  547. self.post_init()
  548. @auto_docstring
  549. def forward(
  550. self,
  551. pixel_values: Tensor,
  552. output_hidden_states: bool | None = None,
  553. return_dict: bool | None = None,
  554. **kwargs,
  555. ) -> BaseModelOutputWithPoolingAndNoAttention:
  556. output_hidden_states = (
  557. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  558. )
  559. return_dict = return_dict if return_dict is not None else self.config.return_dict
  560. embedding_output = self.embedder(pixel_values)
  561. encoder_outputs = self.encoder(
  562. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
  563. )
  564. last_hidden_state = encoder_outputs[0]
  565. last_hidden_state = self.norm(last_hidden_state)
  566. pooled_output = self.pooler(last_hidden_state)
  567. if not return_dict:
  568. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  569. return BaseModelOutputWithPoolingAndNoAttention(
  570. last_hidden_state=last_hidden_state,
  571. pooler_output=pooled_output,
  572. hidden_states=encoder_outputs.hidden_states,
  573. )
  574. @auto_docstring(
  575. custom_intro="""
  576. BiT Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  577. ImageNet.
  578. """
  579. )
  580. class BitForImageClassification(BitPreTrainedModel):
  581. def __init__(self, config):
  582. super().__init__(config)
  583. self.num_labels = config.num_labels
  584. self.bit = BitModel(config)
  585. # classification head
  586. self.classifier = nn.Sequential(
  587. nn.Flatten(),
  588. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
  589. )
  590. # initialize weights and apply final processing
  591. self.post_init()
  592. @auto_docstring
  593. def forward(
  594. self,
  595. pixel_values: torch.FloatTensor | None = None,
  596. labels: torch.LongTensor | None = None,
  597. output_hidden_states: bool | None = None,
  598. return_dict: bool | None = None,
  599. **kwargs,
  600. ) -> ImageClassifierOutputWithNoAttention:
  601. r"""
  602. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  603. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  604. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  605. """
  606. return_dict = return_dict if return_dict is not None else self.config.return_dict
  607. outputs = self.bit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  608. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  609. logits = self.classifier(pooled_output)
  610. loss = None
  611. if labels is not None:
  612. loss = self.loss_function(labels, logits, self.config)
  613. if not return_dict:
  614. output = (logits,) + outputs[2:]
  615. return (loss,) + output if loss is not None else output
  616. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  617. @auto_docstring(
  618. custom_intro="""
  619. BiT backbone, to be used with frameworks like DETR and MaskFormer.
  620. """
  621. )
  622. class BitBackbone(BackboneMixin, BitPreTrainedModel):
  623. has_attentions = False
  624. def __init__(self, config):
  625. super().__init__(config)
  626. self.bit = BitModel(config)
  627. self.num_features = [config.embedding_size] + config.hidden_sizes
  628. # initialize weights and apply final processing
  629. self.post_init()
  630. @can_return_tuple
  631. @filter_output_hidden_states
  632. @auto_docstring
  633. def forward(
  634. self,
  635. pixel_values: Tensor,
  636. output_hidden_states: bool | None = None,
  637. return_dict: bool | None = None,
  638. **kwargs,
  639. ) -> BackboneOutput:
  640. r"""
  641. Examples:
  642. ```python
  643. >>> from transformers import AutoImageProcessor, AutoBackbone
  644. >>> import torch
  645. >>> from PIL import Image
  646. >>> import httpx
  647. >>> from io import BytesIO
  648. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  649. >>> with httpx.stream("GET", url) as response:
  650. ... image = Image.open(BytesIO(response.read()))
  651. >>> processor = AutoImageProcessor.from_pretrained("google/bit-50")
  652. >>> model = AutoBackbone.from_pretrained("google/bit-50")
  653. >>> inputs = processor(image, return_tensors="pt")
  654. >>> outputs = model(**inputs)
  655. ```"""
  656. return_dict = return_dict if return_dict is not None else self.config.return_dict
  657. output_hidden_states = (
  658. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  659. )
  660. outputs = self.bit(pixel_values, output_hidden_states=True, return_dict=True)
  661. hidden_states = outputs.hidden_states
  662. feature_maps = ()
  663. for idx, stage in enumerate(self.stage_names):
  664. if stage in self.out_features:
  665. feature_maps += (hidden_states[idx],)
  666. if not return_dict:
  667. output = (feature_maps,)
  668. if output_hidden_states:
  669. output += (outputs.hidden_states,)
  670. return output
  671. return BackboneOutput(
  672. feature_maps=feature_maps,
  673. hidden_states=outputs.hidden_states if output_hidden_states else None,
  674. attentions=None,
  675. )
  676. __all__ = ["BitForImageClassification", "BitModel", "BitPreTrainedModel", "BitBackbone"]