clip_iqa.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Literal, Union
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.functional.multimodal.clip_score import _get_clip_model_and_processor
  18. from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
  19. from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10
  20. if TYPE_CHECKING:
  21. from transformers import CLIPModel as _CLIPModel
  22. from transformers import CLIPProcessor as _CLIPProcessor
  23. if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
  24. from transformers import CLIPModel as _CLIPModel
  25. from transformers import CLIPProcessor as _CLIPProcessor
  26. def _download_clip_for_iqa_metric() -> None:
  27. _CLIPModel.from_pretrained("openai/clip-vit-base-patch16", resume_download=True)
  28. _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16", resume_download=True)
  29. if not _try_proceed_with_timeout(_download_clip_for_iqa_metric):
  30. __doctest_skip__ = ["clip_image_quality_assessment"]
  31. else:
  32. __doctest_skip__ = ["clip_image_quality_assessment"]
  33. if not _PIQ_GREATER_EQUAL_0_8:
  34. __doctest_skip__ = ["clip_image_quality_assessment"]
  35. _PROMPTS: dict[str, tuple[str, str]] = {
  36. "quality": ("Good photo.", "Bad photo."),
  37. "brightness": ("Bright photo.", "Dark photo."),
  38. "noisiness": ("Clean photo.", "Noisy photo."),
  39. "colorfullness": ("Colorful photo.", "Dull photo."),
  40. "sharpness": ("Sharp photo.", "Blurry photo."),
  41. "contrast": ("High contrast photo.", "Low contrast photo."),
  42. "complexity": ("Complex photo.", "Simple photo."),
  43. "natural": ("Natural photo.", "Synthetic photo."),
  44. "happy": ("Happy photo.", "Sad photo."),
  45. "scary": ("Scary photo.", "Peaceful photo."),
  46. "new": ("New photo.", "Old photo."),
  47. "warm": ("Warm photo.", "Cold photo."),
  48. "real": ("Real photo.", "Abstract photo."),
  49. "beautiful": ("Beautiful photo.", "Ugly photo."),
  50. "lonely": ("Lonely photo.", "Sociable photo."),
  51. "relaxing": ("Relaxing photo.", "Stressful photo."),
  52. }
  53. def _get_clip_iqa_model_and_processor(
  54. model_name_or_path: Literal[
  55. "clip_iqa",
  56. "openai/clip-vit-base-patch16",
  57. "openai/clip-vit-base-patch32",
  58. "openai/clip-vit-large-patch14-336",
  59. "openai/clip-vit-large-patch14",
  60. ],
  61. ) -> tuple["_CLIPModel", "_CLIPProcessor"]:
  62. """Extract the CLIP model and processor from the model name or path."""
  63. from transformers import CLIPProcessor as _CLIPProcessor
  64. if model_name_or_path == "clip_iqa":
  65. if not _PIQ_GREATER_EQUAL_0_8:
  66. raise ValueError(
  67. "For metric `clip_iqa` to work with argument `model_name_or_path` set to default value `'clip_iqa'`"
  68. ", package `piq` version v0.8.0 or later must be installed. Either install with `pip install piq` or"
  69. "`pip install torchmetrics[multimodal]`"
  70. )
  71. import piq
  72. model = piq.clip_iqa.clip.load().eval()
  73. # any model checkpoint can be used here because the tokenizer is the same for all
  74. processor = _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
  75. return model, processor
  76. return _get_clip_model_and_processor(model_name_or_path)
  77. def _clip_iqa_format_prompts(
  78. prompts: tuple[Union[str, tuple[str, str]], ...] = ("quality",),
  79. ) -> tuple[list[str], list[str]]:
  80. """Converts the provided keywords into a list of prompts for the model to calculate the anchor vectors.
  81. Args:
  82. prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one
  83. of the available prompts (see above). Else the input is expected to be a tuple, where each element can
  84. be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the
  85. available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a
  86. positive prompt and the second string must be a negative prompt.
  87. Returns:
  88. Tuple containing a list of prompts and a list of the names of the prompts. The first list is double the length
  89. of the second list.
  90. Examples::
  91. >>> # single prompt
  92. >>> _clip_iqa_format_prompts(("quality",))
  93. (['Good photo.', 'Bad photo.'], ['quality'])
  94. >>> # multiple prompts
  95. >>> _clip_iqa_format_prompts(("quality", "brightness"))
  96. (['Good photo.', 'Bad photo.', 'Bright photo.', 'Dark photo.'], ['quality', 'brightness'])
  97. >>> # Custom prompts
  98. >>> _clip_iqa_format_prompts(("quality", ("Super good photo.", "Super bad photo.")))
  99. (['Good photo.', 'Bad photo.', 'Super good photo.', 'Super bad photo.'], ['quality', 'user_defined_0'])
  100. """
  101. if not isinstance(prompts, tuple):
  102. raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings")
  103. prompts_names: list[str] = []
  104. prompts_list: list[str] = []
  105. count = 0
  106. for p in prompts:
  107. if not isinstance(p, (str, tuple)):
  108. raise ValueError("Argument `prompts` must be a tuple containing strings or tuples of strings")
  109. if isinstance(p, str):
  110. if p not in _PROMPTS:
  111. raise ValueError(
  112. f"All elements of `prompts` must be one of {_PROMPTS.keys()} if not custom tuple prompts, got {p}."
  113. )
  114. prompts_names.append(p)
  115. prompts_list.extend(_PROMPTS[p])
  116. if isinstance(p, tuple) and len(p) != 2:
  117. raise ValueError("If a tuple is provided in argument `prompts`, it must be of length 2")
  118. if isinstance(p, tuple):
  119. prompts_names.append(f"user_defined_{count}")
  120. prompts_list.extend(p)
  121. count += 1
  122. return prompts_list, prompts_names
  123. def _clip_iqa_get_anchor_vectors(
  124. model_name_or_path: str,
  125. model: "_CLIPModel",
  126. processor: "_CLIPProcessor",
  127. prompts_list: list[str],
  128. device: Union[str, torch.device],
  129. ) -> Tensor:
  130. """Calculates the anchor vectors for the CLIP IQA metric.
  131. Args:
  132. model_name_or_path: string indicating the version of the CLIP model to use.
  133. model: The CLIP model
  134. processor: The CLIP processor
  135. prompts_list: A list of prompts
  136. device: The device to use for the calculation
  137. """
  138. if model_name_or_path == "clip_iqa":
  139. text_processed = processor(text=prompts_list)
  140. anchors_text = torch.zeros(
  141. len(prompts_list), processor.tokenizer.model_max_length, dtype=torch.long, device=device
  142. )
  143. for i, tp in enumerate(text_processed["input_ids"]):
  144. anchors_text[i, : len(tp)] = torch.tensor(tp, dtype=torch.long, device=device)
  145. anchors = model.encode_text(anchors_text).float()
  146. else:
  147. text_processed = processor(text=prompts_list, return_tensors="pt", padding=True)
  148. anchors = model.get_text_features(
  149. text_processed["input_ids"].to(device), text_processed["attention_mask"].to(device)
  150. )
  151. return anchors / anchors.norm(p=2, dim=-1, keepdim=True)
  152. def _clip_iqa_update(
  153. model_name_or_path: str,
  154. images: Tensor,
  155. model: "_CLIPModel",
  156. processor: "_CLIPProcessor",
  157. data_range: float,
  158. device: Union[str, torch.device],
  159. ) -> Tensor:
  160. images = images / float(data_range)
  161. """Update function for CLIP IQA."""
  162. if model_name_or_path == "clip_iqa":
  163. # default mean and std from clip paper, see:
  164. # https://github.com/huggingface/transformers/blob/main/src/transformers/utils/constants.py
  165. default_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
  166. default_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)
  167. images = (images - default_mean) / default_std
  168. img_features = model.encode_image(images.float(), pos_embedding=False).float()
  169. else:
  170. processed_input = processor(images=[i.cpu() for i in images], return_tensors="pt", padding=True)
  171. img_features = model.get_image_features(processed_input["pixel_values"].to(device))
  172. return img_features / img_features.norm(p=2, dim=-1, keepdim=True)
  173. def _clip_iqa_compute(
  174. img_features: Tensor,
  175. anchors: Tensor,
  176. prompts_names: list[str],
  177. format_as_dict: bool = True,
  178. ) -> Union[Tensor, dict[str, Tensor]]:
  179. """Final computation of CLIP IQA."""
  180. logits_per_image = 100 * img_features @ anchors.t()
  181. probs = logits_per_image.reshape(logits_per_image.shape[0], -1, 2).softmax(-1)[:, :, 0]
  182. if len(prompts_names) == 1:
  183. return probs.squeeze()
  184. if format_as_dict:
  185. return {p: probs[:, i] for i, p in enumerate(prompts_names)}
  186. return probs
  187. def clip_image_quality_assessment(
  188. images: Tensor,
  189. model_name_or_path: Literal[
  190. "clip_iqa",
  191. "openai/clip-vit-base-patch16",
  192. "openai/clip-vit-base-patch32",
  193. "openai/clip-vit-large-patch14-336",
  194. "openai/clip-vit-large-patch14",
  195. ] = "clip_iqa",
  196. data_range: float = 1.0,
  197. prompts: tuple[Union[str, tuple[str, str]], ...] = ("quality",),
  198. ) -> Union[Tensor, dict[str, Tensor]]:
  199. """Calculates `CLIP-IQA`_, that can be used to measure the visual content of images.
  200. The metric is based on the `CLIP`_ model, which is a neural network trained on a variety of (image, text) pairs to
  201. be able to generate a vector representation of the image and the text that is similar if the image and text are
  202. semantically similar.
  203. The metric works by calculating the cosine similarity between user provided images and pre-defined prompts. The
  204. prompts always come in pairs of "positive" and "negative" such as "Good photo." and "Bad photo.". By calculating
  205. the similartity between image embeddings and both the "positive" and "negative" prompt, the metric can determine
  206. which prompt the image is more similar to. The metric then returns the probability that the image is more similar
  207. to the first prompt than the second prompt.
  208. Build in prompts are:
  209. * quality: "Good photo." vs "Bad photo."
  210. * brightness: "Bright photo." vs "Dark photo."
  211. * noisiness: "Clean photo." vs "Noisy photo."
  212. * colorfullness: "Colorful photo." vs "Dull photo."
  213. * sharpness: "Sharp photo." vs "Blurry photo."
  214. * contrast: "High contrast photo." vs "Low contrast photo."
  215. * complexity: "Complex photo." vs "Simple photo."
  216. * natural: "Natural photo." vs "Synthetic photo."
  217. * happy: "Happy photo." vs "Sad photo."
  218. * scary: "Scary photo." vs "Peaceful photo."
  219. * new: "New photo." vs "Old photo."
  220. * warm: "Warm photo." vs "Cold photo."
  221. * real: "Real photo." vs "Abstract photo."
  222. * beautiful: "Beautiful photo." vs "Ugly photo."
  223. * lonely: "Lonely photo." vs "Sociable photo."
  224. * relaxing: "Relaxing photo." vs "Stressful photo."
  225. Args:
  226. images: Either a single ``[N, C, H, W]`` tensor or a list of ``[C, H, W]`` tensors
  227. model_name_or_path: string indicating the version of the CLIP model to use. By default this argument is set to
  228. ``clip_iqa`` which corresponds to the model used in the original paper. Other available models are
  229. `"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
  230. and `"openai/clip-vit-large-patch14"`
  231. data_range: The maximum value of the input tensor. For example, if the input images are in range [0, 255],
  232. data_range should be 255. The images are normalized by this value.
  233. prompts: A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one
  234. of the available prompts (see above). Else the input is expected to be a tuple, where each element can
  235. be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the
  236. available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a
  237. positive prompt and the second string must be a negative prompt.
  238. .. hint::
  239. If using the default `clip_iqa` model, the package `piq` must be installed. Either install with
  240. `pip install piq` or `pip install torchmetrics[multimodal]`.
  241. Returns:
  242. A tensor of shape ``(N,)`` if a single prompts is provided. If a list of prompts is provided, a dictionary of
  243. with the prompts as keys and tensors of shape ``(N,)`` as values.
  244. Raises:
  245. ModuleNotFoundError:
  246. If transformers package is not installed or version is lower than 4.10.0
  247. ValueError:
  248. If not all images have format [C, H, W]
  249. ValueError:
  250. If prompts is a tuple and it is not of length 2
  251. ValueError:
  252. If prompts is a string and it is not one of the available prompts
  253. ValueError:
  254. If prompts is a list of strings and not all strings are one of the available prompts
  255. Example::
  256. Single prompt:
  257. >>> from torch import randint
  258. >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
  259. >>> imgs = randint(255, (2, 3, 224, 224)).float()
  260. >>> clip_image_quality_assessment(imgs, prompts=("quality",))
  261. tensor([0.8894, 0.8902])
  262. Example::
  263. Multiple prompts:
  264. >>> from torch import randint
  265. >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
  266. >>> imgs = randint(255, (2, 3, 224, 224)).float()
  267. >>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness"))
  268. {'quality': tensor([0.8693, 0.8705]), 'brightness': tensor([0.5722, 0.4762])}
  269. Example::
  270. Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt.
  271. >>> from torch import rand
  272. >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
  273. >>> imgs = randint(255, (2, 3, 224, 224)).float()
  274. >>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness"))
  275. {'user_defined_0': tensor([0.9578, 0.9654]), 'brightness': tensor([0.5495, 0.5764])}
  276. """
  277. prompts_list, prompts_names = _clip_iqa_format_prompts(prompts)
  278. model, processor = _get_clip_iqa_model_and_processor(model_name_or_path)
  279. device = images.device
  280. model = model.to(device)
  281. with torch.inference_mode():
  282. anchors = _clip_iqa_get_anchor_vectors(model_name_or_path, model, processor, prompts_list, device)
  283. img_features = _clip_iqa_update(model_name_or_path, images, model, processor, data_range, device)
  284. return _clip_iqa_compute(img_features, anchors, prompts_names)
  285. if TYPE_CHECKING:
  286. from functools import partial
  287. from typing import Any, cast
  288. images = cast(Any, None)
  289. f = partial(clip_image_quality_assessment, images=images)
  290. f(prompts=("colorfullness",))
  291. f(
  292. prompts=("quality", "brightness", "noisiness"),
  293. )
  294. f(
  295. prompts=("quality", "brightness", "noisiness", "colorfullness"),
  296. )
  297. f(prompts=(("Photo of a cat", "Photo of a dog"), "quality", ("Colorful photo", "Black and white photo")))