endpoint_error.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. from torch import Tensor, nn
  19. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  20. def aepe(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
  21. r"""Create a function that calculates the average endpoint error (AEPE) between 2 flow maps.
  22. AEPE is the endpoint error between two 2D vectors (e.g., optical flow).
  23. Given a h x w x 2 optical flow map, the AEPE is:
  24. .. math::
  25. \text{AEPE}=\frac{1}{hw}\sum_{i=1, j=1}^{h, w}\sqrt{(I_{i,j,1}-T_{i,j,1})^{2}+(I_{i,j,2}-T_{i,j,2})^{2}}
  26. Args:
  27. input: the input flow map with shape :math:`(*, 2)`.
  28. target: the target flow map with shape :math:`(*, 2)`.
  29. reduction : Specifies the reduction to apply to the
  30. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  31. ``'mean'``: the sum of the output will be divided by the number of elements
  32. in the output, ``'sum'``: the output will be summed.
  33. Return:
  34. the computed AEPE as a scalar.
  35. Examples:
  36. >>> ones = torch.ones(4, 4, 2)
  37. >>> aepe(ones, 1.2 * ones)
  38. tensor(0.2828)
  39. Reference:
  40. https://link.springer.com/content/pdf/10.1007/s11263-010-0390-2.pdf
  41. """
  42. KORNIA_CHECK_IS_TENSOR(input)
  43. KORNIA_CHECK_IS_TENSOR(target)
  44. KORNIA_CHECK_SHAPE(input, ["*", "2"])
  45. KORNIA_CHECK_SHAPE(target, ["*", "2"])
  46. KORNIA_CHECK(
  47. input.shape == target.shape, f"input and target shapes must be the same. Got: {input.shape} and {target.shape}"
  48. )
  49. epe: Tensor = ((input[..., 0] - target[..., 0]) ** 2 + (input[..., 1] - target[..., 1]) ** 2).sqrt()
  50. if reduction == "mean":
  51. epe = epe.mean()
  52. elif reduction == "sum":
  53. epe = epe.sum()
  54. elif reduction == "none":
  55. pass
  56. else:
  57. raise NotImplementedError("Invalid reduction option.")
  58. return epe
  59. class AEPE(nn.Module):
  60. r"""Computes the average endpoint error (AEPE) between 2 flow maps.
  61. EPE is the endpoint error between two 2D vectors (e.g., optical flow).
  62. Given a h x w x 2 optical flow map, the AEPE is:
  63. .. math::
  64. \text{AEPE}=\frac{1}{hw}\sum_{i=1, j=1}^{h, w}\sqrt{(I_{i,j,1}-T_{i,j,1})^{2}+(I_{i,j,2}-T_{i,j,2})^{2}}
  65. Args:
  66. reduction : Specifies the reduction to apply to the
  67. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  68. ``'mean'``: the sum of the output will be divided by the number of elements
  69. in the output, ``'sum'``: the output will be summed.
  70. Shape:
  71. - input: :math:`(*, 2)`.
  72. - target :math:`(*, 2)`.
  73. - output: :math:`(1)`.
  74. Examples:
  75. >>> input1 = torch.rand(1, 4, 5, 2)
  76. >>> input2 = torch.rand(1, 4, 5, 2)
  77. >>> epe = AEPE(reduction="mean")
  78. >>> epe = epe(input1, input2)
  79. """
  80. def __init__(self, reduction: str = "mean") -> None:
  81. super().__init__()
  82. self.reduction: str = reduction
  83. def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  84. return aepe(input, target, self.reduction)
  85. average_endpoint_error = aepe