vmaf.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, Union
  15. import torch
  16. from torch import Tensor
  17. from torchmetrics.utilities.imports import _EINOPS_AVAILABLE, _TORCH_VMAF_AVAILABLE
  18. if _TORCH_VMAF_AVAILABLE:
  19. import pandas as pd # pandas is installed as a dependency of vmaf-torch
  20. from vmaf_torch import VMAF
  21. else:
  22. __doctest_skip__ = ["video_multi_method_assessment_fusion"]
  23. if _EINOPS_AVAILABLE:
  24. from einops import rearrange
  25. def calculate_luma(video: Tensor) -> Tensor:
  26. """Calculate the luma component of a video tensor."""
  27. r = video[:, 0, :, :, :]
  28. g = video[:, 1, :, :, :]
  29. b = video[:, 2, :, :, :]
  30. return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1) * 255 # [0, 1] -> [0, 255]
  31. def video_multi_method_assessment_fusion(
  32. preds: Tensor,
  33. target: Tensor,
  34. features: bool = False,
  35. ) -> Union[Tensor, Dict[str, Tensor]]:
  36. """Calculates Video Multi-Method Assessment Fusion (VMAF) metric.
  37. VMAF is a full-reference video quality assessment algorithm that combines multiple quality assessment features
  38. such as detail loss, motion, and contrast using a machine learning model to predict human perception of video
  39. quality more accurately than traditional metrics like PSNR or SSIM.
  40. The metric works by:
  41. 1. Converting input videos to luma component (grayscale)
  42. 2. Computing multiple elementary features:
  43. - Additive Detail Measure (ADM): Evaluates detail preservation at different scales
  44. - Visual Information Fidelity (VIF): Measures preservation of visual information across frequency bands
  45. - Motion: Quantifies the amount of motion in the video
  46. 3. Combining these features using a trained SVM model to predict quality
  47. .. note::
  48. This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch.
  49. Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`.
  50. Args:
  51. preds: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format
  52. with values in range [0, 1].
  53. target: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format
  54. with values in range [0, 1].
  55. features: If True, all the elementary features (ADM, VIF, motion) are returned along with the VMAF score in
  56. a dictionary. This corresponds to the output you would get from the VMAF command line tool with the `--csv`
  57. option enabled. If False, only the VMAF score is returned as a tensor.
  58. Returns:
  59. - If `features` is False, returns a tensor with shape (batch, frame) of VMAF score for each frame in
  60. each video. Higher scores indicate better quality, with typical values ranging from 0 to 100.
  61. - If `features` is True, returns a dictionary where each value is a (batch, frame) tensor of the
  62. corresponding feature. The keys are:
  63. - 'integer_motion2': Integer motion feature
  64. - 'integer_motion': Integer motion feature
  65. - 'integer_adm2': Integer ADM feature
  66. - 'integer_adm_scale0': Integer ADM feature at scale 0
  67. - 'integer_adm_scale1': Integer ADM feature at scale 1
  68. - 'integer_adm_scale2': Integer ADM feature at scale 2
  69. - 'integer_adm_scale3': Integer ADM feature at scale 3
  70. - 'integer_vif_scale0': Integer VIF feature at scale 0
  71. - 'integer_vif_scale1': Integer VIF feature at scale 1
  72. - 'integer_vif_scale2': Integer VIF feature at scale 2
  73. - 'integer_vif_scale3': Integer VIF feature at scale 3
  74. - 'vmaf': VMAF score for each frame in each video
  75. Example:
  76. >>> import torch
  77. >>> from torchmetrics.functional.video import video_multi_method_assessment_fusion
  78. >>> # 2 videos, 3 channels, 10 frames, 32x32 resolution
  79. >>> preds = torch.rand(2, 3, 10, 32, 32, generator=torch.manual_seed(42))
  80. >>> target = torch.rand(2, 3, 10, 32, 32, generator=torch.manual_seed(43))
  81. >>> vmaf_score = video_multi_method_assessment_fusion(preds, target)
  82. >>> torch.round(vmaf_score, decimals=2)
  83. tensor([[ 9.9900, 15.9000, 14.2600, 16.6100, 15.9100, 14.3000, 13.5800, 13.4900, 15.4700, 20.2800],
  84. [ 6.2500, 11.3000, 17.3000, 11.4600, 19.0600, 14.9300, 14.0500, 14.4100, 12.4700, 14.8200]])
  85. >>> vmaf_dict = video_multi_method_assessment_fusion(preds, target, features=True)
  86. >>> # show a couple of features, more features are available
  87. >>> vmaf_dict['vmaf'].round(decimals=2)
  88. tensor([[ 9.9900, 15.9000, 14.2600, 16.6100, 15.9100, 14.3000, 13.5800, 13.4900, 15.4700, 20.2800],
  89. [ 6.2500, 11.3000, 17.3000, 11.4600, 19.0600, 14.9300, 14.0500, 14.4100, 12.4700, 14.8200]])
  90. >>> vmaf_dict['integer_adm2'].round(decimals=2)
  91. tensor([[0.4500, 0.4500, 0.3600, 0.4700, 0.4300, 0.3600, 0.3900, 0.4100, 0.3700, 0.4700],
  92. [0.4200, 0.3900, 0.4400, 0.3700, 0.4500, 0.3900, 0.3800, 0.4800, 0.3900, 0.3900]])
  93. """
  94. if not _TORCH_VMAF_AVAILABLE:
  95. raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.")
  96. b = preds.shape[0]
  97. orig_dtype, device = preds.dtype, preds.device
  98. preds_luma = calculate_luma(preds)
  99. target_luma = calculate_luma(target)
  100. vmaf = VMAF().to(device)
  101. # we need to compute the model for each video separately
  102. if not features:
  103. scores = [
  104. vmaf.compute_vmaf_score(
  105. rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w")
  106. )
  107. for video in range(b)
  108. ]
  109. return torch.cat(scores, dim=1).t().to(orig_dtype)
  110. scores_and_features = [
  111. vmaf.table(
  112. rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w")
  113. )
  114. for video in range(b)
  115. ]
  116. dfs = [scores_and_features[video].apply(pd.to_numeric, errors="coerce") for video in range(b)]
  117. result = [
  118. {col: torch.tensor(dfs[video][col].values, dtype=orig_dtype) for col in dfs[video].columns if col != "Frame"}
  119. for video in range(b)
  120. ]
  121. return {col: torch.stack([result[video][col] for video in range(b)]) for col in result[0]}