flips.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. from kornia.core import ImageModule as Module
  19. from kornia.core import Tensor
  20. __all__ = ["Hflip", "Rot180", "Vflip", "hflip", "rot180", "vflip"]
  21. class Vflip(Module):
  22. r"""Vertically flip a tensor image or a batch of tensor images.
  23. Input must be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
  24. Args:
  25. input: input tensor.
  26. Returns:
  27. The vertically flipped image tensor.
  28. Examples:
  29. >>> vflip = Vflip()
  30. >>> input = torch.tensor([[[
  31. ... [0., 0., 0.],
  32. ... [0., 0., 0.],
  33. ... [0., 1., 1.]
  34. ... ]]])
  35. >>> vflip(input)
  36. tensor([[[[0., 1., 1.],
  37. [0., 0., 0.],
  38. [0., 0., 0.]]]])
  39. """
  40. def forward(self, input: Tensor) -> Tensor:
  41. return vflip(input)
  42. def __repr__(self) -> str:
  43. return self.__class__.__name__
  44. class Hflip(Module):
  45. r"""Horizontally flip a tensor image or a batch of tensor images.
  46. Input must be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
  47. Args:
  48. input: input tensor.
  49. Returns:
  50. The horizontally flipped image tensor.
  51. Examples:
  52. >>> hflip = Hflip()
  53. >>> input = torch.tensor([[[
  54. ... [0., 0., 0.],
  55. ... [0., 0., 0.],
  56. ... [0., 1., 1.]
  57. ... ]]])
  58. >>> hflip(input)
  59. tensor([[[[0., 0., 0.],
  60. [0., 0., 0.],
  61. [1., 1., 0.]]]])
  62. """
  63. def forward(self, input: Tensor) -> Tensor:
  64. return hflip(input)
  65. def __repr__(self) -> str:
  66. return self.__class__.__name__
  67. class Rot180(Module):
  68. r"""Rotate a tensor image or a batch of tensor images 180 degrees.
  69. Input must be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
  70. Args:
  71. input: input tensor.
  72. Examples:
  73. >>> rot180 = Rot180()
  74. >>> input = torch.tensor([[[
  75. ... [0., 0., 0.],
  76. ... [0., 0., 0.],
  77. ... [0., 1., 1.]
  78. ... ]]])
  79. >>> rot180(input)
  80. tensor([[[[1., 1., 0.],
  81. [0., 0., 0.],
  82. [0., 0., 0.]]]])
  83. """
  84. def forward(self, input: Tensor) -> Tensor:
  85. return rot180(input)
  86. def __repr__(self) -> str:
  87. return self.__class__.__name__
  88. def rot180(input: Tensor) -> Tensor:
  89. r"""Rotate a tensor image or a batch of tensor images 180 degrees.
  90. .. image:: _static/img/rot180.png
  91. Input must be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
  92. Args:
  93. input: input tensor.
  94. Returns:
  95. The rotated image tensor.
  96. """
  97. return torch.flip(input, [-2, -1])
  98. def hflip(input: Tensor) -> Tensor:
  99. r"""Horizontally flip a tensor image or a batch of tensor images.
  100. .. image:: _static/img/hflip.png
  101. Input must be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
  102. Args:
  103. input: input tensor.
  104. Returns:
  105. The horizontally flipped image tensor.
  106. """
  107. return input.flip(-1).contiguous()
  108. def vflip(input: Tensor) -> Tensor:
  109. r"""Vertically flip a tensor image or a batch of tensor images.
  110. .. image:: _static/img/vflip.png
  111. Input must be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
  112. Args:
  113. input: input tensor.
  114. Returns:
  115. The vertically flipped image tensor.
  116. """
  117. return input.flip(-2).contiguous()