| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from typing import Any, Optional
- import torch
- from kornia.augmentation import LongestMaxSize
- from kornia.augmentation.container.augment import AugmentationSequential
- from kornia.contrib.models import Prompts, SegmentationResults
- from kornia.contrib.models.sam import Sam, SamConfig
- from kornia.core import Tensor, pad, tensor
- from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
- from kornia.enhance import normalize
- from kornia.geometry.boxes import Boxes
- from kornia.geometry.keypoints import Keypoints
- class VisualPrompter:
- r"""Allow the user to run multiple query with multiple prompts for a model.
- At the moment, we just support the SAM model. The model is loaded based on the given config.
- For default the images are transformed to have their long side with size of the `image_encoder.img_size`. This
- Prompter class ensure to transform the images and the prompts before prediction. Also, the image is passed
- automatically for the method `preprocess_image`, which is responsible for normalize the image and pad it to have
- the right size for the SAM model :math:`(\text{image_encoder.img_size}, \text{image_encoder.img_size})`. For
- default the image is normalized by the mean and standard deviation of the SAM dataset values.
- Args:
- config: A model config to generate the model. Now just the SAM model is supported.
- device: The desired device to use the model.
- dtype: The desired dtype to use the model.
- Example:
- >>> # prompter = VisualPrompter() # Will load the vit h for default
- >>> # You can load a custom SAM type for modifying the config
- >>> prompter = VisualPrompter(SamConfig('vit_b'))
- >>> image = torch.rand(3, 25, 30)
- >>> prompter.set_image(image)
- >>> boxes = Boxes(
- ... torch.tensor(
- ... [[[[0, 0], [0, 10], [10, 0], [10, 10]]]],
- ... device=prompter.device,
- ... dtype=torch.float32
- ... ),
- ... mode='xyxy'
- ... )
- >>> prediction = prompter.predict(boxes=boxes)
- >>> prediction.logits.shape
- torch.Size([1, 3, 256, 256])
- """
- def __init__(
- self,
- config: Optional[SamConfig] = None,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ) -> None:
- super().__init__()
- if config is None:
- config = SamConfig(model_type="vit_h", pretrained=True)
- if isinstance(config, SamConfig):
- self.model = Sam.from_config(config)
- transforms = (LongestMaxSize(self.model.image_encoder.img_size, p=1.0),)
- self.pixel_mean: Optional[Tensor] = tensor([123.675, 116.28, 103.53], device=device, dtype=dtype) / 255
- self.pixel_std: Optional[Tensor] = tensor([58.395, 57.12, 57.375], device=device, dtype=dtype) / 255
- else:
- raise NotImplementedError
- self.model = self.model.to(device=device, dtype=dtype)
- self.transforms = AugmentationSequential(*transforms, same_on_batch=True)
- self.device = device
- self.dtype = dtype
- self._original_image_size: None | tuple[int, int] = None
- self._input_image_size: None | tuple[int, int] = None
- self._input_encoder_size: None | tuple[int, int] = None
- self.reset_image()
- def preprocess_image(self, x: Tensor, mean: Optional[Tensor] = None, std: Optional[Tensor] = None) -> Tensor:
- """Normalize and pad a tensor.
- For normalize the tensor: will prioritize the `mean` and `std` passed as argument, if None will use the default
- Sam Dataset values.
- For pad the tensor: Will pad the tensor into the right and bottom to match with the size of
- `self.model.image_encoder.img_size`
- Args:
- x: The image to be preprocessed
- mean: Mean for each channel.
- std: Standard deviations for each channel.
- Returns:
- The image preprocessed (normalized if has mean and str available and padded to encoder size)
- """
- if isinstance(mean, Tensor) and isinstance(std, Tensor):
- x = normalize(x, mean, std)
- elif isinstance(self.pixel_mean, Tensor) and isinstance(self.pixel_std, Tensor):
- x = normalize(x, self.pixel_mean, self.pixel_std)
- encoder_im_size = self.model.image_encoder.img_size
- pad_h = encoder_im_size - x.shape[-2]
- pad_w = encoder_im_size - x.shape[-1]
- x = pad(x, (0, pad_w, 0, pad_h))
- return x
- @torch.no_grad()
- def set_image(self, image: Tensor, mean: Optional[Tensor] = None, std: Optional[Tensor] = None) -> None:
- """Set the embeddings from the given image with `image_decoder` of the model.
- Prepare the given image with the selected transforms and the preprocess method.
- Args:
- image: RGB image. Normally images with range of [0-1], the model preprocess normalize the
- pixel values with the mean and std defined in its initialization. Expected to be into a float32
- dtype. Shape :math:`(3, H, W)`.
- mean: mean value of dataset for normalization.
- std: standard deviation of dataset for normalization.
- """
- KORNIA_CHECK_SHAPE(image, ["3", "H", "W"])
- self.reset_image()
- self._original_image_size = (image.shape[-2], image.shape[-1])
- image = self.transforms(image, data_keys=["input"])
- self._tfs_params = self.transforms._params
- self._input_image_size = (image.shape[-2], image.shape[-1])
- image = self.preprocess_image(image, mean, std)
- self._input_encoder_size = (image.shape[-2], image.shape[-1])
- self.image_embeddings = self.model.image_encoder(image)
- self.is_image_set = True
- def _valid_keypoints(self, keypoints: Keypoints | Tensor, labels: Tensor) -> Keypoints:
- """Validate the keypoints shape and ensure to be a Keypoints."""
- KORNIA_CHECK_SHAPE(keypoints.data, ["K", "N", "2"])
- KORNIA_CHECK_SHAPE(labels.data, ["K", "N"])
- KORNIA_CHECK(keypoints.shape[0] == labels.shape[0], "The keypoints and labels should have the same batch size")
- if isinstance(keypoints, Tensor):
- keypoints = Keypoints.from_tensor(keypoints)
- return keypoints
- def _valid_boxes(self, boxes: Boxes | Tensor) -> Boxes:
- """Validate the boxes shape and ensure to be a Boxes into xyxy mode."""
- if isinstance(boxes, Tensor):
- KORNIA_CHECK_SHAPE(boxes.data, ["K", "4"])
- boxes = Boxes(boxes, mode="xyxy")
- if boxes.mode == "xyxy":
- boxes_xyxy = boxes
- else:
- boxes_xyxy = Boxes(boxes.to_tensor(mode="xyxy"), mode="xyxy")
- return boxes_xyxy
- def _valid_masks(self, masks: Tensor) -> Tensor:
- """Validate the input masks shape."""
- KORNIA_CHECK_SHAPE(masks, ["K", "1", "256", "256"])
- return masks
- def _transform_prompts(
- self, *prompts: Tensor | Boxes | Keypoints, data_keys: Optional[list[str]] = None
- ) -> dict[str, Tensor | Boxes | Keypoints]:
- transformed_prompts = self.transforms(*prompts, data_keys=data_keys, params=self._tfs_params)
- if data_keys is None:
- data_keys = []
- # prevent unpacking tensor when creating the output dict (issue #2627)
- if not isinstance(transformed_prompts, (list, tuple)):
- transformed_prompts = [transformed_prompts]
- return {key: transformed_prompts[idx] for idx, key in enumerate(data_keys)}
- def preprocess_prompts(
- self,
- keypoints: Optional[Keypoints | Tensor] = None,
- keypoints_labels: Optional[Tensor] = None,
- boxes: Optional[Boxes | Tensor] = None,
- masks: Optional[Tensor] = None,
- ) -> Prompts:
- """Validate and preprocess the given prompts to be aligned with the input image."""
- data_keys = []
- to_transform: list[Keypoints | Boxes | Tensor] = []
- if isinstance(keypoints, (Keypoints, Tensor)) and isinstance(keypoints_labels, Tensor):
- keypoints = self._valid_keypoints(keypoints, keypoints_labels)
- data_keys.append("keypoints")
- to_transform.append(keypoints)
- if isinstance(boxes, (Boxes, Tensor)):
- self._valid_boxes(boxes)
- data_keys.append("bbox_xyxy")
- to_transform.append(boxes)
- if isinstance(masks, Tensor):
- self._valid_masks(masks)
- data = self._transform_prompts(*to_transform, data_keys=data_keys)
- if "keypoints" in data and isinstance(data["keypoints"], Keypoints):
- kpts_tensor = data["keypoints"].to_tensor()
- if KORNIA_CHECK_IS_TENSOR(kpts_tensor) and KORNIA_CHECK_IS_TENSOR(keypoints_labels):
- points = (kpts_tensor, keypoints_labels)
- else:
- points = None
- if "bbox_xyxy" in data and isinstance(data["bbox_xyxy"], Boxes):
- _bbox = data["bbox_xyxy"].to_tensor(mode="xyxy")
- if KORNIA_CHECK_IS_TENSOR(_bbox):
- bbox = _bbox
- else:
- bbox = None
- return Prompts(points=points, boxes=bbox, masks=masks)
- @torch.no_grad()
- def predict(
- self,
- keypoints: Optional[Keypoints | Tensor] = None,
- keypoints_labels: Optional[Tensor] = None,
- boxes: Optional[Boxes | Tensor] = None,
- masks: Optional[Tensor] = None,
- multimask_output: bool = True,
- output_original_size: bool = True,
- ) -> SegmentationResults:
- """Predict masks for the given image based on the input prompts.
- Args:
- keypoints: Point prompts to the model. Each point is in (X,Y) in pixels. Shape :math:`(K, N, 2)`. Where
- `N` is the number of points and `K` the number of prompts.
- keypoints_labels: Labels for the point prompts. 1 indicates a foreground point and 0 indicates a background
- point. Shape :math:`(K, N)`. Where `N` is the number of points, and `K` the number of
- prompts.
- boxes: A box prompt to the model. If a tensor, should be in a xyxy mode. Shape :math:`(K, 4)`
- masks: A low resolution mask input to the model, typically coming from a previous prediction
- iteration. Has shape :math:`(K, 1, H, W)`, where for SAM, H=W=256.
- multimask_output: If true, the model will return three masks. For ambiguous input prompts (such as a
- single click), this will often produce better masks than a single prediction. If only
- a single mask is needed, the model's predicted quality score can be used to select the
- best mask. For non-ambiguous prompts, such as multiple input prompts,
- multimask_output=False can give better results.
- output_original_size: If true, the logits of `SegmentationResults` will be post-process to match the
- original input image size.
- Returns:
- A prediction with the logits and scores (IoU of each predicted mask)
- """
- KORNIA_CHECK(self.is_image_set, "An image must be set with `self.set_image(...)` before `predict` be called!")
- prompts = self.preprocess_prompts(keypoints, keypoints_labels, boxes, masks)
- # Embed prompts
- sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
- points=prompts.points, boxes=prompts.boxes, masks=prompts.masks
- )
- del prompts
- # Predict masks
- logits, scores = self.model.mask_decoder(
- image_embeddings=self.image_embeddings,
- image_pe=self.model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- )
- results = SegmentationResults(logits, scores)
- if (
- output_original_size
- and isinstance(self._input_image_size, tuple)
- and isinstance(self._original_image_size, tuple)
- ):
- results.original_res_logits(self._input_image_size, self._original_image_size, self._input_encoder_size)
- # results = results.squeeze(0)
- return results
- def reset_image(self) -> None:
- self._tfs_params = None
- self._original_image_size = None
- self._input_image_size = None
- self._input_encoder_size = None
- if hasattr(self, "image_embeddings"):
- del self.image_embeddings
- self.image_embeddings = None
- self.is_image_set = False
- def compile(
- self,
- *,
- fullgraph: bool = False,
- dynamic: bool = False,
- backend: str = "inductor",
- mode: Optional[str] = None,
- options: Optional[dict[Any, Any]] = None,
- disable: bool = False,
- ) -> None:
- """Apply `torch.compile(...)`/dynamo API into the VisualPrompter API.
- .. note:: For more information about the dynamo API check the official docs
- https://pytorch.org/docs/stable/generated/torch.compile.html
- Args:
- fullgraph: Whether it is ok to break model into several subgraphs
- dynamic: Use dynamic shape tracing
- backend: backend to be used
- mode: Can be either “default”, “reduce-overhead” or “max-autotune”
- options: A dictionary of options to pass to the backend.
- disable: Turn torch.compile() into a no-op for testing
- Example:
- >>> # prompter = VisualPrompter()
- >>> # prompter.compile() # You should have torch >= 2.0.0 installed
- >>> # Use the prompter methods ...
- """
- # self.set_image = torch.compile( # type: ignore[method-assign]
- # self.set_image,
- # fullgraph=fullgraph,
- # dynamic=dynamic,
- # backend=backend,
- # mode=mode,
- # options=options,
- # disable=disable,
- # )
- # FIXME: compile set image will try to compile AugmentationSequential which fails
- self.model.image_encoder = torch.compile( # type: ignore
- self.model.image_encoder,
- fullgraph=fullgraph,
- dynamic=dynamic,
- backend=backend,
- mode=mode,
- options=options,
- disable=disable,
- )
- # self.preprocess_image = torch.compile( # type: ignore[method-assign]
- # self.preprocess_image,
- # fullgraph=fullgraph,
- # dynamic=dynamic,
- # backend=backend,
- # mode=mode,
- # options=options,
- # disable=disable,
- # )
- # FIXME: compile predict will try to compile Preproc prompts, which need to compileAugmentationSequential
- # which fails
- # self.predict = torch.compile( # type: ignore[method-assign]
- # self.predict,
- # fullgraph=fullgraph,
- # dynamic=dynamic,
- # backend=backend,
- # mode=mode,
- # options=options,
- # disable=disable,
- # )
- self.model.mask_decoder = torch.compile( # type: ignore
- self.model.mask_decoder,
- fullgraph=fullgraph,
- dynamic=dynamic,
- backend=backend,
- mode=mode,
- options=options,
- disable=disable,
- )
- self.model.prompt_encoder = torch.compile( # type: ignore
- self.model.prompt_encoder,
- fullgraph=fullgraph,
- dynamic=dynamic,
- backend=backend,
- mode=mode,
- options=options,
- disable=disable,
- )
|