modeling_swin2sr.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068
  1. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Swin2SR Transformer model."""
  15. import collections.abc
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...utils import ModelOutput, auto_docstring, logging
  26. from .configuration_swin2sr import Swin2SRConfig
  27. logger = logging.get_logger(__name__)
  28. @dataclass
  29. @auto_docstring(
  30. custom_intro="""
  31. Swin2SR encoder's outputs, with potential hidden states and attentions.
  32. """
  33. )
  34. class Swin2SREncoderOutput(ModelOutput):
  35. last_hidden_state: torch.FloatTensor | None = None
  36. hidden_states: tuple[torch.FloatTensor] | None = None
  37. attentions: tuple[torch.FloatTensor] | None = None
  38. # Copied from transformers.models.swin.modeling_swin.window_partition
  39. def window_partition(input_feature, window_size):
  40. """
  41. Partitions the given input into windows.
  42. """
  43. batch_size, height, width, num_channels = input_feature.shape
  44. input_feature = input_feature.view(
  45. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  46. )
  47. windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  48. return windows
  49. # Copied from transformers.models.swin.modeling_swin.window_reverse
  50. def window_reverse(windows, window_size, height, width):
  51. """
  52. Merges windows to produce higher resolution features.
  53. """
  54. num_channels = windows.shape[-1]
  55. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  56. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  57. return windows
  58. # Copied from transformers.models.beit.modeling_beit.drop_path
  59. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  60. """
  61. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  62. """
  63. if drop_prob == 0.0 or not training:
  64. return input
  65. keep_prob = 1 - drop_prob
  66. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  67. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  68. random_tensor.floor_() # binarize
  69. output = input.div(keep_prob) * random_tensor
  70. return output
  71. # Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swin2SR
  72. class Swin2SRDropPath(nn.Module):
  73. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  74. def __init__(self, drop_prob: float | None = None) -> None:
  75. super().__init__()
  76. self.drop_prob = drop_prob
  77. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  78. return drop_path(hidden_states, self.drop_prob, self.training)
  79. def extra_repr(self) -> str:
  80. return f"p={self.drop_prob}"
  81. class Swin2SREmbeddings(nn.Module):
  82. """
  83. Construct the patch and optional position embeddings.
  84. """
  85. def __init__(self, config):
  86. super().__init__()
  87. self.patch_embeddings = Swin2SRPatchEmbeddings(config)
  88. num_patches = self.patch_embeddings.num_patches
  89. if config.use_absolute_embeddings:
  90. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
  91. else:
  92. self.position_embeddings = None
  93. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  94. self.window_size = config.window_size
  95. def forward(self, pixel_values: torch.FloatTensor | None) -> tuple[torch.Tensor]:
  96. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  97. if self.position_embeddings is not None:
  98. embeddings = embeddings + self.position_embeddings
  99. embeddings = self.dropout(embeddings)
  100. return embeddings, output_dimensions
  101. class Swin2SRPatchEmbeddings(nn.Module):
  102. def __init__(self, config, normalize_patches=True):
  103. super().__init__()
  104. num_channels = config.embed_dim
  105. image_size, patch_size = config.image_size, config.patch_size
  106. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  107. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  108. patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
  109. self.patches_resolution = patches_resolution
  110. self.num_patches = patches_resolution[0] * patches_resolution[1]
  111. self.projection = nn.Conv2d(num_channels, config.embed_dim, kernel_size=patch_size, stride=patch_size)
  112. self.layernorm = nn.LayerNorm(config.embed_dim) if normalize_patches else None
  113. def forward(self, embeddings: torch.FloatTensor | None) -> tuple[torch.Tensor, tuple[int]]:
  114. embeddings = self.projection(embeddings)
  115. _, _, height, width = embeddings.shape
  116. output_dimensions = (height, width)
  117. embeddings = embeddings.flatten(2).transpose(1, 2)
  118. if self.layernorm is not None:
  119. embeddings = self.layernorm(embeddings)
  120. return embeddings, output_dimensions
  121. class Swin2SRPatchUnEmbeddings(nn.Module):
  122. r"""Image to Patch Unembedding"""
  123. def __init__(self, config):
  124. super().__init__()
  125. self.embed_dim = config.embed_dim
  126. def forward(self, embeddings, x_size):
  127. batch_size, height_width, num_channels = embeddings.shape
  128. embeddings = embeddings.transpose(1, 2).view(batch_size, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
  129. return embeddings
  130. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2PatchMerging with Swinv2->Swin2SR
  131. class Swin2SRPatchMerging(nn.Module):
  132. """
  133. Patch Merging Layer.
  134. Args:
  135. input_resolution (`tuple[int]`):
  136. Resolution of input feature.
  137. dim (`int`):
  138. Number of input channels.
  139. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  140. Normalization layer class.
  141. """
  142. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  143. super().__init__()
  144. self.input_resolution = input_resolution
  145. self.dim = dim
  146. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  147. self.norm = norm_layer(2 * dim)
  148. def maybe_pad(self, input_feature, height, width):
  149. should_pad = (height % 2 == 1) or (width % 2 == 1)
  150. if should_pad:
  151. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  152. input_feature = nn.functional.pad(input_feature, pad_values)
  153. return input_feature
  154. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  155. height, width = input_dimensions
  156. # `dim` is height * width
  157. batch_size, dim, num_channels = input_feature.shape
  158. input_feature = input_feature.view(batch_size, height, width, num_channels)
  159. # pad input to be divisible by width and height, if needed
  160. input_feature = self.maybe_pad(input_feature, height, width)
  161. # [batch_size, height/2, width/2, num_channels]
  162. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  163. # [batch_size, height/2, width/2, num_channels]
  164. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  165. # [batch_size, height/2, width/2, num_channels]
  166. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  167. # [batch_size, height/2, width/2, num_channels]
  168. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  169. # [batch_size, height/2 * width/2, 4*num_channels]
  170. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  171. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C]
  172. input_feature = self.reduction(input_feature)
  173. input_feature = self.norm(input_feature)
  174. return input_feature
  175. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2SelfAttention with Swinv2->Swin2SR
  176. class Swin2SRSelfAttention(nn.Module):
  177. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]):
  178. super().__init__()
  179. if dim % num_heads != 0:
  180. raise ValueError(
  181. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  182. )
  183. self.num_attention_heads = num_heads
  184. self.attention_head_size = int(dim / num_heads)
  185. self.all_head_size = self.num_attention_heads * self.attention_head_size
  186. self.window_size = (
  187. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  188. )
  189. self.pretrained_window_size = pretrained_window_size
  190. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
  191. # mlp to generate continuous relative position bias
  192. self.continuous_position_bias_mlp = nn.Sequential(
  193. nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
  194. )
  195. relative_coords_table, relative_position_index = self.create_coords_table_and_index()
  196. self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
  197. self.register_buffer("relative_position_index", relative_position_index, persistent=False)
  198. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  199. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False)
  200. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  201. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  202. def forward(
  203. self,
  204. hidden_states: torch.Tensor,
  205. attention_mask: torch.FloatTensor | None = None,
  206. output_attentions: bool | None = False,
  207. ) -> tuple[torch.Tensor]:
  208. batch_size, dim, num_channels = hidden_states.shape
  209. query_layer = (
  210. self.query(hidden_states)
  211. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  212. .transpose(1, 2)
  213. )
  214. key_layer = (
  215. self.key(hidden_states)
  216. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  217. .transpose(1, 2)
  218. )
  219. value_layer = (
  220. self.value(hidden_states)
  221. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  222. .transpose(1, 2)
  223. )
  224. # cosine attention
  225. attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize(
  226. key_layer, dim=-1
  227. ).transpose(-2, -1)
  228. logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
  229. attention_scores = attention_scores * logit_scale
  230. relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
  231. -1, self.num_attention_heads
  232. )
  233. # [window_height*window_width,window_height*window_width,num_attention_heads]
  234. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  235. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  236. )
  237. # [num_attention_heads,window_height*window_width,window_height*window_width]
  238. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  239. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  240. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  241. if attention_mask is not None:
  242. # Apply the attention mask is (precomputed for all layers in Swin2SRModel forward() function)
  243. mask_shape = attention_mask.shape[0]
  244. attention_scores = attention_scores.view(
  245. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  246. ) + attention_mask.unsqueeze(1).unsqueeze(0)
  247. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  248. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  249. # Normalize the attention scores to probabilities.
  250. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  251. # This is actually dropping out entire tokens to attend to, which might
  252. # seem a bit unusual, but is taken from the original Transformer paper.
  253. attention_probs = self.dropout(attention_probs)
  254. # Mask heads if we want to
  255. context_layer = torch.matmul(attention_probs, value_layer)
  256. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  257. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  258. context_layer = context_layer.view(new_context_layer_shape)
  259. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  260. return outputs
  261. def create_coords_table_and_index(self):
  262. # get relative_coords_table
  263. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
  264. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
  265. relative_coords_table = (
  266. torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
  267. .permute(1, 2, 0)
  268. .contiguous()
  269. .unsqueeze(0)
  270. ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
  271. if self.pretrained_window_size[0] > 0:
  272. relative_coords_table[:, :, :, 0] /= self.pretrained_window_size[0] - 1
  273. relative_coords_table[:, :, :, 1] /= self.pretrained_window_size[1] - 1
  274. elif self.window_size[0] > 1:
  275. relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
  276. relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
  277. relative_coords_table *= 8 # normalize to -8, 8
  278. relative_coords_table = (
  279. torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
  280. )
  281. # set to same dtype as mlp weight
  282. relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
  283. # get pair-wise relative position index for each token inside the window
  284. coords_h = torch.arange(self.window_size[0])
  285. coords_w = torch.arange(self.window_size[1])
  286. coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
  287. coords_flatten = torch.flatten(coords, 1)
  288. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  289. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  290. relative_coords[:, :, 0] += self.window_size[0] - 1
  291. relative_coords[:, :, 1] += self.window_size[1] - 1
  292. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  293. relative_position_index = relative_coords.sum(-1)
  294. return relative_coords_table, relative_position_index
  295. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swin2SR
  296. class Swin2SRSelfOutput(nn.Module):
  297. def __init__(self, config, dim):
  298. super().__init__()
  299. self.dense = nn.Linear(dim, dim)
  300. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  301. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  302. hidden_states = self.dense(hidden_states)
  303. hidden_states = self.dropout(hidden_states)
  304. return hidden_states
  305. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Attention with Swinv2->Swin2SR
  306. class Swin2SRAttention(nn.Module):
  307. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0):
  308. super().__init__()
  309. self.self = Swin2SRSelfAttention(
  310. config=config,
  311. dim=dim,
  312. num_heads=num_heads,
  313. window_size=window_size,
  314. pretrained_window_size=pretrained_window_size
  315. if isinstance(pretrained_window_size, collections.abc.Iterable)
  316. else (pretrained_window_size, pretrained_window_size),
  317. )
  318. self.output = Swin2SRSelfOutput(config, dim)
  319. def forward(
  320. self,
  321. hidden_states: torch.Tensor,
  322. attention_mask: torch.FloatTensor | None = None,
  323. output_attentions: bool | None = False,
  324. ) -> tuple[torch.Tensor]:
  325. self_outputs = self.self(hidden_states, attention_mask, output_attentions)
  326. attention_output = self.output(self_outputs[0], hidden_states)
  327. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  328. return outputs
  329. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swin2SR
  330. class Swin2SRIntermediate(nn.Module):
  331. def __init__(self, config, dim):
  332. super().__init__()
  333. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  334. if isinstance(config.hidden_act, str):
  335. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  336. else:
  337. self.intermediate_act_fn = config.hidden_act
  338. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  339. hidden_states = self.dense(hidden_states)
  340. hidden_states = self.intermediate_act_fn(hidden_states)
  341. return hidden_states
  342. # Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swin2SR
  343. class Swin2SROutput(nn.Module):
  344. def __init__(self, config, dim):
  345. super().__init__()
  346. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  347. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  348. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  349. hidden_states = self.dense(hidden_states)
  350. hidden_states = self.dropout(hidden_states)
  351. return hidden_states
  352. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Layer with Swinv2->Swin2SR
  353. class Swin2SRLayer(nn.Module):
  354. def __init__(
  355. self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0
  356. ):
  357. super().__init__()
  358. self.input_resolution = input_resolution
  359. window_size, shift_size = self._compute_window_shift(
  360. (config.window_size, config.window_size), (shift_size, shift_size)
  361. )
  362. self.window_size = window_size[0]
  363. self.shift_size = shift_size[0]
  364. self.attention = Swin2SRAttention(
  365. config=config,
  366. dim=dim,
  367. num_heads=num_heads,
  368. window_size=self.window_size,
  369. pretrained_window_size=pretrained_window_size
  370. if isinstance(pretrained_window_size, collections.abc.Iterable)
  371. else (pretrained_window_size, pretrained_window_size),
  372. )
  373. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  374. self.drop_path = Swin2SRDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  375. self.intermediate = Swin2SRIntermediate(config, dim)
  376. self.output = Swin2SROutput(config, dim)
  377. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  378. def _compute_window_shift(self, target_window_size, target_shift_size) -> tuple[tuple[int, int], tuple[int, int]]:
  379. window_size = [min(r, w) for r, w in zip(self.input_resolution, target_window_size)]
  380. shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
  381. return window_size, shift_size
  382. def get_attn_mask(self, height, width, dtype):
  383. if self.shift_size > 0:
  384. # calculate attention mask for shifted window multihead self attention
  385. img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
  386. height_slices = (
  387. slice(0, -self.window_size),
  388. slice(-self.window_size, -self.shift_size),
  389. slice(-self.shift_size, None),
  390. )
  391. width_slices = (
  392. slice(0, -self.window_size),
  393. slice(-self.window_size, -self.shift_size),
  394. slice(-self.shift_size, None),
  395. )
  396. count = 0
  397. for height_slice in height_slices:
  398. for width_slice in width_slices:
  399. img_mask[:, height_slice, width_slice, :] = count
  400. count += 1
  401. mask_windows = window_partition(img_mask, self.window_size)
  402. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  403. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  404. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  405. else:
  406. attn_mask = None
  407. return attn_mask
  408. def maybe_pad(self, hidden_states, height, width):
  409. pad_right = (self.window_size - width % self.window_size) % self.window_size
  410. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  411. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  412. hidden_states = nn.functional.pad(hidden_states, pad_values)
  413. return hidden_states, pad_values
  414. def forward(
  415. self,
  416. hidden_states: torch.Tensor,
  417. input_dimensions: tuple[int, int],
  418. output_attentions: bool | None = False,
  419. ) -> tuple[torch.Tensor, torch.Tensor]:
  420. height, width = input_dimensions
  421. batch_size, _, channels = hidden_states.size()
  422. shortcut = hidden_states
  423. # pad hidden_states to multiples of window size
  424. hidden_states = hidden_states.view(batch_size, height, width, channels)
  425. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  426. _, height_pad, width_pad, _ = hidden_states.shape
  427. # cyclic shift
  428. if self.shift_size > 0:
  429. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  430. else:
  431. shifted_hidden_states = hidden_states
  432. # partition windows
  433. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  434. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  435. attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
  436. if attn_mask is not None:
  437. attn_mask = attn_mask.to(hidden_states_windows.device)
  438. attention_outputs = self.attention(hidden_states_windows, attn_mask, output_attentions=output_attentions)
  439. attention_output = attention_outputs[0]
  440. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  441. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  442. # reverse cyclic shift
  443. if self.shift_size > 0:
  444. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  445. else:
  446. attention_windows = shifted_windows
  447. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  448. if was_padded:
  449. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  450. attention_windows = attention_windows.view(batch_size, height * width, channels)
  451. hidden_states = self.layernorm_before(attention_windows)
  452. hidden_states = shortcut + self.drop_path(hidden_states)
  453. layer_output = self.intermediate(hidden_states)
  454. layer_output = self.output(layer_output)
  455. layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output))
  456. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  457. return layer_outputs
  458. class Swin2SRStage(GradientCheckpointingLayer):
  459. """
  460. This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation.
  461. """
  462. def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, pretrained_window_size=0):
  463. super().__init__()
  464. self.config = config
  465. self.dim = dim
  466. self.layers = nn.ModuleList(
  467. [
  468. Swin2SRLayer(
  469. config=config,
  470. dim=dim,
  471. input_resolution=input_resolution,
  472. num_heads=num_heads,
  473. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  474. pretrained_window_size=pretrained_window_size,
  475. )
  476. for i in range(depth)
  477. ]
  478. )
  479. if config.resi_connection == "1conv":
  480. self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
  481. elif config.resi_connection == "3conv":
  482. # to save parameters and memory
  483. self.conv = nn.Sequential(
  484. nn.Conv2d(dim, dim // 4, 3, 1, 1),
  485. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  486. nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
  487. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  488. nn.Conv2d(dim // 4, dim, 3, 1, 1),
  489. )
  490. self.patch_embed = Swin2SRPatchEmbeddings(config, normalize_patches=False)
  491. self.patch_unembed = Swin2SRPatchUnEmbeddings(config)
  492. def forward(
  493. self,
  494. hidden_states: torch.Tensor,
  495. input_dimensions: tuple[int, int],
  496. output_attentions: bool | None = False,
  497. ) -> tuple[torch.Tensor]:
  498. residual = hidden_states
  499. height, width = input_dimensions
  500. for i, layer_module in enumerate(self.layers):
  501. layer_outputs = layer_module(hidden_states, input_dimensions, output_attentions)
  502. hidden_states = layer_outputs[0]
  503. output_dimensions = (height, width, height, width)
  504. hidden_states = self.patch_unembed(hidden_states, input_dimensions)
  505. hidden_states = self.conv(hidden_states)
  506. hidden_states, _ = self.patch_embed(hidden_states)
  507. hidden_states = hidden_states + residual
  508. stage_outputs = (hidden_states, output_dimensions)
  509. if output_attentions:
  510. stage_outputs += layer_outputs[1:]
  511. return stage_outputs
  512. class Swin2SREncoder(nn.Module):
  513. def __init__(self, config, grid_size):
  514. super().__init__()
  515. self.num_stages = len(config.depths)
  516. self.config = config
  517. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  518. self.stages = nn.ModuleList(
  519. [
  520. Swin2SRStage(
  521. config=config,
  522. dim=config.embed_dim,
  523. input_resolution=(grid_size[0], grid_size[1]),
  524. depth=config.depths[stage_idx],
  525. num_heads=config.num_heads[stage_idx],
  526. drop_path=dpr[sum(config.depths[:stage_idx]) : sum(config.depths[: stage_idx + 1])],
  527. pretrained_window_size=0,
  528. )
  529. for stage_idx in range(self.num_stages)
  530. ]
  531. )
  532. self.gradient_checkpointing = False
  533. def forward(
  534. self,
  535. hidden_states: torch.Tensor,
  536. input_dimensions: tuple[int, int],
  537. output_attentions: bool | None = False,
  538. output_hidden_states: bool | None = False,
  539. return_dict: bool | None = True,
  540. ) -> tuple | Swin2SREncoderOutput:
  541. all_input_dimensions = ()
  542. all_hidden_states = () if output_hidden_states else None
  543. all_self_attentions = () if output_attentions else None
  544. if output_hidden_states:
  545. all_hidden_states += (hidden_states,)
  546. for i, stage_module in enumerate(self.stages):
  547. layer_outputs = stage_module(hidden_states, input_dimensions, output_attentions)
  548. hidden_states = layer_outputs[0]
  549. output_dimensions = layer_outputs[1]
  550. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  551. all_input_dimensions += (input_dimensions,)
  552. if output_hidden_states:
  553. all_hidden_states += (hidden_states,)
  554. if output_attentions:
  555. all_self_attentions += layer_outputs[2:]
  556. if not return_dict:
  557. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  558. return Swin2SREncoderOutput(
  559. last_hidden_state=hidden_states,
  560. hidden_states=all_hidden_states,
  561. attentions=all_self_attentions,
  562. )
  563. @auto_docstring
  564. class Swin2SRPreTrainedModel(PreTrainedModel):
  565. config: Swin2SRConfig
  566. base_model_prefix = "swin2sr"
  567. main_input_name = "pixel_values"
  568. input_modalities = ("image",)
  569. supports_gradient_checkpointing = True
  570. @torch.no_grad()
  571. def _init_weights(self, module):
  572. """Initialize the weights"""
  573. if isinstance(module, (nn.Linear, nn.Conv2d)):
  574. init.trunc_normal_(module.weight, std=self.config.initializer_range)
  575. if module.bias is not None:
  576. init.zeros_(module.bias)
  577. elif isinstance(module, nn.LayerNorm):
  578. init.zeros_(module.bias)
  579. init.ones_(module.weight)
  580. elif isinstance(module, Swin2SRSelfAttention):
  581. init.constant_(module.logit_scale, math.log(10))
  582. relative_coords_table, relative_position_index = module.create_coords_table_and_index()
  583. init.copy_(module.relative_coords_table, relative_coords_table)
  584. init.copy_(module.relative_position_index, relative_position_index)
  585. elif isinstance(module, Swin2SRModel):
  586. if module.config.num_channels == 3 and module.config.num_channels_out == 3:
  587. mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1)
  588. else:
  589. mean = torch.zeros(1, 1, 1, 1)
  590. init.copy_(module.mean, mean)
  591. @auto_docstring
  592. class Swin2SRModel(Swin2SRPreTrainedModel):
  593. def __init__(self, config):
  594. super().__init__(config)
  595. self.config = config
  596. if config.num_channels == 3 and config.num_channels_out == 3:
  597. mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1)
  598. else:
  599. mean = torch.zeros(1, 1, 1, 1)
  600. self.register_buffer("mean", mean, persistent=False)
  601. self.img_range = config.img_range
  602. self.first_convolution = nn.Conv2d(config.num_channels, config.embed_dim, 3, 1, 1)
  603. self.embeddings = Swin2SREmbeddings(config)
  604. self.encoder = Swin2SREncoder(config, grid_size=self.embeddings.patch_embeddings.patches_resolution)
  605. self.layernorm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
  606. self.patch_unembed = Swin2SRPatchUnEmbeddings(config)
  607. self.conv_after_body = nn.Conv2d(config.embed_dim, config.embed_dim, 3, 1, 1)
  608. # Initialize weights and apply final processing
  609. self.post_init()
  610. def get_input_embeddings(self):
  611. return self.embeddings.patch_embeddings
  612. def pad_and_normalize(self, pixel_values):
  613. _, _, height, width = pixel_values.size()
  614. # 1. pad
  615. window_size = self.config.window_size
  616. modulo_pad_height = (window_size - height % window_size) % window_size
  617. modulo_pad_width = (window_size - width % window_size) % window_size
  618. pixel_values = nn.functional.pad(pixel_values, (0, modulo_pad_width, 0, modulo_pad_height), "reflect")
  619. # 2. normalize
  620. mean = self.mean.type_as(pixel_values)
  621. pixel_values = (pixel_values - mean) * self.img_range
  622. return pixel_values
  623. @auto_docstring
  624. def forward(
  625. self,
  626. pixel_values: torch.FloatTensor,
  627. output_attentions: bool | None = None,
  628. output_hidden_states: bool | None = None,
  629. return_dict: bool | None = None,
  630. **kwargs,
  631. ) -> tuple | BaseModelOutput:
  632. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  633. output_hidden_states = (
  634. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  635. )
  636. return_dict = return_dict if return_dict is not None else self.config.return_dict
  637. _, _, height, width = pixel_values.shape
  638. # some preprocessing: padding + normalization
  639. pixel_values = self.pad_and_normalize(pixel_values)
  640. embeddings = self.first_convolution(pixel_values)
  641. embedding_output, input_dimensions = self.embeddings(embeddings)
  642. encoder_outputs = self.encoder(
  643. embedding_output,
  644. input_dimensions,
  645. output_attentions=output_attentions,
  646. output_hidden_states=output_hidden_states,
  647. return_dict=return_dict,
  648. )
  649. sequence_output = encoder_outputs[0]
  650. sequence_output = self.layernorm(sequence_output)
  651. sequence_output = self.patch_unembed(sequence_output, (height, width))
  652. sequence_output = self.conv_after_body(sequence_output) + embeddings
  653. if not return_dict:
  654. output = (sequence_output,) + encoder_outputs[1:]
  655. return output
  656. return BaseModelOutput(
  657. last_hidden_state=sequence_output,
  658. hidden_states=encoder_outputs.hidden_states,
  659. attentions=encoder_outputs.attentions,
  660. )
  661. class Upsample(nn.Module):
  662. """Upsample module.
  663. Args:
  664. scale (`int`):
  665. Scale factor. Supported scales: 2^n and 3.
  666. num_features (`int`):
  667. Channel number of intermediate features.
  668. """
  669. def __init__(self, scale, num_features):
  670. super().__init__()
  671. self.scale = scale
  672. if (scale & (scale - 1)) == 0:
  673. # scale = 2^n
  674. for i in range(int(math.log2(scale))):
  675. self.add_module(f"convolution_{i}", nn.Conv2d(num_features, 4 * num_features, 3, 1, 1))
  676. self.add_module(f"pixelshuffle_{i}", nn.PixelShuffle(2))
  677. elif scale == 3:
  678. self.convolution = nn.Conv2d(num_features, 9 * num_features, 3, 1, 1)
  679. self.pixelshuffle = nn.PixelShuffle(3)
  680. else:
  681. raise ValueError(f"Scale {scale} is not supported. Supported scales: 2^n and 3.")
  682. def forward(self, hidden_state):
  683. if (self.scale & (self.scale - 1)) == 0:
  684. for i in range(int(math.log2(self.scale))):
  685. hidden_state = self.__getattr__(f"convolution_{i}")(hidden_state)
  686. hidden_state = self.__getattr__(f"pixelshuffle_{i}")(hidden_state)
  687. elif self.scale == 3:
  688. hidden_state = self.convolution(hidden_state)
  689. hidden_state = self.pixelshuffle(hidden_state)
  690. return hidden_state
  691. class UpsampleOneStep(nn.Module):
  692. """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
  693. Used in lightweight SR to save parameters.
  694. Args:
  695. scale (int):
  696. Scale factor. Supported scales: 2^n and 3.
  697. in_channels (int):
  698. Channel number of intermediate features.
  699. out_channels (int):
  700. Channel number of output features.
  701. """
  702. def __init__(self, scale, in_channels, out_channels):
  703. super().__init__()
  704. self.conv = nn.Conv2d(in_channels, (scale**2) * out_channels, 3, 1, 1)
  705. self.pixel_shuffle = nn.PixelShuffle(scale)
  706. def forward(self, x):
  707. x = self.conv(x)
  708. x = self.pixel_shuffle(x)
  709. return x
  710. class PixelShuffleUpsampler(nn.Module):
  711. def __init__(self, config, num_features):
  712. super().__init__()
  713. self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
  714. self.activation = nn.LeakyReLU(inplace=True)
  715. self.upsample = Upsample(config.upscale, num_features)
  716. self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
  717. def forward(self, sequence_output):
  718. x = self.conv_before_upsample(sequence_output)
  719. x = self.activation(x)
  720. x = self.upsample(x)
  721. x = self.final_convolution(x)
  722. return x
  723. class NearestConvUpsampler(nn.Module):
  724. def __init__(self, config, num_features):
  725. super().__init__()
  726. if config.upscale != 4:
  727. raise ValueError("The nearest+conv upsampler only supports an upscale factor of 4 at the moment.")
  728. self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
  729. self.activation = nn.LeakyReLU(inplace=True)
  730. self.conv_up1 = nn.Conv2d(num_features, num_features, 3, 1, 1)
  731. self.conv_up2 = nn.Conv2d(num_features, num_features, 3, 1, 1)
  732. self.conv_hr = nn.Conv2d(num_features, num_features, 3, 1, 1)
  733. self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
  734. self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  735. def forward(self, sequence_output):
  736. sequence_output = self.conv_before_upsample(sequence_output)
  737. sequence_output = self.activation(sequence_output)
  738. sequence_output = self.lrelu(
  739. self.conv_up1(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest"))
  740. )
  741. sequence_output = self.lrelu(
  742. self.conv_up2(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest"))
  743. )
  744. reconstruction = self.final_convolution(self.lrelu(self.conv_hr(sequence_output)))
  745. return reconstruction
  746. class PixelShuffleAuxUpsampler(nn.Module):
  747. def __init__(self, config, num_features):
  748. super().__init__()
  749. self.upscale = config.upscale
  750. self.conv_bicubic = nn.Conv2d(config.num_channels, num_features, 3, 1, 1)
  751. self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
  752. self.activation = nn.LeakyReLU(inplace=True)
  753. self.conv_aux = nn.Conv2d(num_features, config.num_channels, 3, 1, 1)
  754. self.conv_after_aux = nn.Sequential(nn.Conv2d(3, num_features, 3, 1, 1), nn.LeakyReLU(inplace=True))
  755. self.upsample = Upsample(config.upscale, num_features)
  756. self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
  757. def forward(self, sequence_output, bicubic, height, width):
  758. bicubic = self.conv_bicubic(bicubic)
  759. sequence_output = self.conv_before_upsample(sequence_output)
  760. sequence_output = self.activation(sequence_output)
  761. aux = self.conv_aux(sequence_output)
  762. sequence_output = self.conv_after_aux(aux)
  763. sequence_output = (
  764. self.upsample(sequence_output)[:, :, : height * self.upscale, : width * self.upscale]
  765. + bicubic[:, :, : height * self.upscale, : width * self.upscale]
  766. )
  767. reconstruction = self.final_convolution(sequence_output)
  768. return reconstruction, aux
  769. @auto_docstring(
  770. custom_intro="""
  771. Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration.
  772. """
  773. )
  774. class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
  775. def __init__(self, config):
  776. super().__init__(config)
  777. self.swin2sr = Swin2SRModel(config)
  778. self.upsampler = config.upsampler
  779. self.upscale = config.upscale
  780. # Upsampler
  781. num_features = 64
  782. if self.upsampler == "pixelshuffle":
  783. self.upsample = PixelShuffleUpsampler(config, num_features)
  784. elif self.upsampler == "pixelshuffle_aux":
  785. self.upsample = PixelShuffleAuxUpsampler(config, num_features)
  786. elif self.upsampler == "pixelshuffledirect":
  787. # for lightweight SR (to save parameters)
  788. self.upsample = UpsampleOneStep(config.upscale, config.embed_dim, config.num_channels_out)
  789. elif self.upsampler == "nearest+conv":
  790. # for real-world SR (less artifacts)
  791. self.upsample = NearestConvUpsampler(config, num_features)
  792. else:
  793. # for image denoising and JPEG compression artifact reduction
  794. self.final_convolution = nn.Conv2d(config.embed_dim, config.num_channels_out, 3, 1, 1)
  795. # Initialize weights and apply final processing
  796. self.post_init()
  797. @auto_docstring
  798. def forward(
  799. self,
  800. pixel_values: torch.FloatTensor | None = None,
  801. labels: torch.LongTensor | None = None,
  802. output_attentions: bool | None = None,
  803. output_hidden_states: bool | None = None,
  804. return_dict: bool | None = None,
  805. **kwargs,
  806. ) -> tuple | ImageSuperResolutionOutput:
  807. r"""
  808. Example:
  809. ```python
  810. >>> import torch
  811. >>> import numpy as np
  812. >>> from PIL import Image
  813. >>> import httpx
  814. >>> from io import BytesIO
  815. >>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
  816. >>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
  817. >>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
  818. >>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg"
  819. >>> with httpx.stream("GET", url) as response:
  820. ... image = Image.open(BytesIO(response.read()))
  821. >>> # prepare image for the model
  822. >>> inputs = processor(image, return_tensors="pt")
  823. >>> # forward pass
  824. >>> with torch.no_grad():
  825. ... outputs = model(**inputs)
  826. >>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  827. >>> output = np.moveaxis(output, source=0, destination=-1)
  828. >>> output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
  829. >>> # you can visualize `output` with `Image.fromarray`
  830. ```"""
  831. return_dict = return_dict if return_dict is not None else self.config.return_dict
  832. loss = None
  833. if labels is not None:
  834. raise NotImplementedError("Training is not supported at the moment")
  835. height, width = pixel_values.shape[2:]
  836. if self.config.upsampler == "pixelshuffle_aux":
  837. bicubic = nn.functional.interpolate(
  838. pixel_values,
  839. size=(height * self.upscale, width * self.upscale),
  840. mode="bicubic",
  841. align_corners=False,
  842. )
  843. outputs = self.swin2sr(
  844. pixel_values,
  845. output_attentions=output_attentions,
  846. output_hidden_states=output_hidden_states,
  847. return_dict=return_dict,
  848. )
  849. sequence_output = outputs[0]
  850. if self.upsampler in ["pixelshuffle", "pixelshuffledirect", "nearest+conv"]:
  851. reconstruction = self.upsample(sequence_output)
  852. elif self.upsampler == "pixelshuffle_aux":
  853. reconstruction, aux = self.upsample(sequence_output, bicubic, height, width)
  854. aux = aux / self.swin2sr.img_range + self.swin2sr.mean
  855. else:
  856. reconstruction = pixel_values + self.final_convolution(sequence_output)
  857. reconstruction = reconstruction / self.swin2sr.img_range + self.swin2sr.mean
  858. reconstruction = reconstruction[:, :, : height * self.upscale, : width * self.upscale]
  859. if not return_dict:
  860. output = (reconstruction,) + outputs[1:]
  861. return ((loss,) + output) if loss is not None else output
  862. return ImageSuperResolutionOutput(
  863. loss=loss,
  864. reconstruction=reconstruction,
  865. hidden_states=outputs.hidden_states,
  866. attentions=outputs.attentions,
  867. )
  868. __all__ = ["Swin2SRForImageSuperResolution", "Swin2SRModel", "Swin2SRPreTrainedModel"]