depth_smooth.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import torch
  19. from torch import nn
  20. # Based on
  21. # https://github.com/tensorflow/models/blob/master/research/struct2depth/model.py#L625-L641
  22. def _gradient_x(img: torch.Tensor) -> torch.Tensor:
  23. if len(img.shape) != 4:
  24. raise AssertionError(img.shape)
  25. return img[:, :, :, :-1] - img[:, :, :, 1:]
  26. def _gradient_y(img: torch.Tensor) -> torch.Tensor:
  27. if len(img.shape) != 4:
  28. raise AssertionError(img.shape)
  29. return img[:, :, :-1, :] - img[:, :, 1:, :]
  30. def inverse_depth_smoothness_loss(idepth: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
  31. r"""Criterion that computes image-aware inverse depth smoothness loss.
  32. .. math::
  33. \text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \|
  34. \partial_x I_{ij} \right \|} + \left |
  35. \partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}
  36. Args:
  37. idepth: tensor with the inverse depth with shape :math:`(N, 1, H, W)`.
  38. image: tensor with the input image with shape :math:`(N, 3, H, W)`.
  39. Return:
  40. a scalar with the computed loss.
  41. Examples:
  42. >>> idepth = torch.rand(1, 1, 4, 5)
  43. >>> image = torch.rand(1, 3, 4, 5)
  44. >>> loss = inverse_depth_smoothness_loss(idepth, image)
  45. """
  46. if not isinstance(idepth, torch.Tensor):
  47. raise TypeError(f"Input idepth type is not a torch.Tensor. Got {type(idepth)}")
  48. if not isinstance(image, torch.Tensor):
  49. raise TypeError(f"Input image type is not a torch.Tensor. Got {type(image)}")
  50. if not len(idepth.shape) == 4:
  51. raise ValueError(f"Invalid idepth shape, we expect BxCxHxW. Got: {idepth.shape}")
  52. if not len(image.shape) == 4:
  53. raise ValueError(f"Invalid image shape, we expect BxCxHxW. Got: {image.shape}")
  54. if not idepth.shape[-2:] == image.shape[-2:]:
  55. raise ValueError(f"idepth and image shapes must be the same. Got: {idepth.shape} and {image.shape}")
  56. if not idepth.device == image.device:
  57. raise ValueError(f"idepth and image must be in the same device. Got: {idepth.device} and {image.device}")
  58. if not idepth.dtype == image.dtype:
  59. raise ValueError(f"idepth and image must be in the same dtype. Got: {idepth.dtype} and {image.dtype}")
  60. # compute the gradients
  61. idepth_dx: torch.Tensor = _gradient_x(idepth)
  62. idepth_dy: torch.Tensor = _gradient_y(idepth)
  63. image_dx: torch.Tensor = _gradient_x(image)
  64. image_dy: torch.Tensor = _gradient_y(image)
  65. # compute image weights
  66. weights_x: torch.Tensor = torch.exp(-torch.mean(torch.abs(image_dx), dim=1, keepdim=True))
  67. weights_y: torch.Tensor = torch.exp(-torch.mean(torch.abs(image_dy), dim=1, keepdim=True))
  68. # apply image weights to depth
  69. smoothness_x: torch.Tensor = torch.abs(idepth_dx * weights_x)
  70. smoothness_y: torch.Tensor = torch.abs(idepth_dy * weights_y)
  71. return torch.mean(smoothness_x) + torch.mean(smoothness_y)
  72. class InverseDepthSmoothnessLoss(nn.Module):
  73. r"""Criterion that computes image-aware inverse depth smoothness loss.
  74. .. math::
  75. \text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \|
  76. \partial_x I_{ij} \right \|} + \left |
  77. \partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}
  78. Shape:
  79. - Inverse Depth: :math:`(N, 1, H, W)`
  80. - Image: :math:`(N, 3, H, W)`
  81. - Output: scalar
  82. Examples:
  83. >>> idepth = torch.rand(1, 1, 4, 5)
  84. >>> image = torch.rand(1, 3, 4, 5)
  85. >>> smooth = InverseDepthSmoothnessLoss()
  86. >>> loss = smooth(idepth, image)
  87. """
  88. def forward(self, idepth: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
  89. return inverse_depth_smoothness_loss(idepth, image)