visual_prompter.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Any, Optional
  19. import torch
  20. from kornia.augmentation import LongestMaxSize
  21. from kornia.augmentation.container.augment import AugmentationSequential
  22. from kornia.contrib.models import Prompts, SegmentationResults
  23. from kornia.contrib.models.sam import Sam, SamConfig
  24. from kornia.core import Tensor, pad, tensor
  25. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  26. from kornia.enhance import normalize
  27. from kornia.geometry.boxes import Boxes
  28. from kornia.geometry.keypoints import Keypoints
  29. class VisualPrompter:
  30. r"""Allow the user to run multiple query with multiple prompts for a model.
  31. At the moment, we just support the SAM model. The model is loaded based on the given config.
  32. For default the images are transformed to have their long side with size of the `image_encoder.img_size`. This
  33. Prompter class ensure to transform the images and the prompts before prediction. Also, the image is passed
  34. automatically for the method `preprocess_image`, which is responsible for normalize the image and pad it to have
  35. the right size for the SAM model :math:`(\text{image_encoder.img_size}, \text{image_encoder.img_size})`. For
  36. default the image is normalized by the mean and standard deviation of the SAM dataset values.
  37. Args:
  38. config: A model config to generate the model. Now just the SAM model is supported.
  39. device: The desired device to use the model.
  40. dtype: The desired dtype to use the model.
  41. Example:
  42. >>> # prompter = VisualPrompter() # Will load the vit h for default
  43. >>> # You can load a custom SAM type for modifying the config
  44. >>> prompter = VisualPrompter(SamConfig('vit_b'))
  45. >>> image = torch.rand(3, 25, 30)
  46. >>> prompter.set_image(image)
  47. >>> boxes = Boxes(
  48. ... torch.tensor(
  49. ... [[[[0, 0], [0, 10], [10, 0], [10, 10]]]],
  50. ... device=prompter.device,
  51. ... dtype=torch.float32
  52. ... ),
  53. ... mode='xyxy'
  54. ... )
  55. >>> prediction = prompter.predict(boxes=boxes)
  56. >>> prediction.logits.shape
  57. torch.Size([1, 3, 256, 256])
  58. """
  59. def __init__(
  60. self,
  61. config: Optional[SamConfig] = None,
  62. device: Optional[torch.device] = None,
  63. dtype: Optional[torch.dtype] = None,
  64. ) -> None:
  65. super().__init__()
  66. if config is None:
  67. config = SamConfig(model_type="vit_h", pretrained=True)
  68. if isinstance(config, SamConfig):
  69. self.model = Sam.from_config(config)
  70. transforms = (LongestMaxSize(self.model.image_encoder.img_size, p=1.0),)
  71. self.pixel_mean: Optional[Tensor] = tensor([123.675, 116.28, 103.53], device=device, dtype=dtype) / 255
  72. self.pixel_std: Optional[Tensor] = tensor([58.395, 57.12, 57.375], device=device, dtype=dtype) / 255
  73. else:
  74. raise NotImplementedError
  75. self.model = self.model.to(device=device, dtype=dtype)
  76. self.transforms = AugmentationSequential(*transforms, same_on_batch=True)
  77. self.device = device
  78. self.dtype = dtype
  79. self._original_image_size: None | tuple[int, int] = None
  80. self._input_image_size: None | tuple[int, int] = None
  81. self._input_encoder_size: None | tuple[int, int] = None
  82. self.reset_image()
  83. def preprocess_image(self, x: Tensor, mean: Optional[Tensor] = None, std: Optional[Tensor] = None) -> Tensor:
  84. """Normalize and pad a tensor.
  85. For normalize the tensor: will prioritize the `mean` and `std` passed as argument, if None will use the default
  86. Sam Dataset values.
  87. For pad the tensor: Will pad the tensor into the right and bottom to match with the size of
  88. `self.model.image_encoder.img_size`
  89. Args:
  90. x: The image to be preprocessed
  91. mean: Mean for each channel.
  92. std: Standard deviations for each channel.
  93. Returns:
  94. The image preprocessed (normalized if has mean and str available and padded to encoder size)
  95. """
  96. if isinstance(mean, Tensor) and isinstance(std, Tensor):
  97. x = normalize(x, mean, std)
  98. elif isinstance(self.pixel_mean, Tensor) and isinstance(self.pixel_std, Tensor):
  99. x = normalize(x, self.pixel_mean, self.pixel_std)
  100. encoder_im_size = self.model.image_encoder.img_size
  101. pad_h = encoder_im_size - x.shape[-2]
  102. pad_w = encoder_im_size - x.shape[-1]
  103. x = pad(x, (0, pad_w, 0, pad_h))
  104. return x
  105. @torch.no_grad()
  106. def set_image(self, image: Tensor, mean: Optional[Tensor] = None, std: Optional[Tensor] = None) -> None:
  107. """Set the embeddings from the given image with `image_decoder` of the model.
  108. Prepare the given image with the selected transforms and the preprocess method.
  109. Args:
  110. image: RGB image. Normally images with range of [0-1], the model preprocess normalize the
  111. pixel values with the mean and std defined in its initialization. Expected to be into a float32
  112. dtype. Shape :math:`(3, H, W)`.
  113. mean: mean value of dataset for normalization.
  114. std: standard deviation of dataset for normalization.
  115. """
  116. KORNIA_CHECK_SHAPE(image, ["3", "H", "W"])
  117. self.reset_image()
  118. self._original_image_size = (image.shape[-2], image.shape[-1])
  119. image = self.transforms(image, data_keys=["input"])
  120. self._tfs_params = self.transforms._params
  121. self._input_image_size = (image.shape[-2], image.shape[-1])
  122. image = self.preprocess_image(image, mean, std)
  123. self._input_encoder_size = (image.shape[-2], image.shape[-1])
  124. self.image_embeddings = self.model.image_encoder(image)
  125. self.is_image_set = True
  126. def _valid_keypoints(self, keypoints: Keypoints | Tensor, labels: Tensor) -> Keypoints:
  127. """Validate the keypoints shape and ensure to be a Keypoints."""
  128. KORNIA_CHECK_SHAPE(keypoints.data, ["K", "N", "2"])
  129. KORNIA_CHECK_SHAPE(labels.data, ["K", "N"])
  130. KORNIA_CHECK(keypoints.shape[0] == labels.shape[0], "The keypoints and labels should have the same batch size")
  131. if isinstance(keypoints, Tensor):
  132. keypoints = Keypoints.from_tensor(keypoints)
  133. return keypoints
  134. def _valid_boxes(self, boxes: Boxes | Tensor) -> Boxes:
  135. """Validate the boxes shape and ensure to be a Boxes into xyxy mode."""
  136. if isinstance(boxes, Tensor):
  137. KORNIA_CHECK_SHAPE(boxes.data, ["K", "4"])
  138. boxes = Boxes(boxes, mode="xyxy")
  139. if boxes.mode == "xyxy":
  140. boxes_xyxy = boxes
  141. else:
  142. boxes_xyxy = Boxes(boxes.to_tensor(mode="xyxy"), mode="xyxy")
  143. return boxes_xyxy
  144. def _valid_masks(self, masks: Tensor) -> Tensor:
  145. """Validate the input masks shape."""
  146. KORNIA_CHECK_SHAPE(masks, ["K", "1", "256", "256"])
  147. return masks
  148. def _transform_prompts(
  149. self, *prompts: Tensor | Boxes | Keypoints, data_keys: Optional[list[str]] = None
  150. ) -> dict[str, Tensor | Boxes | Keypoints]:
  151. transformed_prompts = self.transforms(*prompts, data_keys=data_keys, params=self._tfs_params)
  152. if data_keys is None:
  153. data_keys = []
  154. # prevent unpacking tensor when creating the output dict (issue #2627)
  155. if not isinstance(transformed_prompts, (list, tuple)):
  156. transformed_prompts = [transformed_prompts]
  157. return {key: transformed_prompts[idx] for idx, key in enumerate(data_keys)}
  158. def preprocess_prompts(
  159. self,
  160. keypoints: Optional[Keypoints | Tensor] = None,
  161. keypoints_labels: Optional[Tensor] = None,
  162. boxes: Optional[Boxes | Tensor] = None,
  163. masks: Optional[Tensor] = None,
  164. ) -> Prompts:
  165. """Validate and preprocess the given prompts to be aligned with the input image."""
  166. data_keys = []
  167. to_transform: list[Keypoints | Boxes | Tensor] = []
  168. if isinstance(keypoints, (Keypoints, Tensor)) and isinstance(keypoints_labels, Tensor):
  169. keypoints = self._valid_keypoints(keypoints, keypoints_labels)
  170. data_keys.append("keypoints")
  171. to_transform.append(keypoints)
  172. if isinstance(boxes, (Boxes, Tensor)):
  173. self._valid_boxes(boxes)
  174. data_keys.append("bbox_xyxy")
  175. to_transform.append(boxes)
  176. if isinstance(masks, Tensor):
  177. self._valid_masks(masks)
  178. data = self._transform_prompts(*to_transform, data_keys=data_keys)
  179. if "keypoints" in data and isinstance(data["keypoints"], Keypoints):
  180. kpts_tensor = data["keypoints"].to_tensor()
  181. if KORNIA_CHECK_IS_TENSOR(kpts_tensor) and KORNIA_CHECK_IS_TENSOR(keypoints_labels):
  182. points = (kpts_tensor, keypoints_labels)
  183. else:
  184. points = None
  185. if "bbox_xyxy" in data and isinstance(data["bbox_xyxy"], Boxes):
  186. _bbox = data["bbox_xyxy"].to_tensor(mode="xyxy")
  187. if KORNIA_CHECK_IS_TENSOR(_bbox):
  188. bbox = _bbox
  189. else:
  190. bbox = None
  191. return Prompts(points=points, boxes=bbox, masks=masks)
  192. @torch.no_grad()
  193. def predict(
  194. self,
  195. keypoints: Optional[Keypoints | Tensor] = None,
  196. keypoints_labels: Optional[Tensor] = None,
  197. boxes: Optional[Boxes | Tensor] = None,
  198. masks: Optional[Tensor] = None,
  199. multimask_output: bool = True,
  200. output_original_size: bool = True,
  201. ) -> SegmentationResults:
  202. """Predict masks for the given image based on the input prompts.
  203. Args:
  204. keypoints: Point prompts to the model. Each point is in (X,Y) in pixels. Shape :math:`(K, N, 2)`. Where
  205. `N` is the number of points and `K` the number of prompts.
  206. keypoints_labels: Labels for the point prompts. 1 indicates a foreground point and 0 indicates a background
  207. point. Shape :math:`(K, N)`. Where `N` is the number of points, and `K` the number of
  208. prompts.
  209. boxes: A box prompt to the model. If a tensor, should be in a xyxy mode. Shape :math:`(K, 4)`
  210. masks: A low resolution mask input to the model, typically coming from a previous prediction
  211. iteration. Has shape :math:`(K, 1, H, W)`, where for SAM, H=W=256.
  212. multimask_output: If true, the model will return three masks. For ambiguous input prompts (such as a
  213. single click), this will often produce better masks than a single prediction. If only
  214. a single mask is needed, the model's predicted quality score can be used to select the
  215. best mask. For non-ambiguous prompts, such as multiple input prompts,
  216. multimask_output=False can give better results.
  217. output_original_size: If true, the logits of `SegmentationResults` will be post-process to match the
  218. original input image size.
  219. Returns:
  220. A prediction with the logits and scores (IoU of each predicted mask)
  221. """
  222. KORNIA_CHECK(self.is_image_set, "An image must be set with `self.set_image(...)` before `predict` be called!")
  223. prompts = self.preprocess_prompts(keypoints, keypoints_labels, boxes, masks)
  224. # Embed prompts
  225. sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
  226. points=prompts.points, boxes=prompts.boxes, masks=prompts.masks
  227. )
  228. del prompts
  229. # Predict masks
  230. logits, scores = self.model.mask_decoder(
  231. image_embeddings=self.image_embeddings,
  232. image_pe=self.model.prompt_encoder.get_dense_pe(),
  233. sparse_prompt_embeddings=sparse_embeddings,
  234. dense_prompt_embeddings=dense_embeddings,
  235. multimask_output=multimask_output,
  236. )
  237. results = SegmentationResults(logits, scores)
  238. if (
  239. output_original_size
  240. and isinstance(self._input_image_size, tuple)
  241. and isinstance(self._original_image_size, tuple)
  242. ):
  243. results.original_res_logits(self._input_image_size, self._original_image_size, self._input_encoder_size)
  244. # results = results.squeeze(0)
  245. return results
  246. def reset_image(self) -> None:
  247. self._tfs_params = None
  248. self._original_image_size = None
  249. self._input_image_size = None
  250. self._input_encoder_size = None
  251. if hasattr(self, "image_embeddings"):
  252. del self.image_embeddings
  253. self.image_embeddings = None
  254. self.is_image_set = False
  255. def compile(
  256. self,
  257. *,
  258. fullgraph: bool = False,
  259. dynamic: bool = False,
  260. backend: str = "inductor",
  261. mode: Optional[str] = None,
  262. options: Optional[dict[Any, Any]] = None,
  263. disable: bool = False,
  264. ) -> None:
  265. """Apply `torch.compile(...)`/dynamo API into the VisualPrompter API.
  266. .. note:: For more information about the dynamo API check the official docs
  267. https://pytorch.org/docs/stable/generated/torch.compile.html
  268. Args:
  269. fullgraph: Whether it is ok to break model into several subgraphs
  270. dynamic: Use dynamic shape tracing
  271. backend: backend to be used
  272. mode: Can be either “default”, “reduce-overhead” or “max-autotune”
  273. options: A dictionary of options to pass to the backend.
  274. disable: Turn torch.compile() into a no-op for testing
  275. Example:
  276. >>> # prompter = VisualPrompter()
  277. >>> # prompter.compile() # You should have torch >= 2.0.0 installed
  278. >>> # Use the prompter methods ...
  279. """
  280. # self.set_image = torch.compile( # type: ignore[method-assign]
  281. # self.set_image,
  282. # fullgraph=fullgraph,
  283. # dynamic=dynamic,
  284. # backend=backend,
  285. # mode=mode,
  286. # options=options,
  287. # disable=disable,
  288. # )
  289. # FIXME: compile set image will try to compile AugmentationSequential which fails
  290. self.model.image_encoder = torch.compile( # type: ignore
  291. self.model.image_encoder,
  292. fullgraph=fullgraph,
  293. dynamic=dynamic,
  294. backend=backend,
  295. mode=mode,
  296. options=options,
  297. disable=disable,
  298. )
  299. # self.preprocess_image = torch.compile( # type: ignore[method-assign]
  300. # self.preprocess_image,
  301. # fullgraph=fullgraph,
  302. # dynamic=dynamic,
  303. # backend=backend,
  304. # mode=mode,
  305. # options=options,
  306. # disable=disable,
  307. # )
  308. # FIXME: compile predict will try to compile Preproc prompts, which need to compileAugmentationSequential
  309. # which fails
  310. # self.predict = torch.compile( # type: ignore[method-assign]
  311. # self.predict,
  312. # fullgraph=fullgraph,
  313. # dynamic=dynamic,
  314. # backend=backend,
  315. # mode=mode,
  316. # options=options,
  317. # disable=disable,
  318. # )
  319. self.model.mask_decoder = torch.compile( # type: ignore
  320. self.model.mask_decoder,
  321. fullgraph=fullgraph,
  322. dynamic=dynamic,
  323. backend=backend,
  324. mode=mode,
  325. options=options,
  326. disable=disable,
  327. )
  328. self.model.prompt_encoder = torch.compile( # type: ignore
  329. self.model.prompt_encoder,
  330. fullgraph=fullgraph,
  331. dynamic=dynamic,
  332. backend=backend,
  333. mode=mode,
  334. options=options,
  335. disable=disable,
  336. )