ssim.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import List, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from torch.nn import functional as F # noqa: N812
  19. from typing_extensions import Literal
  20. from torchmetrics.functional.image.utils import _gaussian_kernel_2d, _gaussian_kernel_3d, _reflection_pad_3d
  21. from torchmetrics.utilities.checks import _check_same_shape
  22. from torchmetrics.utilities.distributed import reduce
  23. def _ssim_check_inputs(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
  24. """Update and returns variables required to compute Structural Similarity Index Measure.
  25. Args:
  26. preds: Predicted tensor
  27. target: Ground truth tensor
  28. """
  29. if preds.dtype != target.dtype:
  30. target = target.to(preds.dtype)
  31. _check_same_shape(preds, target)
  32. if len(preds.shape) not in (4, 5):
  33. raise ValueError(
  34. "Expected `preds` and `target` to have BxCxHxW or BxCxDxHxW shape."
  35. f" Got preds: {preds.shape} and target: {target.shape}."
  36. )
  37. return preds, target
  38. def _ssim_update(
  39. preds: Tensor,
  40. target: Tensor,
  41. gaussian_kernel: bool = True,
  42. sigma: Union[float, Sequence[float]] = 1.5,
  43. kernel_size: Union[int, Sequence[int]] = 11,
  44. data_range: Optional[Union[float, tuple[float, float]]] = None,
  45. k1: float = 0.01,
  46. k2: float = 0.03,
  47. return_full_image: bool = False,
  48. return_contrast_sensitivity: bool = False,
  49. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  50. """Compute Structural Similarity Index Measure.
  51. Args:
  52. preds: estimated image
  53. target: ground truth image
  54. gaussian_kernel: If true (default), a gaussian kernel is used, if false a uniform kernel is used
  55. sigma: Standard deviation of the gaussian kernel, anisotropic kernels are possible.
  56. Ignored if a uniform kernel is used
  57. kernel_size: the size of the uniform kernel, anisotropic kernels are possible.
  58. Ignored if a Gaussian kernel is used
  59. data_range: Range of the image. If ``None``, it is determined from the image (max - min)
  60. k1: Parameter of SSIM.
  61. k2: Parameter of SSIM.
  62. return_full_image: If true, the full ``ssim`` image is returned as a second argument.
  63. Mutually exclusive with ``return_contrast_sensitivity``
  64. return_contrast_sensitivity: If true, the contrast term is returned as a second argument.
  65. The luminance term can be obtained with luminance=ssim/contrast
  66. Mutually exclusive with ``return_full_image``
  67. """
  68. is_3d = preds.ndim == 5
  69. if not isinstance(kernel_size, Sequence):
  70. kernel_size = 3 * [kernel_size] if is_3d else 2 * [kernel_size]
  71. if not isinstance(sigma, Sequence):
  72. sigma = 3 * [sigma] if is_3d else 2 * [sigma]
  73. if len(kernel_size) != len(target.shape) - 2:
  74. raise ValueError(
  75. f"`kernel_size` has dimension {len(kernel_size)}, but expected to be two less that target dimensionality,"
  76. f" which is: {len(target.shape)}"
  77. )
  78. if len(kernel_size) not in (2, 3):
  79. raise ValueError(
  80. f"Expected `kernel_size` dimension to be 2 or 3. `kernel_size` dimensionality: {len(kernel_size)}"
  81. )
  82. if len(sigma) != len(target.shape) - 2:
  83. raise ValueError(
  84. f"`kernel_size` has dimension {len(kernel_size)}, but expected to be two less that target dimensionality,"
  85. f" which is: {len(target.shape)}"
  86. )
  87. if len(sigma) not in (2, 3):
  88. raise ValueError(
  89. f"Expected `kernel_size` dimension to be 2 or 3. `kernel_size` dimensionality: {len(kernel_size)}"
  90. )
  91. if return_full_image and return_contrast_sensitivity:
  92. raise ValueError("Arguments `return_full_image` and `return_contrast_sensitivity` are mutually exclusive.")
  93. if any(x % 2 == 0 or x <= 0 for x in kernel_size):
  94. raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
  95. if any(y <= 0 for y in sigma):
  96. raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
  97. if data_range is None:
  98. data_range = max(preds.max() - preds.min(), target.max() - target.min()) # type: ignore[call-overload]
  99. elif isinstance(data_range, tuple):
  100. preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
  101. target = torch.clamp(target, min=data_range[0], max=data_range[1])
  102. data_range = data_range[1] - data_range[0]
  103. c1 = pow(k1 * data_range, 2) # type: ignore[operator]
  104. c2 = pow(k2 * data_range, 2) # type: ignore[operator]
  105. device = preds.device
  106. channel = preds.size(1)
  107. dtype = preds.dtype
  108. gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma]
  109. if gaussian_kernel:
  110. pad_h = (gauss_kernel_size[0] - 1) // 2
  111. pad_w = (gauss_kernel_size[1] - 1) // 2
  112. else:
  113. pad_h = (kernel_size[0] - 1) // 2
  114. pad_w = (kernel_size[1] - 1) // 2
  115. if is_3d:
  116. pad_d = (kernel_size[2] - 1) // 2
  117. preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h)
  118. target = _reflection_pad_3d(target, pad_d, pad_w, pad_h)
  119. if gaussian_kernel:
  120. kernel = _gaussian_kernel_3d(channel, gauss_kernel_size, sigma, dtype, device)
  121. else:
  122. preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode="reflect")
  123. target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode="reflect")
  124. if gaussian_kernel:
  125. kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device)
  126. if not gaussian_kernel:
  127. kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(
  128. torch.tensor(kernel_size, dtype=dtype, device=device)
  129. )
  130. input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
  131. outputs = F.conv3d(input_list, kernel, groups=channel) if is_3d else F.conv2d(input_list, kernel, groups=channel)
  132. output_list = outputs.split(preds.shape[0])
  133. mu_pred_sq = output_list[0].pow(2)
  134. mu_target_sq = output_list[1].pow(2)
  135. mu_pred_target = output_list[0] * output_list[1]
  136. # Calculate the variance of the predicted and target images, should be non-negative
  137. sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0)
  138. sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0)
  139. sigma_pred_target = output_list[4] - mu_pred_target
  140. upper = 2 * sigma_pred_target.to(dtype) + c2
  141. lower = (sigma_pred_sq + sigma_target_sq).to(dtype) + c2
  142. ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)
  143. if return_contrast_sensitivity:
  144. contrast_sensitivity = upper / lower
  145. if is_3d:
  146. contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d]
  147. else:
  148. contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w]
  149. return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), contrast_sensitivity.reshape(
  150. contrast_sensitivity.shape[0], -1
  151. ).mean(-1)
  152. if return_full_image:
  153. return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), ssim_idx_full_image
  154. return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1)
  155. def _ssim_compute(
  156. similarities: Tensor,
  157. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  158. ) -> Tensor:
  159. """Apply the specified reduction to pre-computed structural similarity.
  160. Args:
  161. similarities: per image similarities for a batch of images.
  162. reduction: a method to reduce metric score over individual batch scores
  163. - ``'elementwise_mean'``: takes the mean
  164. - ``'sum'``: takes the sum
  165. - ``'none'`` or ``None``: no reduction will be applied
  166. Returns:
  167. The reduced SSIM score
  168. """
  169. return reduce(similarities, reduction)
  170. def structural_similarity_index_measure(
  171. preds: Tensor,
  172. target: Tensor,
  173. gaussian_kernel: bool = True,
  174. sigma: Union[float, Sequence[float]] = 1.5,
  175. kernel_size: Union[int, Sequence[int]] = 11,
  176. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  177. data_range: Optional[Union[float, tuple[float, float]]] = None,
  178. k1: float = 0.01,
  179. k2: float = 0.03,
  180. return_full_image: bool = False,
  181. return_contrast_sensitivity: bool = False,
  182. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  183. """Compute Structural Similarity Index Measure.
  184. Args:
  185. preds: estimated image
  186. target: ground truth image
  187. gaussian_kernel: If true (default), a gaussian kernel is used, if false a uniform kernel is used
  188. sigma: Standard deviation of the gaussian kernel, anisotropic kernels are possible.
  189. Ignored if a uniform kernel is used
  190. kernel_size: the size of the uniform kernel, anisotropic kernels are possible.
  191. Ignored if a Gaussian kernel is used
  192. reduction: a method to reduce metric score over labels.
  193. - ``'elementwise_mean'``: takes the mean
  194. - ``'sum'``: takes the sum
  195. - ``'none'`` or ``None``: no reduction will be applied
  196. data_range:
  197. the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
  198. the range is calculated as the difference and input is clamped between the values.
  199. k1: Parameter of SSIM.
  200. k2: Parameter of SSIM.
  201. return_full_image: If true, the full ``ssim`` image is returned as a second argument.
  202. Mutually exclusive with ``return_contrast_sensitivity``
  203. return_contrast_sensitivity: If true, the constant term is returned as a second argument.
  204. The luminance term can be obtained with luminance=ssim/contrast
  205. Mutually exclusive with ``return_full_image``
  206. Return:
  207. Tensor with SSIM score
  208. Raises:
  209. TypeError:
  210. If ``preds`` and ``target`` don't have the same data type.
  211. ValueError:
  212. If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
  213. ValueError:
  214. If the length of ``kernel_size`` or ``sigma`` is not ``2``.
  215. ValueError:
  216. If one of the elements of ``kernel_size`` is not an ``odd positive number``.
  217. ValueError:
  218. If one of the elements of ``sigma`` is not a ``positive number``.
  219. Example:
  220. >>> from torchmetrics.functional.image import structural_similarity_index_measure
  221. >>> preds = torch.rand([3, 3, 256, 256])
  222. >>> target = preds * 0.75
  223. >>> structural_similarity_index_measure(preds, target)
  224. tensor(0.9219)
  225. """
  226. preds, target = _ssim_check_inputs(preds, target)
  227. similarity_pack = _ssim_update(
  228. preds,
  229. target,
  230. gaussian_kernel,
  231. sigma,
  232. kernel_size,
  233. data_range,
  234. k1,
  235. k2,
  236. return_full_image,
  237. return_contrast_sensitivity,
  238. )
  239. if isinstance(similarity_pack, tuple):
  240. similarity, image = similarity_pack
  241. return _ssim_compute(similarity, reduction), image
  242. similarity = similarity_pack
  243. return _ssim_compute(similarity, reduction)
  244. def _get_normalized_sim_and_cs(
  245. preds: Tensor,
  246. target: Tensor,
  247. gaussian_kernel: bool = True,
  248. sigma: Union[float, Sequence[float]] = 1.5,
  249. kernel_size: Union[int, Sequence[int]] = 11,
  250. data_range: Optional[Union[float, tuple[float, float]]] = None,
  251. k1: float = 0.01,
  252. k2: float = 0.03,
  253. normalize: Optional[Literal["relu", "simple"]] = None,
  254. ) -> tuple[Tensor, Tensor]:
  255. sim, contrast_sensitivity = _ssim_update(
  256. preds,
  257. target,
  258. gaussian_kernel,
  259. sigma,
  260. kernel_size,
  261. data_range,
  262. k1,
  263. k2,
  264. return_contrast_sensitivity=True,
  265. )
  266. if normalize == "relu":
  267. sim = torch.relu(sim)
  268. contrast_sensitivity = torch.relu(contrast_sensitivity)
  269. return sim, contrast_sensitivity
  270. def _multiscale_ssim_update(
  271. preds: Tensor,
  272. target: Tensor,
  273. gaussian_kernel: bool = True,
  274. sigma: Union[float, Sequence[float]] = 1.5,
  275. kernel_size: Union[int, Sequence[int]] = 11,
  276. data_range: Optional[Union[float, tuple[float, float]]] = None,
  277. k1: float = 0.01,
  278. k2: float = 0.03,
  279. betas: Union[tuple[float, float, float, float, float], tuple[float, ...]] = (
  280. 0.0448,
  281. 0.2856,
  282. 0.3001,
  283. 0.2363,
  284. 0.1333,
  285. ),
  286. normalize: Optional[Literal["relu", "simple"]] = None,
  287. ) -> Tensor:
  288. """Compute Multi-Scale Structural Similarity Index Measure.
  289. Adapted from: https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py.
  290. Args:
  291. preds: estimated image
  292. target: ground truth image
  293. gaussian_kernel: If true, a gaussian kernel is used, if false a uniform kernel is used
  294. sigma: Standard deviation of the gaussian kernel
  295. kernel_size: size of the gaussian kernel
  296. reduction: a method to reduce metric score over labels.
  297. - ``'elementwise_mean'``: takes the mean
  298. - ``'sum'``: takes the sum
  299. - ``'none'`` or ``None``: no reduction will be applied
  300. data_range: Range of the image. If ``None``, it is determined from the image (max - min)
  301. k1: Parameter of structural similarity index measure.
  302. k2: Parameter of structural similarity index measure.
  303. betas: Exponent parameters for individual similarities and contrastive sensitives returned by different image
  304. resolutions.
  305. normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the
  306. training stability. This `normalize` argument is out of scope of the original implementation [1], and it is
  307. adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
  308. Raises:
  309. ValueError:
  310. If the image height or width is smaller then ``2 ** len(betas)``.
  311. ValueError:
  312. If the image height is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``.
  313. ValueError:
  314. If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``.
  315. """
  316. mcs_list: List[Tensor] = []
  317. is_3d = preds.ndim == 5
  318. if not isinstance(kernel_size, Sequence):
  319. kernel_size = 3 * [kernel_size] if is_3d else 2 * [kernel_size]
  320. if not isinstance(sigma, Sequence):
  321. sigma = 3 * [sigma] if is_3d else 2 * [sigma]
  322. if preds.size()[-1] < 2 ** len(betas) or preds.size()[-2] < 2 ** len(betas):
  323. raise ValueError(
  324. f"For a given number of `betas` parameters {len(betas)}, the image height and width dimensions must be"
  325. f" larger than or equal to {2 ** len(betas)}."
  326. )
  327. _betas_div = max(1, (len(betas) - 1)) ** 2
  328. if preds.size()[-2] // _betas_div <= kernel_size[0] - 1:
  329. raise ValueError(
  330. f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[0]},"
  331. f" the image height must be larger than {(kernel_size[0] - 1) * _betas_div}."
  332. )
  333. if preds.size()[-1] // _betas_div <= kernel_size[1] - 1:
  334. raise ValueError(
  335. f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[1]},"
  336. f" the image width must be larger than {(kernel_size[1] - 1) * _betas_div}."
  337. )
  338. for _ in range(len(betas)):
  339. sim, contrast_sensitivity = _get_normalized_sim_and_cs(
  340. preds, target, gaussian_kernel, sigma, kernel_size, data_range, k1, k2, normalize=normalize
  341. )
  342. mcs_list.append(contrast_sensitivity)
  343. if len(kernel_size) == 2:
  344. preds = F.avg_pool2d(preds, (2, 2))
  345. target = F.avg_pool2d(target, (2, 2))
  346. elif len(kernel_size) == 3:
  347. preds = F.avg_pool3d(preds, (2, 2, 2))
  348. target = F.avg_pool3d(target, (2, 2, 2))
  349. else:
  350. raise ValueError("length of kernel_size is neither 2 nor 3")
  351. mcs_list[-1] = sim
  352. mcs_stack = torch.stack(mcs_list)
  353. if normalize == "simple":
  354. mcs_stack = (mcs_stack + 1) / 2
  355. betas = torch.tensor(betas, device=mcs_stack.device).view(-1, 1)
  356. mcs_weighted = mcs_stack**betas
  357. return torch.prod(mcs_weighted, axis=0) # type: ignore[call-overload]
  358. def _multiscale_ssim_compute(
  359. mcs_per_image: Tensor,
  360. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  361. ) -> Tensor:
  362. """Apply the specified reduction to pre-computed multi-scale structural similarity.
  363. Args:
  364. mcs_per_image: per image similarities for a batch of images.
  365. reduction: a method to reduce metric score over individual batch scores
  366. - ``'elementwise_mean'``: takes the mean
  367. - ``'sum'``: takes the sum
  368. - ``'none'`` or ``None``: no reduction will be applied
  369. Returns:
  370. The reduced multi-scale structural similarity
  371. """
  372. return reduce(mcs_per_image, reduction)
  373. def multiscale_structural_similarity_index_measure(
  374. preds: Tensor,
  375. target: Tensor,
  376. gaussian_kernel: bool = True,
  377. sigma: Union[float, Sequence[float]] = 1.5,
  378. kernel_size: Union[int, Sequence[int]] = 11,
  379. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  380. data_range: Optional[Union[float, tuple[float, float]]] = None,
  381. k1: float = 0.01,
  382. k2: float = 0.03,
  383. betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
  384. normalize: Optional[Literal["relu", "simple"]] = "relu",
  385. ) -> Tensor:
  386. """Compute `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure.
  387. This metric is a generalization of Structural Similarity Index Measure by incorporating image details at different
  388. resolution scores.
  389. Args:
  390. preds: Predictions from model of shape ``[N, C, H, W]``
  391. target: Ground truth values of shape ``[N, C, H, W]``
  392. gaussian_kernel: If true, a gaussian kernel is used, if false a uniform kernel is used
  393. sigma: Standard deviation of the gaussian kernel
  394. kernel_size: size of the gaussian kernel
  395. reduction: a method to reduce metric score over labels.
  396. - ``'elementwise_mean'``: takes the mean
  397. - ``'sum'``: takes the sum
  398. - ``'none'`` or ``None``: no reduction will be applied
  399. data_range:
  400. the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
  401. the range is calculated as the difference and input is clamped between the values.
  402. k1: Parameter of structural similarity index measure.
  403. k2: Parameter of structural similarity index measure.
  404. betas: Exponent parameters for individual similarities and contrastive sensitivities returned by different image
  405. resolutions.
  406. normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the
  407. training stability. This `normalize` argument is out of scope of the original implementation [1], and it is
  408. adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
  409. Return:
  410. Tensor with Multi-Scale SSIM score
  411. Raises:
  412. TypeError:
  413. If ``preds`` and ``target`` don't have the same data type.
  414. ValueError:
  415. If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
  416. ValueError:
  417. If the length of ``kernel_size`` or ``sigma`` is not ``2``.
  418. ValueError:
  419. If one of the elements of ``kernel_size`` is not an ``odd positive number``.
  420. ValueError:
  421. If one of the elements of ``sigma`` is not a ``positive number``.
  422. Example:
  423. >>> from torch import rand
  424. >>> from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
  425. >>> preds = rand([3, 3, 256, 256])
  426. >>> target = preds * 0.75
  427. >>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
  428. tensor(0.9628)
  429. References:
  430. [1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C.
  431. Bovik `MultiScaleSSIM`_
  432. """
  433. if not isinstance(betas, tuple):
  434. raise ValueError("Argument `betas` is expected to be of a type tuple.")
  435. if isinstance(betas, tuple) and not all(isinstance(beta, float) for beta in betas):
  436. raise ValueError("Argument `betas` is expected to be a tuple of floats.")
  437. if normalize and normalize not in ("relu", "simple"):
  438. raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'")
  439. preds, target = _ssim_check_inputs(preds, target)
  440. mcs_per_image = _multiscale_ssim_update(
  441. preds, target, gaussian_kernel, sigma, kernel_size, data_range, k1, k2, betas, normalize
  442. )
  443. return _multiscale_ssim_compute(mcs_per_image, reduction)