vmaf.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Union
  15. from torch import Tensor
  16. from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion
  17. from torchmetrics.metric import Metric
  18. from torchmetrics.utilities.data import dim_zero_cat
  19. from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE
  20. if not _TORCH_VMAF_AVAILABLE:
  21. __doctest_skip__ = ["VideoMultiMethodAssessmentFusion"]
  22. class VideoMultiMethodAssessmentFusion(Metric):
  23. """Calculates Video Multi-Method Assessment Fusion (VMAF) metric.
  24. VMAF is a full-reference video quality assessment algorithm that combines multiple quality assessment features
  25. such as detail loss, motion, and contrast using a machine learning model to predict human perception of video
  26. quality more accurately than traditional metrics like PSNR or SSIM.
  27. The metric works by:
  28. 1. Converting input videos to luma component (grayscale)
  29. 2. Computing multiple elementary features:
  30. - Additive Detail Measure (ADM): Evaluates detail preservation at different scales
  31. - Visual Information Fidelity (VIF): Measures preservation of visual information across frequency bands
  32. - Motion: Quantifies the amount of motion in the video
  33. 3. Combining these features using a trained SVM model to predict quality
  34. .. note::
  35. This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch.
  36. Install either by cloning the repository and running ``pip install .``
  37. or with ``pip install torchmetrics[video]``.
  38. As input to ``forward`` and ``update`` the metric accepts the following input:
  39. - ``preds`` (:class:`~torch.Tensor`): Video tensor of shape ``(batch, channels, frames, height, width)``.
  40. Expected to be in RGB format with values in range [0, 1].
  41. - ``target`` (:class:`~torch.Tensor`): Video tensor of shape ``(batch, channels, frames, height, width)``.
  42. Expected to be in RGB format with values in range [0, 1].
  43. As output of ``forward`` and ``compute`` the metric returns the following output ``vmaf`` (:class:`~torch.Tensor`):
  44. - If ``features`` is False, returns a tensor with shape (batch, frame)
  45. of VMAF score for each frame in each video. Higher scores indicate better quality, with typical values
  46. ranging from 0 to 100.
  47. - If ``features`` is True, returns a dictionary where each value is a (batch, frame) tensor of the
  48. corresponding feature. The keys are:
  49. - 'integer_motion2': Integer motion feature
  50. - 'integer_motion': Integer motion feature
  51. - 'integer_adm2': Integer ADM feature
  52. - 'integer_adm_scale0': Integer ADM feature at scale 0
  53. - 'integer_adm_scale1': Integer ADM feature at scale 1
  54. - 'integer_adm_scale2': Integer ADM feature at scale 2
  55. - 'integer_adm_scale3': Integer ADM feature at scale 3
  56. - 'integer_vif_scale0': Integer VIF feature at scale 0
  57. - 'integer_vif_scale1': Integer VIF feature at scale 1
  58. - 'integer_vif_scale2': Integer VIF feature at scale 2
  59. - 'integer_vif_scale3': Integer VIF feature at scale 3
  60. - 'vmaf': VMAF score for each frame in each video
  61. Args:
  62. features: If True, all the elementary features (ADM, VIF, motion) are returned along with the VMAF score in
  63. a dictionary. This corresponds to the output you would get from the VMAF command line tool with
  64. the ``--csv`` option enabled. If False, only the VMAF score is returned as a tensor.
  65. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  66. Raises:
  67. RuntimeError:
  68. If vmaf-torch is not installed.
  69. ValueError:
  70. If ``features`` is not a boolean.
  71. Example:
  72. >>> import torch
  73. >>> from torchmetrics.video import VideoMultiMethodAssessmentFusion
  74. >>> # 2 videos, 3 channels, 10 frames, 32x32 resolution
  75. >>> preds = torch.rand(2, 3, 10, 32, 32, generator=torch.manual_seed(42))
  76. >>> target = torch.rand(2, 3, 10, 32, 32, generator=torch.manual_seed(43))
  77. >>> vmaf = VideoMultiMethodAssessmentFusion()
  78. >>> torch.round(vmaf(preds, target), decimals=2)
  79. tensor([[ 9.9900, 15.9000, 14.2600, 16.6100, 15.9100, 14.3000, 13.5800, 13.4900, 15.4700, 20.2800],
  80. [ 6.2500, 11.3000, 17.3000, 11.4600, 19.0600, 14.9300, 14.0500, 14.4100, 12.4700, 14.8200]])
  81. >>> vmaf = VideoMultiMethodAssessmentFusion(features=True)
  82. >>> vmaf_dict = vmaf(preds, target)
  83. >>> vmaf_dict['vmaf'].round(decimals=2)
  84. tensor([[ 9.9900, 15.9000, 14.2600, 16.6100, 15.9100, 14.3000, 13.5800, 13.4900, 15.4700, 20.2800],
  85. [ 6.2500, 11.3000, 17.3000, 11.4600, 19.0600, 14.9300, 14.0500, 14.4100, 12.4700, 14.8200]])
  86. >>> vmaf_dict['integer_adm2'].round(decimals=2)
  87. tensor([[0.4500, 0.4500, 0.3600, 0.4700, 0.4300, 0.3600, 0.3900, 0.4100, 0.3700, 0.4700],
  88. [0.4200, 0.3900, 0.4400, 0.3700, 0.4500, 0.3900, 0.3800, 0.4800, 0.3900, 0.3900]])
  89. """
  90. is_differentiable: bool = False
  91. higher_is_better: bool = True
  92. full_state_update: bool = False
  93. plot_lower_bound: float = 0.0
  94. plot_upper_bound: float = 100.0
  95. vmaf_score: List[Tensor]
  96. integer_motion2: List[Tensor]
  97. integer_motion: List[Tensor]
  98. integer_adm2: List[Tensor]
  99. integer_adm_scale0: List[Tensor]
  100. integer_adm_scale1: List[Tensor]
  101. integer_adm_scale2: List[Tensor]
  102. integer_adm_scale3: List[Tensor]
  103. integer_vif_scale0: List[Tensor]
  104. integer_vif_scale1: List[Tensor]
  105. integer_vif_scale2: List[Tensor]
  106. integer_vif_scale3: List[Tensor]
  107. def __init__(self, features: bool = False, **kwargs: Any) -> None:
  108. super().__init__(**kwargs)
  109. if not _TORCH_VMAF_AVAILABLE:
  110. raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.")
  111. if not isinstance(features, bool):
  112. raise ValueError("Argument `elementary_features` should be a boolean, but got {features}.")
  113. self.features = features
  114. self.add_state("vmaf_score", default=[], dist_reduce_fx="cat")
  115. if self.features:
  116. self.add_state("integer_motion2", default=[], dist_reduce_fx="cat")
  117. self.add_state("integer_motion", default=[], dist_reduce_fx="cat")
  118. self.add_state("integer_adm2", default=[], dist_reduce_fx="cat")
  119. self.add_state("integer_adm_scale0", default=[], dist_reduce_fx="cat")
  120. self.add_state("integer_adm_scale1", default=[], dist_reduce_fx="cat")
  121. self.add_state("integer_adm_scale2", default=[], dist_reduce_fx="cat")
  122. self.add_state("integer_adm_scale3", default=[], dist_reduce_fx="cat")
  123. self.add_state("integer_vif_scale0", default=[], dist_reduce_fx="cat")
  124. self.add_state("integer_vif_scale1", default=[], dist_reduce_fx="cat")
  125. self.add_state("integer_vif_scale2", default=[], dist_reduce_fx="cat")
  126. self.add_state("integer_vif_scale3", default=[], dist_reduce_fx="cat")
  127. def update(self, preds: Tensor, target: Tensor) -> None:
  128. """Update state with predictions and targets."""
  129. score = video_multi_method_assessment_fusion(preds, target, self.features)
  130. if self.features and isinstance(score, dict):
  131. self.vmaf_score.append(score["vmaf"])
  132. self.integer_motion2.append(score["integer_motion2"])
  133. self.integer_motion.append(score["integer_motion"])
  134. self.integer_adm2.append(score["integer_adm2"])
  135. self.integer_adm_scale0.append(score["integer_adm_scale0"])
  136. self.integer_adm_scale1.append(score["integer_adm_scale1"])
  137. self.integer_adm_scale2.append(score["integer_adm_scale2"])
  138. self.integer_adm_scale3.append(score["integer_adm_scale3"])
  139. self.integer_vif_scale0.append(score["integer_vif_scale0"])
  140. self.integer_vif_scale1.append(score["integer_vif_scale1"])
  141. self.integer_vif_scale2.append(score["integer_vif_scale2"])
  142. self.integer_vif_scale3.append(score["integer_vif_scale3"])
  143. elif isinstance(score, Tensor):
  144. self.vmaf_score.append(score)
  145. def compute(self) -> Union[Tensor, Dict[str, Tensor]]:
  146. """Compute final VMAF score."""
  147. if self.features:
  148. return {
  149. "vmaf": dim_zero_cat(self.vmaf_score),
  150. "integer_motion2": dim_zero_cat(self.integer_motion2),
  151. "integer_motion": dim_zero_cat(self.integer_motion),
  152. "integer_adm2": dim_zero_cat(self.integer_adm2),
  153. "integer_adm_scale0": dim_zero_cat(self.integer_adm_scale0),
  154. "integer_adm_scale1": dim_zero_cat(self.integer_adm_scale1),
  155. "integer_adm_scale2": dim_zero_cat(self.integer_adm_scale2),
  156. "integer_adm_scale3": dim_zero_cat(self.integer_adm_scale3),
  157. "integer_vif_scale0": dim_zero_cat(self.integer_vif_scale0),
  158. "integer_vif_scale1": dim_zero_cat(self.integer_vif_scale1),
  159. "integer_vif_scale2": dim_zero_cat(self.integer_vif_scale2),
  160. "integer_vif_scale3": dim_zero_cat(self.integer_vif_scale3),
  161. }
  162. return dim_zero_cat(self.vmaf_score)