grid.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Optional
  18. import torch
  19. from kornia.core import Tensor, stack
  20. from kornia.utils._compat import torch_meshgrid
  21. def create_meshgrid(
  22. height: int,
  23. width: int,
  24. normalized_coordinates: bool = True,
  25. device: Optional[torch.device] = None,
  26. dtype: Optional[torch.dtype] = None,
  27. ) -> Tensor:
  28. """Generate a coordinate grid for an image.
  29. When the flag ``normalized_coordinates`` is set to True, the grid is
  30. normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
  31. function :py:func:`torch.nn.functional.grid_sample`.
  32. Args:
  33. height: the image height (rows).
  34. width: the image width (cols).
  35. normalized_coordinates: whether to normalize
  36. coordinates in the range :math:`[-1,1]` in order to be consistent with the
  37. PyTorch function :py:func:`torch.nn.functional.grid_sample`.
  38. device: the device on which the grid will be generated.
  39. dtype: the data type of the generated grid.
  40. Return:
  41. grid tensor with shape :math:`(1, H, W, 2)`.
  42. Example:
  43. >>> create_meshgrid(2, 2)
  44. tensor([[[[-1., -1.],
  45. [ 1., -1.]],
  46. <BLANKLINE>
  47. [[-1., 1.],
  48. [ 1., 1.]]]])
  49. >>> create_meshgrid(2, 2, normalized_coordinates=False)
  50. tensor([[[[0., 0.],
  51. [1., 0.]],
  52. <BLANKLINE>
  53. [[0., 1.],
  54. [1., 1.]]]])
  55. """
  56. xs: Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
  57. ys: Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
  58. # Fix TracerWarning
  59. # Note: normalize_pixel_coordinates still gots TracerWarning since new width and height
  60. # tensors will be generated.
  61. # Below is the code using normalize_pixel_coordinates:
  62. # base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=2)
  63. # if normalized_coordinates:
  64. # base_grid = K.geometry.normalize_pixel_coordinates(base_grid, height, width)
  65. # return torch.unsqueeze(base_grid.transpose(0, 1), dim=0)
  66. if normalized_coordinates:
  67. xs = (xs / (width - 1) - 0.5) * 2
  68. ys = (ys / (height - 1) - 0.5) * 2
  69. # generate grid by stacking coordinates
  70. # TODO: torchscript doesn't like `torch_version_ge`
  71. # if torch_version_ge(1, 13, 0):
  72. # x, y = torch_meshgrid([xs, ys], indexing="xy")
  73. # return stack([x, y], -1).unsqueeze(0) # 1xHxWx2
  74. # TODO: remove after we drop support of old versions
  75. base_grid: Tensor = stack(torch_meshgrid([xs, ys], indexing="ij"), dim=-1) # WxHx2
  76. return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
  77. def create_meshgrid3d(
  78. depth: int,
  79. height: int,
  80. width: int,
  81. normalized_coordinates: bool = True,
  82. device: Optional[torch.device] = None,
  83. dtype: Optional[torch.dtype] = None,
  84. ) -> Tensor:
  85. """Generate a coordinate grid for an image.
  86. When the flag ``normalized_coordinates`` is set to True, the grid is
  87. normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
  88. function :py:func:`torch.nn.functional.grid_sample`.
  89. Args:
  90. depth: the image depth (channels).
  91. height: the image height (rows).
  92. width: the image width (cols).
  93. normalized_coordinates: whether to normalize
  94. coordinates in the range :math:`[-1,1]` in order to be consistent with the
  95. PyTorch function :py:func:`torch.nn.functional.grid_sample`.
  96. device: the device on which the grid will be generated.
  97. dtype: the data type of the generated grid.
  98. Return:
  99. grid tensor with shape :math:`(1, D, H, W, 3)`.
  100. """
  101. xs: Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
  102. ys: Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
  103. zs: Tensor = torch.linspace(0, depth - 1, depth, device=device, dtype=dtype)
  104. # Fix TracerWarning
  105. if normalized_coordinates:
  106. xs = (xs / (width - 1) - 0.5) * 2
  107. ys = (ys / (height - 1) - 0.5) * 2
  108. zs = (zs / (depth - 1) - 0.5) * 2
  109. # generate grid by stacking coordinates
  110. base_grid = stack(torch_meshgrid([zs, xs, ys], indexing="ij"), dim=-1) # DxWxHx3
  111. return base_grid.permute(0, 2, 1, 3).unsqueeze(0) # 1xDxHxWx3