clip_iqa.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from torchmetrics.functional.multimodal.clip_iqa import (
  19. _clip_iqa_compute,
  20. _clip_iqa_format_prompts,
  21. _clip_iqa_get_anchor_vectors,
  22. _clip_iqa_update,
  23. _get_clip_iqa_model_and_processor,
  24. )
  25. from torchmetrics.metric import Metric
  26. from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
  27. from torchmetrics.utilities.data import dim_zero_cat
  28. from torchmetrics.utilities.imports import (
  29. _MATPLOTLIB_AVAILABLE,
  30. _PIQ_GREATER_EQUAL_0_8,
  31. _TRANSFORMERS_GREATER_EQUAL_4_10,
  32. )
  33. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  34. if not _PIQ_GREATER_EQUAL_0_8:
  35. __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]
  36. if not _MATPLOTLIB_AVAILABLE:
  37. __doctest_skip__ = ["CLIPImageQualityAssessment.plot"]
  38. if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
  39. from transformers import CLIPModel as _CLIPModel
  40. from transformers import CLIPProcessor as _CLIPProcessor
  41. def _download_clip_iqa_metric() -> None:
  42. _CLIPModel.from_pretrained("openai/clip-vit-large-patch14", resume_download=True)
  43. _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", resume_download=True)
  44. if not _try_proceed_with_timeout(_download_clip_iqa_metric):
  45. __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]
  46. else:
  47. __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]
  48. class CLIPImageQualityAssessment(Metric):
  49. """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images.
  50. The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to
  51. be able to generate a vector representation of the image and the text that is similar if the image and text are
  52. semantically similar.
  53. The metric works by calculating the cosine similarity between user provided images and pre-defined prompts. The
  54. prompts always comes in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating
  55. the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine
  56. which prompt the image is more similar to. The metric then returns the probability that the image is more similar
  57. to the first prompt than the second prompt.
  58. Build in prompts are:
  59. * quality: "Good photo." vs "Bad photo."
  60. * brightness: "Bright photo." vs "Dark photo."
  61. * noisiness: "Clean photo." vs "Noisy photo."
  62. * colorfullness: "Colorful photo." vs "Dull photo."
  63. * sharpness: "Sharp photo." vs "Blurry photo."
  64. * contrast: "High contrast photo." vs "Low contrast photo."
  65. * complexity: "Complex photo." vs "Simple photo."
  66. * natural: "Natural photo." vs "Synthetic photo."
  67. * happy: "Happy photo." vs "Sad photo."
  68. * scary: "Scary photo." vs "Peaceful photo."
  69. * new: "New photo." vs "Old photo."
  70. * warm: "Warm photo." vs "Cold photo."
  71. * real: "Real photo." vs "Abstract photo."
  72. * beautiful: "Beautiful photo." vs "Ugly photo."
  73. * lonely: "Lonely photo." vs "Sociable photo."
  74. * relaxing: "Relaxing photo." vs "Stressful photo."
  75. As input to ``forward`` and ``update`` the metric accepts the following input
  76. - ``images`` (:class:`~torch.Tensor`): tensor with images feed to the feature extractor with shape ``(N,C,H,W)``
  77. As output of `forward` and `compute` the metric returns the following output
  78. - ``clip_iqa`` (:class:`~torch.Tensor` or dict of tensors): tensor with the CLIP-IQA score. If a single prompt is
  79. provided, a single tensor with shape ``(N,)`` is returned. If a list of prompts is provided, a dict of tensors
  80. is returned with the prompt as key and the tensor with shape ``(N,)`` as value.
  81. Args:
  82. model_name_or_path: string indicating the version of the CLIP model to use. Available models are:
  83. - `"clip_iqa"`, model corresponding to the CLIP-IQA paper.
  84. - `"openai/clip-vit-base-patch16"`
  85. - `"openai/clip-vit-base-patch32"`
  86. - `"openai/clip-vit-large-patch14-336"`
  87. - `"openai/clip-vit-large-patch14"`
  88. data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255],
  89. data_range should be 255. The images are normalized by this value.
  90. prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one
  91. of the available prompts (see above). Else the input is expected to be a tuple, where each element can
  92. be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the
  93. available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a
  94. positive prompt and the second string must be a negative prompt.
  95. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  96. .. hint::
  97. If using the default `clip_iqa` model, the package `piq` must be installed. Either install with
  98. `pip install piq` or `pip install torchmetrics[image]`.
  99. Raises:
  100. ModuleNotFoundError:
  101. If transformers package is not installed or version is lower than 4.10.0
  102. ValueError:
  103. If `prompts` is a tuple and it is not of length 2
  104. ValueError:
  105. If `prompts` is a string and it is not one of the available prompts
  106. ValueError:
  107. If `prompts` is a list of strings and not all strings are one of the available prompts
  108. Example::
  109. Single prompt:
  110. >>> from torch import randint
  111. >>> from torchmetrics.multimodal import CLIPImageQualityAssessment
  112. >>> imgs = randint(255, (2, 3, 224, 224)).float()
  113. >>> metric = CLIPImageQualityAssessment()
  114. >>> metric(imgs)
  115. tensor([0.8894, 0.8902])
  116. Example::
  117. Multiple prompts:
  118. >>> from torch import randint
  119. >>> from torchmetrics.multimodal import CLIPImageQualityAssessment
  120. >>> imgs = randint(255, (2, 3, 224, 224)).float()
  121. >>> metric = CLIPImageQualityAssessment(prompts=("quality", "brightness"))
  122. >>> metric(imgs)
  123. {'quality': tensor([0.8693, 0.8705]), 'brightness': tensor([0.5722, 0.4762])}
  124. Example::
  125. Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt.
  126. >>> from torch import randint
  127. >>> from torchmetrics.multimodal import CLIPImageQualityAssessment
  128. >>> imgs = randint(255, (2, 3, 224, 224)).float()
  129. >>> metric = CLIPImageQualityAssessment(prompts=(("Super good photo.", "Super bad photo."), "brightness"))
  130. >>> metric(imgs)
  131. {'user_defined_0': tensor([0.9578, 0.9654]), 'brightness': tensor([0.5495, 0.5764])}
  132. """
  133. is_differentiable: bool = False
  134. higher_is_better: bool = True
  135. full_state_update: bool = True
  136. plot_lower_bound = 0.0
  137. plot_upper_bound = 100.0
  138. anchors: Tensor
  139. probs_list: List[Tensor]
  140. feature_network: str = "model"
  141. def __init__(
  142. self,
  143. model_name_or_path: Literal[
  144. "clip_iqa",
  145. "openai/clip-vit-base-patch16",
  146. "openai/clip-vit-base-patch32",
  147. "openai/clip-vit-large-patch14-336",
  148. "openai/clip-vit-large-patch14",
  149. ] = "clip_iqa",
  150. data_range: float = 1.0,
  151. prompts: tuple[Union[str, tuple[str, str]], ...] = ("quality",),
  152. **kwargs: Any,
  153. ) -> None:
  154. super().__init__(**kwargs)
  155. if not (isinstance(data_range, (int, float)) and data_range > 0):
  156. raise ValueError("Argument `data_range` should be a positive number.")
  157. self.data_range = data_range
  158. prompts_list, prompts_name = _clip_iqa_format_prompts(prompts)
  159. self.prompts_list = prompts_list
  160. self.prompts_name = prompts_name
  161. self.model, self.processor = _get_clip_iqa_model_and_processor(model_name_or_path)
  162. self.model_name_or_path = model_name_or_path
  163. with torch.inference_mode():
  164. anchors = _clip_iqa_get_anchor_vectors(
  165. model_name_or_path, self.model, self.processor, self.prompts_list, self.device
  166. )
  167. self.register_buffer("anchors", anchors)
  168. self.add_state("probs_list", [], dist_reduce_fx="cat")
  169. def update(self, images: Tensor) -> None:
  170. """Update metric state with new data."""
  171. with torch.inference_mode():
  172. img_features = _clip_iqa_update(
  173. self.model_name_or_path, images, self.model, self.processor, self.data_range, self.device
  174. )
  175. probs = _clip_iqa_compute(img_features, self.anchors, self.prompts_name, format_as_dict=False)
  176. if not isinstance(probs, Tensor):
  177. raise ValueError("Output probs should be a tensor")
  178. self.probs_list.append(probs)
  179. def compute(self) -> Union[Tensor, dict[str, Tensor]]:
  180. """Compute metric."""
  181. probs = dim_zero_cat(self.probs_list)
  182. if len(self.prompts_name) == 1:
  183. return probs.squeeze()
  184. return {p: probs[:, i] for i, p in enumerate(self.prompts_name)}
  185. def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
  186. """Plot a single or multiple values from the metric.
  187. Args:
  188. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  189. If no value is provided, will automatically call `metric.compute` and plot that result.
  190. ax: An matplotlib axis object. If provided will add plot to that axis
  191. Returns:
  192. Figure and Axes object
  193. Raises:
  194. ModuleNotFoundError:
  195. If `matplotlib` is not installed
  196. .. plot::
  197. :scale: 75
  198. >>> # Example plotting a single value
  199. >>> import torch
  200. >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment
  201. >>> metric = CLIPImageQualityAssessment()
  202. >>> metric.update(torch.rand(1, 3, 224, 224))
  203. >>> fig_, ax_ = metric.plot()
  204. .. plot::
  205. :scale: 75
  206. >>> # Example plotting multiple values
  207. >>> import torch
  208. >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment
  209. >>> metric = CLIPImageQualityAssessment()
  210. >>> values = [ ]
  211. >>> for _ in range(10):
  212. ... values.append(metric(torch.rand(1, 3, 224, 224)))
  213. >>> fig_, ax_ = metric.plot(values)
  214. """
  215. return self._plot(val, ax)
  216. if TYPE_CHECKING:
  217. f = CLIPImageQualityAssessment
  218. f(prompts=("colorfullness",))
  219. f(
  220. prompts=("quality", "brightness", "noisiness"),
  221. )
  222. f(
  223. prompts=("quality", "brightness", "noisiness", "colorfullness"),
  224. )
  225. f(prompts=(("Photo of a cat", "Photo of a dog"), "quality", ("Colorful photo", "Black and white photo")))