model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Based from the original code from Meta Platforms, Inc. and affiliates.
  18. https://github.com/facebookresearch/segment-
  19. anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/build_sam.py
  20. https://github.com/facebookresearch/segment-
  21. anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/modeling/sam.py
  22. """
  23. from __future__ import annotations
  24. import warnings
  25. from dataclasses import dataclass
  26. from enum import Enum
  27. from typing import Any, Optional
  28. import torch
  29. from kornia.contrib.models import SegmentationResults
  30. from kornia.contrib.models.base import ModelBase
  31. from kornia.contrib.models.sam.architecture.common import LayerNorm
  32. from kornia.contrib.models.sam.architecture.image_encoder import ImageEncoderViT
  33. from kornia.contrib.models.sam.architecture.mask_decoder import MaskDecoder
  34. from kornia.contrib.models.sam.architecture.prompt_encoder import PromptEncoder
  35. from kornia.contrib.models.sam.architecture.transformer import TwoWayTransformer
  36. from kornia.contrib.models.tiny_vit import TinyViT
  37. from kornia.core import Tensor
  38. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
  39. class SamModelType(Enum):
  40. """Map the SAM model types."""
  41. vit_h = 0
  42. vit_l = 1
  43. vit_b = 2
  44. mobile_sam = 3
  45. @dataclass
  46. class SamConfig:
  47. """Encapsulate the Config to build a SAM model.
  48. Args:
  49. model_type: the available models are:
  50. - 0, 'vit_h' or :func:`kornia.contrib.sam.SamModelType.vit_h`
  51. - 1, 'vit_l' or :func:`kornia.contrib.sam.SamModelType.vit_l`
  52. - 2, 'vit_b' or :func:`kornia.contrib.sam.SamModelType.vit_b`
  53. - 3, 'mobile_sam', or :func:`kornia.contrib.sam.SamModelType.mobile_sam`
  54. checkpoint: URL or a path for a file with the weights of the model
  55. encoder_embed_dim: Patch embedding dimension.
  56. encoder_depth: Depth of ViT.
  57. encoder_num_heads: Number of attention heads in each ViT block.
  58. encoder_global_attn_indexes: Encoder indexes for blocks using global attention.
  59. """
  60. model_type: Optional[str | int | SamModelType] = None
  61. checkpoint: Optional[str] = None
  62. pretrained: bool = False
  63. encoder_embed_dim: Optional[int] = None
  64. encoder_depth: Optional[int] = None
  65. encoder_num_heads: Optional[int] = None
  66. encoder_global_attn_indexes: Optional[tuple[int, ...]] = None
  67. class Sam(ModelBase[SamConfig]):
  68. mask_threshold: float = 0.0
  69. def __init__(
  70. self, image_encoder: ImageEncoderViT | TinyViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder
  71. ) -> None:
  72. """SAM predicts object masks from an image and input prompts.
  73. Args:
  74. image_encoder: The backbone used to encode the image into image embeddings that allow for efficient mask
  75. prediction.
  76. prompt_encoder: Encodes various types of input prompts.
  77. mask_decoder: Predicts masks from the image embeddings and encoded prompts.
  78. """
  79. super().__init__()
  80. self.image_encoder = image_encoder
  81. self.prompt_encoder = prompt_encoder
  82. self.mask_decoder = mask_decoder
  83. @staticmethod
  84. def from_name(name: str) -> Sam:
  85. """Build/load the SAM model based on it's name.
  86. Args:
  87. name: The name of the SAM model. Valid names are:
  88. - 'vit_b'
  89. - 'vit_l'
  90. - 'vit_h'
  91. - 'mobile_sam'
  92. Returns:
  93. The respective SAM model
  94. """
  95. if name in ["vit_b", "vit_l", "vit_h", "mobile_sam"]:
  96. return Sam.from_config(SamConfig(name))
  97. else:
  98. raise ValueError(f"Invalid SAM model name: {name}")
  99. @staticmethod
  100. def from_config(config: SamConfig) -> Sam:
  101. """Build/load the SAM model based on it's config.
  102. Args:
  103. config: The SamConfig data structure. If the model_type is available, build from it, otherwise will use
  104. the parameters set.
  105. Returns:
  106. The respective SAM model
  107. Example:
  108. >>> from kornia.contrib.models.sam import SamConfig
  109. >>> sam_model = Sam.from_config(SamConfig('vit_b'))
  110. """
  111. model_type = config.model_type
  112. if isinstance(model_type, int):
  113. model_type = SamModelType(model_type)
  114. elif isinstance(model_type, str):
  115. _map_sam_type = {
  116. "vit_h": SamModelType.vit_h,
  117. "vit_l": SamModelType.vit_l,
  118. "vit_b": SamModelType.vit_b,
  119. "mobile_sam": SamModelType.mobile_sam,
  120. }
  121. model_type = _map_sam_type[model_type]
  122. if model_type == SamModelType.vit_b:
  123. model = _build_sam(
  124. encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=(2, 5, 8, 11)
  125. )
  126. elif model_type == SamModelType.vit_l:
  127. model = _build_sam(
  128. encoder_embed_dim=1024,
  129. encoder_depth=24,
  130. encoder_num_heads=16,
  131. encoder_global_attn_indexes=(5, 11, 17, 23),
  132. )
  133. elif model_type == SamModelType.vit_h:
  134. model = _build_sam(
  135. encoder_embed_dim=1280,
  136. encoder_depth=32,
  137. encoder_num_heads=16,
  138. encoder_global_attn_indexes=(7, 15, 23, 31),
  139. )
  140. elif model_type == SamModelType.mobile_sam:
  141. # TODO: merge this with _build_sam()
  142. prompt_embed_dim = 256
  143. image_size = 1024
  144. vit_patch_size = 16
  145. image_embedding_size = image_size // vit_patch_size
  146. model = Sam(
  147. image_encoder=TinyViT.from_config("5m", img_size=image_size, mobile_sam=True),
  148. prompt_encoder=PromptEncoder(
  149. embed_dim=prompt_embed_dim,
  150. image_embedding_size=(image_embedding_size, image_embedding_size),
  151. input_image_size=(image_size, image_size),
  152. mask_in_chans=16,
  153. ),
  154. mask_decoder=MaskDecoder(
  155. num_multimask_outputs=3,
  156. transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
  157. transformer_dim=prompt_embed_dim,
  158. iou_head_depth=3,
  159. iou_head_hidden_dim=256,
  160. ),
  161. # pixel_mean=[123.675, 116.28, 103.53],
  162. # pixel_std=[58.395, 57.12, 57.375],
  163. )
  164. elif (
  165. isinstance(config.encoder_embed_dim, int)
  166. and isinstance(config.encoder_depth, int)
  167. and isinstance(config.encoder_num_heads, int)
  168. and isinstance(config.encoder_global_attn_indexes, int)
  169. ):
  170. model = _build_sam(
  171. encoder_embed_dim=config.encoder_embed_dim,
  172. encoder_depth=config.encoder_depth,
  173. encoder_num_heads=config.num_heads,
  174. encoder_global_attn_indexes=config.encoder_global_attn_indexes,
  175. )
  176. else:
  177. raise NotImplementedError("Unexpected config. The model_type should be provide or the encoder configs.")
  178. checkpoint = config.checkpoint
  179. if config.pretrained:
  180. if checkpoint is None:
  181. checkpoint = {
  182. SamModelType.vit_b: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
  183. SamModelType.vit_l: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
  184. SamModelType.vit_h: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
  185. SamModelType.mobile_sam: "https://github.com/ChaoningZhang/MobileSAM/raw/a509aac54fdd7af59f843135f2f7cee307283c88/weights/mobile_sam.pt",
  186. }[model_type]
  187. else:
  188. warnings.warn("checkpoint is not None. pretrained=True is ignored", stacklevel=1)
  189. if checkpoint:
  190. model.load_checkpoint(checkpoint)
  191. return model
  192. @torch.no_grad()
  193. def forward(
  194. self, images: Tensor, batched_prompts: list[dict[str, Any]], multimask_output: bool
  195. ) -> list[SegmentationResults]:
  196. """Predicts masks end-to-end from provided images and prompts.
  197. This method expects that the images have already been pre-processed, at least been normalized, resized and
  198. padded to be compatible with the `self.image_encoder`.
  199. .. note:: For each image :math:`(3, H, W)`, it is possible to input a batch (:math:`K`) of :math:`N` prompts,
  200. the results are batched by the number of prompts batch. So given a prompt with :math:`K=5`, and
  201. :math:`N=10`, the results will look like :math:`5xCxHxW` where :math:`C` is determined by
  202. multimask_output. And within each of these masks :math:`(5xC)`, it should be possible to find
  203. :math:`N` instances if the model succeed.
  204. Args:
  205. images: The image as a torch tensor in :math:`(B, 3, H, W)` format, already transformed for input to the
  206. model.
  207. batched_prompts: A list over the batch of images (list length should be :math:`B`), each a dictionary with
  208. the following keys. If it does not have the respective prompt, it should not be included
  209. in this dictionary. The options are:
  210. - "points": tuple of (Tensor, Tensor) within the coordinate keypoints and their respective labels.
  211. the tuple should look like (keypoints, labels), where:
  212. - The keypoints (a tensor) are a batched point prompts for this image, with shape
  213. :math:`(K, N, 2)`. Already transformed to the input frame of the model.
  214. - The labels (a tensor) are a batched labels for point prompts, with shape :math:`(K, N)`.
  215. Where 1 indicates a foreground point and 0 indicates a background point.
  216. - "boxes": (Tensor) Batched box inputs, with shape :math:`(K, 4)`. Already transformed to the input
  217. frame of the model.
  218. - "mask_inputs": (Tensor) Batched mask inputs to the model, in the form :math:`(K, 1, H, W)`.
  219. multimask_output: Whether the model should predict multiple disambiguating masks, or return a single mask.
  220. Returns:
  221. A list over input images, where each element is as SegmentationResults the following.
  222. - logits: Low resolution logits with shape :math:`(K, C, H, W)`. Can be passed as mask input to
  223. subsequent iterations of prediction. Where :math:`K` is the number of input prompts,
  224. :math:`C` is determined by multimask_output, and :math:`H=W=256` are the model output size.
  225. - scores: The model's predictions of mask quality (iou prediction), in shape BxC.
  226. """
  227. KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
  228. KORNIA_CHECK(
  229. images.shape[0] == len(batched_prompts),
  230. "The number of images (`B`) should match with the length of prompts!",
  231. )
  232. image_embeddings = self.image_encoder(images)
  233. outputs = []
  234. for prompt_record, curr_embedding in zip(batched_prompts, image_embeddings):
  235. # Embed prompts
  236. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  237. points=prompt_record.get("points", None),
  238. boxes=prompt_record.get("boxes", None),
  239. masks=prompt_record.get("mask_inputs", None),
  240. )
  241. # Predict masks
  242. low_res_logits, iou_predictions = self.mask_decoder(
  243. image_embeddings=curr_embedding[None, ...],
  244. image_pe=self.prompt_encoder.get_dense_pe(),
  245. sparse_prompt_embeddings=sparse_embeddings,
  246. dense_prompt_embeddings=dense_embeddings,
  247. multimask_output=multimask_output,
  248. )
  249. # Save results
  250. outputs.append(SegmentationResults(low_res_logits, iou_predictions, self.mask_threshold))
  251. return outputs
  252. def _build_sam(
  253. encoder_embed_dim: int, encoder_depth: int, encoder_num_heads: int, encoder_global_attn_indexes: tuple[int, ...]
  254. ) -> Sam:
  255. prompt_embed_dim = 256
  256. image_size = 1024
  257. vit_patch_size = 16
  258. image_embedding_size = image_size // vit_patch_size
  259. return Sam(
  260. image_encoder=ImageEncoderViT(
  261. depth=encoder_depth,
  262. embed_dim=encoder_embed_dim,
  263. img_size=image_size,
  264. mlp_ratio=4,
  265. norm_layer=LayerNorm,
  266. num_heads=encoder_num_heads,
  267. patch_size=vit_patch_size,
  268. qkv_bias=True,
  269. use_rel_pos=True,
  270. global_attn_indexes=encoder_global_attn_indexes,
  271. window_size=14,
  272. out_chans=prompt_embed_dim,
  273. ),
  274. prompt_encoder=PromptEncoder(
  275. embed_dim=prompt_embed_dim,
  276. image_embedding_size=(image_embedding_size, image_embedding_size),
  277. input_image_size=(image_size, image_size),
  278. mask_in_chans=16,
  279. ),
  280. mask_decoder=MaskDecoder(
  281. num_multimask_outputs=3,
  282. transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
  283. transformer_dim=prompt_embed_dim,
  284. iou_head_depth=3,
  285. iou_head_hidden_dim=256,
  286. ),
  287. # pixel_mean=[123.675, 116.28, 103.53],
  288. # pixel_std=[58.395, 57.12, 57.375],
  289. )