| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- """Based from the original code from Meta Platforms, Inc. and affiliates.
- https://github.com/facebookresearch/segment-
- anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/build_sam.py
- https://github.com/facebookresearch/segment-
- anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/modeling/sam.py
- """
- from __future__ import annotations
- import warnings
- from dataclasses import dataclass
- from enum import Enum
- from typing import Any, Optional
- import torch
- from kornia.contrib.models import SegmentationResults
- from kornia.contrib.models.base import ModelBase
- from kornia.contrib.models.sam.architecture.common import LayerNorm
- from kornia.contrib.models.sam.architecture.image_encoder import ImageEncoderViT
- from kornia.contrib.models.sam.architecture.mask_decoder import MaskDecoder
- from kornia.contrib.models.sam.architecture.prompt_encoder import PromptEncoder
- from kornia.contrib.models.sam.architecture.transformer import TwoWayTransformer
- from kornia.contrib.models.tiny_vit import TinyViT
- from kornia.core import Tensor
- from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
- class SamModelType(Enum):
- """Map the SAM model types."""
- vit_h = 0
- vit_l = 1
- vit_b = 2
- mobile_sam = 3
- @dataclass
- class SamConfig:
- """Encapsulate the Config to build a SAM model.
- Args:
- model_type: the available models are:
- - 0, 'vit_h' or :func:`kornia.contrib.sam.SamModelType.vit_h`
- - 1, 'vit_l' or :func:`kornia.contrib.sam.SamModelType.vit_l`
- - 2, 'vit_b' or :func:`kornia.contrib.sam.SamModelType.vit_b`
- - 3, 'mobile_sam', or :func:`kornia.contrib.sam.SamModelType.mobile_sam`
- checkpoint: URL or a path for a file with the weights of the model
- encoder_embed_dim: Patch embedding dimension.
- encoder_depth: Depth of ViT.
- encoder_num_heads: Number of attention heads in each ViT block.
- encoder_global_attn_indexes: Encoder indexes for blocks using global attention.
- """
- model_type: Optional[str | int | SamModelType] = None
- checkpoint: Optional[str] = None
- pretrained: bool = False
- encoder_embed_dim: Optional[int] = None
- encoder_depth: Optional[int] = None
- encoder_num_heads: Optional[int] = None
- encoder_global_attn_indexes: Optional[tuple[int, ...]] = None
- class Sam(ModelBase[SamConfig]):
- mask_threshold: float = 0.0
- def __init__(
- self, image_encoder: ImageEncoderViT | TinyViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder
- ) -> None:
- """SAM predicts object masks from an image and input prompts.
- Args:
- image_encoder: The backbone used to encode the image into image embeddings that allow for efficient mask
- prediction.
- prompt_encoder: Encodes various types of input prompts.
- mask_decoder: Predicts masks from the image embeddings and encoded prompts.
- """
- super().__init__()
- self.image_encoder = image_encoder
- self.prompt_encoder = prompt_encoder
- self.mask_decoder = mask_decoder
- @staticmethod
- def from_name(name: str) -> Sam:
- """Build/load the SAM model based on it's name.
- Args:
- name: The name of the SAM model. Valid names are:
- - 'vit_b'
- - 'vit_l'
- - 'vit_h'
- - 'mobile_sam'
- Returns:
- The respective SAM model
- """
- if name in ["vit_b", "vit_l", "vit_h", "mobile_sam"]:
- return Sam.from_config(SamConfig(name))
- else:
- raise ValueError(f"Invalid SAM model name: {name}")
- @staticmethod
- def from_config(config: SamConfig) -> Sam:
- """Build/load the SAM model based on it's config.
- Args:
- config: The SamConfig data structure. If the model_type is available, build from it, otherwise will use
- the parameters set.
- Returns:
- The respective SAM model
- Example:
- >>> from kornia.contrib.models.sam import SamConfig
- >>> sam_model = Sam.from_config(SamConfig('vit_b'))
- """
- model_type = config.model_type
- if isinstance(model_type, int):
- model_type = SamModelType(model_type)
- elif isinstance(model_type, str):
- _map_sam_type = {
- "vit_h": SamModelType.vit_h,
- "vit_l": SamModelType.vit_l,
- "vit_b": SamModelType.vit_b,
- "mobile_sam": SamModelType.mobile_sam,
- }
- model_type = _map_sam_type[model_type]
- if model_type == SamModelType.vit_b:
- model = _build_sam(
- encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=(2, 5, 8, 11)
- )
- elif model_type == SamModelType.vit_l:
- model = _build_sam(
- encoder_embed_dim=1024,
- encoder_depth=24,
- encoder_num_heads=16,
- encoder_global_attn_indexes=(5, 11, 17, 23),
- )
- elif model_type == SamModelType.vit_h:
- model = _build_sam(
- encoder_embed_dim=1280,
- encoder_depth=32,
- encoder_num_heads=16,
- encoder_global_attn_indexes=(7, 15, 23, 31),
- )
- elif model_type == SamModelType.mobile_sam:
- # TODO: merge this with _build_sam()
- prompt_embed_dim = 256
- image_size = 1024
- vit_patch_size = 16
- image_embedding_size = image_size // vit_patch_size
- model = Sam(
- image_encoder=TinyViT.from_config("5m", img_size=image_size, mobile_sam=True),
- prompt_encoder=PromptEncoder(
- embed_dim=prompt_embed_dim,
- image_embedding_size=(image_embedding_size, image_embedding_size),
- input_image_size=(image_size, image_size),
- mask_in_chans=16,
- ),
- mask_decoder=MaskDecoder(
- num_multimask_outputs=3,
- transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
- transformer_dim=prompt_embed_dim,
- iou_head_depth=3,
- iou_head_hidden_dim=256,
- ),
- # pixel_mean=[123.675, 116.28, 103.53],
- # pixel_std=[58.395, 57.12, 57.375],
- )
- elif (
- isinstance(config.encoder_embed_dim, int)
- and isinstance(config.encoder_depth, int)
- and isinstance(config.encoder_num_heads, int)
- and isinstance(config.encoder_global_attn_indexes, int)
- ):
- model = _build_sam(
- encoder_embed_dim=config.encoder_embed_dim,
- encoder_depth=config.encoder_depth,
- encoder_num_heads=config.num_heads,
- encoder_global_attn_indexes=config.encoder_global_attn_indexes,
- )
- else:
- raise NotImplementedError("Unexpected config. The model_type should be provide or the encoder configs.")
- checkpoint = config.checkpoint
- if config.pretrained:
- if checkpoint is None:
- checkpoint = {
- SamModelType.vit_b: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
- SamModelType.vit_l: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
- SamModelType.vit_h: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
- SamModelType.mobile_sam: "https://github.com/ChaoningZhang/MobileSAM/raw/a509aac54fdd7af59f843135f2f7cee307283c88/weights/mobile_sam.pt",
- }[model_type]
- else:
- warnings.warn("checkpoint is not None. pretrained=True is ignored", stacklevel=1)
- if checkpoint:
- model.load_checkpoint(checkpoint)
- return model
- @torch.no_grad()
- def forward(
- self, images: Tensor, batched_prompts: list[dict[str, Any]], multimask_output: bool
- ) -> list[SegmentationResults]:
- """Predicts masks end-to-end from provided images and prompts.
- This method expects that the images have already been pre-processed, at least been normalized, resized and
- padded to be compatible with the `self.image_encoder`.
- .. note:: For each image :math:`(3, H, W)`, it is possible to input a batch (:math:`K`) of :math:`N` prompts,
- the results are batched by the number of prompts batch. So given a prompt with :math:`K=5`, and
- :math:`N=10`, the results will look like :math:`5xCxHxW` where :math:`C` is determined by
- multimask_output. And within each of these masks :math:`(5xC)`, it should be possible to find
- :math:`N` instances if the model succeed.
- Args:
- images: The image as a torch tensor in :math:`(B, 3, H, W)` format, already transformed for input to the
- model.
- batched_prompts: A list over the batch of images (list length should be :math:`B`), each a dictionary with
- the following keys. If it does not have the respective prompt, it should not be included
- in this dictionary. The options are:
- - "points": tuple of (Tensor, Tensor) within the coordinate keypoints and their respective labels.
- the tuple should look like (keypoints, labels), where:
- - The keypoints (a tensor) are a batched point prompts for this image, with shape
- :math:`(K, N, 2)`. Already transformed to the input frame of the model.
- - The labels (a tensor) are a batched labels for point prompts, with shape :math:`(K, N)`.
- Where 1 indicates a foreground point and 0 indicates a background point.
- - "boxes": (Tensor) Batched box inputs, with shape :math:`(K, 4)`. Already transformed to the input
- frame of the model.
- - "mask_inputs": (Tensor) Batched mask inputs to the model, in the form :math:`(K, 1, H, W)`.
- multimask_output: Whether the model should predict multiple disambiguating masks, or return a single mask.
- Returns:
- A list over input images, where each element is as SegmentationResults the following.
- - logits: Low resolution logits with shape :math:`(K, C, H, W)`. Can be passed as mask input to
- subsequent iterations of prediction. Where :math:`K` is the number of input prompts,
- :math:`C` is determined by multimask_output, and :math:`H=W=256` are the model output size.
- - scores: The model's predictions of mask quality (iou prediction), in shape BxC.
- """
- KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
- KORNIA_CHECK(
- images.shape[0] == len(batched_prompts),
- "The number of images (`B`) should match with the length of prompts!",
- )
- image_embeddings = self.image_encoder(images)
- outputs = []
- for prompt_record, curr_embedding in zip(batched_prompts, image_embeddings):
- # Embed prompts
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
- points=prompt_record.get("points", None),
- boxes=prompt_record.get("boxes", None),
- masks=prompt_record.get("mask_inputs", None),
- )
- # Predict masks
- low_res_logits, iou_predictions = self.mask_decoder(
- image_embeddings=curr_embedding[None, ...],
- image_pe=self.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- )
- # Save results
- outputs.append(SegmentationResults(low_res_logits, iou_predictions, self.mask_threshold))
- return outputs
- def _build_sam(
- encoder_embed_dim: int, encoder_depth: int, encoder_num_heads: int, encoder_global_attn_indexes: tuple[int, ...]
- ) -> Sam:
- prompt_embed_dim = 256
- image_size = 1024
- vit_patch_size = 16
- image_embedding_size = image_size // vit_patch_size
- return Sam(
- image_encoder=ImageEncoderViT(
- depth=encoder_depth,
- embed_dim=encoder_embed_dim,
- img_size=image_size,
- mlp_ratio=4,
- norm_layer=LayerNorm,
- num_heads=encoder_num_heads,
- patch_size=vit_patch_size,
- qkv_bias=True,
- use_rel_pos=True,
- global_attn_indexes=encoder_global_attn_indexes,
- window_size=14,
- out_chans=prompt_embed_dim,
- ),
- prompt_encoder=PromptEncoder(
- embed_dim=prompt_embed_dim,
- image_embedding_size=(image_embedding_size, image_embedding_size),
- input_image_size=(image_size, image_size),
- mask_in_chans=16,
- ),
- mask_decoder=MaskDecoder(
- num_multimask_outputs=3,
- transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
- transformer_dim=prompt_embed_dim,
- iou_head_depth=3,
- iou_head_hidden_dim=256,
- ),
- # pixel_mean=[123.675, 116.28, 103.53],
- # pixel_std=[58.395, 57.12, 57.375],
- )
|