integral.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Optional, Tuple
  18. from kornia.core import ImageModule as Module
  19. from kornia.core import Tensor
  20. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
  21. def integral_tensor(input: Tensor, dim: Optional[Tuple[int, ...]] = None) -> Tensor:
  22. """Calculate integral of the input tensor.
  23. The algorithm computes the integral image by summing over the specified dimensions.
  24. In case dim is specified, the contained dimensions must be unique and sorted in ascending order
  25. and not exceed the number of dimensions of the input tensor.
  26. Args:
  27. input: the input tensor with shape :math:`(*, D)`. Where D is the number of dimensions.
  28. dim: the dimension to be summed.
  29. Returns:
  30. Integral tensor for the input tensor with shape :math:`(*, D)`.
  31. Examples:
  32. >>> input = torch.ones(3, 5)
  33. >>> output = integral_tensor(input, (-2, -1))
  34. >>> output
  35. tensor([[ 1., 2., 3., 4., 5.],
  36. [ 2., 4., 6., 8., 10.],
  37. [ 3., 6., 9., 12., 15.]])
  38. """
  39. KORNIA_CHECK_SHAPE(input, ["*", "D"])
  40. if dim is None:
  41. dim = (-1,)
  42. KORNIA_CHECK(len(dim) > 0, "dim must be a non-empty tuple.")
  43. KORNIA_CHECK(len(dim) <= len(input.shape), "dim must be a tuple of length <= input.shape.")
  44. output = input
  45. for i in dim:
  46. output = output.cumsum(i)
  47. return output
  48. def integral_image(image: Tensor) -> Tensor:
  49. r"""Calculate integral of the input image tensor.
  50. This particular version sums over the last two dimensions.
  51. Args:
  52. image: the input image tensor with shape :math:`(*, H, W)`.
  53. Returns:
  54. Integral tensor for the input image tensor with shape :math:`(*, H, W)`.
  55. Examples:
  56. >>> input = torch.ones(1, 5, 5)
  57. >>> output = integral_image(input)
  58. >>> output
  59. tensor([[[ 1., 2., 3., 4., 5.],
  60. [ 2., 4., 6., 8., 10.],
  61. [ 3., 6., 9., 12., 15.],
  62. [ 4., 8., 12., 16., 20.],
  63. [ 5., 10., 15., 20., 25.]]])
  64. """
  65. KORNIA_CHECK_SHAPE(image, ["*", "H", "W"])
  66. return integral_tensor(image, (-2, -1))
  67. class IntegralTensor(Module):
  68. r"""Calculates integral of the input tensor.
  69. Args:
  70. image: the input tensor with shape :math:`(B,C,H,W)`.
  71. Returns:
  72. Integral tensor for the input tensor with shape :math:`(B,C,H,W)`.
  73. Shape:
  74. - Input: :math:`(B, C, H, W)`
  75. - Output: :math:`(B, C, H, W)`
  76. Examples:
  77. >>> input = torch.ones(3, 5)
  78. >>> dim = (-2, -1)
  79. >>> output = IntegralTensor(dim)(input)
  80. >>> output
  81. tensor([[ 1., 2., 3., 4., 5.],
  82. [ 2., 4., 6., 8., 10.],
  83. [ 3., 6., 9., 12., 15.]])
  84. """
  85. def __init__(self, dim: Optional[Tuple[int, ...]] = None) -> None:
  86. super().__init__()
  87. self.dim = dim
  88. def forward(self, input: Tensor) -> Tensor:
  89. return integral_tensor(input, self.dim)
  90. class IntegralImage(Module):
  91. """Calculates integral of the input image tensor.
  92. This particular version sums over the last two dimensions.
  93. Args:
  94. image: the input image tensor with shape :math:`(B,C,H,W)`.
  95. Returns:
  96. Integral tensor for the input image tensor with shape :math:`(B,C,H,W)`.
  97. Shape:
  98. - Input: :math:`(B, C, H, W)`
  99. - Output: :math:`(B, C, H, W)`
  100. Examples:
  101. >>> input = torch.ones(1, 5, 5)
  102. >>> output = IntegralImage()(input)
  103. >>> output
  104. tensor([[[ 1., 2., 3., 4., 5.],
  105. [ 2., 4., 6., 8., 10.],
  106. [ 3., 6., 9., 12., 15.],
  107. [ 4., 8., 12., 16., 20.],
  108. [ 5., 10., 15., 20., 25.]]])
  109. """
  110. def forward(self, input: Tensor) -> Tensor:
  111. return integral_image(input)