# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Based from the original code from Meta Platforms, Inc. and affiliates. https://github.com/facebookresearch/segment- anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/build_sam.py https://github.com/facebookresearch/segment- anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/modeling/sam.py """ from __future__ import annotations import warnings from dataclasses import dataclass from enum import Enum from typing import Any, Optional import torch from kornia.contrib.models import SegmentationResults from kornia.contrib.models.base import ModelBase from kornia.contrib.models.sam.architecture.common import LayerNorm from kornia.contrib.models.sam.architecture.image_encoder import ImageEncoderViT from kornia.contrib.models.sam.architecture.mask_decoder import MaskDecoder from kornia.contrib.models.sam.architecture.prompt_encoder import PromptEncoder from kornia.contrib.models.sam.architecture.transformer import TwoWayTransformer from kornia.contrib.models.tiny_vit import TinyViT from kornia.core import Tensor from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE class SamModelType(Enum): """Map the SAM model types.""" vit_h = 0 vit_l = 1 vit_b = 2 mobile_sam = 3 @dataclass class SamConfig: """Encapsulate the Config to build a SAM model. Args: model_type: the available models are: - 0, 'vit_h' or :func:`kornia.contrib.sam.SamModelType.vit_h` - 1, 'vit_l' or :func:`kornia.contrib.sam.SamModelType.vit_l` - 2, 'vit_b' or :func:`kornia.contrib.sam.SamModelType.vit_b` - 3, 'mobile_sam', or :func:`kornia.contrib.sam.SamModelType.mobile_sam` checkpoint: URL or a path for a file with the weights of the model encoder_embed_dim: Patch embedding dimension. encoder_depth: Depth of ViT. encoder_num_heads: Number of attention heads in each ViT block. encoder_global_attn_indexes: Encoder indexes for blocks using global attention. """ model_type: Optional[str | int | SamModelType] = None checkpoint: Optional[str] = None pretrained: bool = False encoder_embed_dim: Optional[int] = None encoder_depth: Optional[int] = None encoder_num_heads: Optional[int] = None encoder_global_attn_indexes: Optional[tuple[int, ...]] = None class Sam(ModelBase[SamConfig]): mask_threshold: float = 0.0 def __init__( self, image_encoder: ImageEncoderViT | TinyViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder ) -> None: """SAM predicts object masks from an image and input prompts. Args: image_encoder: The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder: Encodes various types of input prompts. mask_decoder: Predicts masks from the image embeddings and encoded prompts. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder @staticmethod def from_name(name: str) -> Sam: """Build/load the SAM model based on it's name. Args: name: The name of the SAM model. Valid names are: - 'vit_b' - 'vit_l' - 'vit_h' - 'mobile_sam' Returns: The respective SAM model """ if name in ["vit_b", "vit_l", "vit_h", "mobile_sam"]: return Sam.from_config(SamConfig(name)) else: raise ValueError(f"Invalid SAM model name: {name}") @staticmethod def from_config(config: SamConfig) -> Sam: """Build/load the SAM model based on it's config. Args: config: The SamConfig data structure. If the model_type is available, build from it, otherwise will use the parameters set. Returns: The respective SAM model Example: >>> from kornia.contrib.models.sam import SamConfig >>> sam_model = Sam.from_config(SamConfig('vit_b')) """ model_type = config.model_type if isinstance(model_type, int): model_type = SamModelType(model_type) elif isinstance(model_type, str): _map_sam_type = { "vit_h": SamModelType.vit_h, "vit_l": SamModelType.vit_l, "vit_b": SamModelType.vit_b, "mobile_sam": SamModelType.mobile_sam, } model_type = _map_sam_type[model_type] if model_type == SamModelType.vit_b: model = _build_sam( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=(2, 5, 8, 11) ) elif model_type == SamModelType.vit_l: model = _build_sam( encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16, encoder_global_attn_indexes=(5, 11, 17, 23), ) elif model_type == SamModelType.vit_h: model = _build_sam( encoder_embed_dim=1280, encoder_depth=32, encoder_num_heads=16, encoder_global_attn_indexes=(7, 15, 23, 31), ) elif model_type == SamModelType.mobile_sam: # TODO: merge this with _build_sam() prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size model = Sam( image_encoder=TinyViT.from_config("5m", img_size=image_size, mobile_sam=True), prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, ), mask_decoder=MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8), transformer_dim=prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, ), # pixel_mean=[123.675, 116.28, 103.53], # pixel_std=[58.395, 57.12, 57.375], ) elif ( isinstance(config.encoder_embed_dim, int) and isinstance(config.encoder_depth, int) and isinstance(config.encoder_num_heads, int) and isinstance(config.encoder_global_attn_indexes, int) ): model = _build_sam( encoder_embed_dim=config.encoder_embed_dim, encoder_depth=config.encoder_depth, encoder_num_heads=config.num_heads, encoder_global_attn_indexes=config.encoder_global_attn_indexes, ) else: raise NotImplementedError("Unexpected config. The model_type should be provide or the encoder configs.") checkpoint = config.checkpoint if config.pretrained: if checkpoint is None: checkpoint = { SamModelType.vit_b: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", SamModelType.vit_l: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", SamModelType.vit_h: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", SamModelType.mobile_sam: "https://github.com/ChaoningZhang/MobileSAM/raw/a509aac54fdd7af59f843135f2f7cee307283c88/weights/mobile_sam.pt", }[model_type] else: warnings.warn("checkpoint is not None. pretrained=True is ignored", stacklevel=1) if checkpoint: model.load_checkpoint(checkpoint) return model @torch.no_grad() def forward( self, images: Tensor, batched_prompts: list[dict[str, Any]], multimask_output: bool ) -> list[SegmentationResults]: """Predicts masks end-to-end from provided images and prompts. This method expects that the images have already been pre-processed, at least been normalized, resized and padded to be compatible with the `self.image_encoder`. .. note:: For each image :math:`(3, H, W)`, it is possible to input a batch (:math:`K`) of :math:`N` prompts, the results are batched by the number of prompts batch. So given a prompt with :math:`K=5`, and :math:`N=10`, the results will look like :math:`5xCxHxW` where :math:`C` is determined by multimask_output. And within each of these masks :math:`(5xC)`, it should be possible to find :math:`N` instances if the model succeed. Args: images: The image as a torch tensor in :math:`(B, 3, H, W)` format, already transformed for input to the model. batched_prompts: A list over the batch of images (list length should be :math:`B`), each a dictionary with the following keys. If it does not have the respective prompt, it should not be included in this dictionary. The options are: - "points": tuple of (Tensor, Tensor) within the coordinate keypoints and their respective labels. the tuple should look like (keypoints, labels), where: - The keypoints (a tensor) are a batched point prompts for this image, with shape :math:`(K, N, 2)`. Already transformed to the input frame of the model. - The labels (a tensor) are a batched labels for point prompts, with shape :math:`(K, N)`. Where 1 indicates a foreground point and 0 indicates a background point. - "boxes": (Tensor) Batched box inputs, with shape :math:`(K, 4)`. Already transformed to the input frame of the model. - "mask_inputs": (Tensor) Batched mask inputs to the model, in the form :math:`(K, 1, H, W)`. multimask_output: Whether the model should predict multiple disambiguating masks, or return a single mask. Returns: A list over input images, where each element is as SegmentationResults the following. - logits: Low resolution logits with shape :math:`(K, C, H, W)`. Can be passed as mask input to subsequent iterations of prediction. Where :math:`K` is the number of input prompts, :math:`C` is determined by multimask_output, and :math:`H=W=256` are the model output size. - scores: The model's predictions of mask quality (iou prediction), in shape BxC. """ KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"]) KORNIA_CHECK( images.shape[0] == len(batched_prompts), "The number of images (`B`) should match with the length of prompts!", ) image_embeddings = self.image_encoder(images) outputs = [] for prompt_record, curr_embedding in zip(batched_prompts, image_embeddings): # Embed prompts sparse_embeddings, dense_embeddings = self.prompt_encoder( points=prompt_record.get("points", None), boxes=prompt_record.get("boxes", None), masks=prompt_record.get("mask_inputs", None), ) # Predict masks low_res_logits, iou_predictions = self.mask_decoder( image_embeddings=curr_embedding[None, ...], image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) # Save results outputs.append(SegmentationResults(low_res_logits, iou_predictions, self.mask_threshold)) return outputs def _build_sam( encoder_embed_dim: int, encoder_depth: int, encoder_num_heads: int, encoder_global_attn_indexes: tuple[int, ...] ) -> Sam: prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size return Sam( image_encoder=ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=LayerNorm, num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ), prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, ), mask_decoder=MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8), transformer_dim=prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, ), # pixel_mean=[123.675, 116.28, 103.53], # pixel_std=[58.395, 57.12, 57.375], )