patch_embed.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # References:
  7. # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
  8. # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
  9. from typing import Callable, Optional, Tuple, Union
  10. from torch import Tensor
  11. import torch.nn as nn
  12. def make_2tuple(x):
  13. if isinstance(x, tuple):
  14. assert len(x) == 2
  15. return x
  16. assert isinstance(x, int)
  17. return (x, x)
  18. class PatchEmbed(nn.Module):
  19. """
  20. 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
  21. Args:
  22. img_size: Image size.
  23. patch_size: Patch token size.
  24. in_chans: Number of input image channels.
  25. embed_dim: Number of linear projection output channels.
  26. norm_layer: Normalization layer.
  27. """
  28. def __init__(
  29. self,
  30. img_size: Union[int, Tuple[int, int]] = 224,
  31. patch_size: Union[int, Tuple[int, int]] = 16,
  32. in_chans: int = 3,
  33. embed_dim: int = 768,
  34. norm_layer: Optional[Callable] = None,
  35. flatten_embedding: bool = True,
  36. ) -> None:
  37. super().__init__()
  38. image_HW = make_2tuple(img_size)
  39. patch_HW = make_2tuple(patch_size)
  40. patch_grid_size = (
  41. image_HW[0] // patch_HW[0],
  42. image_HW[1] // patch_HW[1],
  43. )
  44. self.img_size = image_HW
  45. self.patch_size = patch_HW
  46. self.patches_resolution = patch_grid_size
  47. self.num_patches = patch_grid_size[0] * patch_grid_size[1]
  48. self.in_chans = in_chans
  49. self.embed_dim = embed_dim
  50. self.flatten_embedding = flatten_embedding
  51. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
  52. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  53. def forward(self, x: Tensor) -> Tensor:
  54. _, _, H, W = x.shape
  55. patch_H, patch_W = self.patch_size
  56. assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
  57. assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
  58. x = self.proj(x) # B C H W
  59. H, W = x.size(2), x.size(3)
  60. x = x.flatten(2).transpose(1, 2) # B HW C
  61. x = self.norm(x)
  62. if not self.flatten_embedding:
  63. x = x.reshape(-1, H, W, self.embed_dim) # B H W C
  64. return x
  65. def flops(self) -> float:
  66. Ho, Wo = self.patches_resolution
  67. flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
  68. if self.norm is not None:
  69. flops += Ho * Wo * self.embed_dim
  70. return flops