descriptor.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch.nn.functional as F
  18. from kornia.core import Module, Tensor
  19. class DeDoDeDescriptor(Module):
  20. def __init__(self, encoder: Module, decoder: Module, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
  21. super().__init__(*args, **kwargs)
  22. self.encoder = encoder
  23. self.decoder = decoder
  24. def forward(
  25. self,
  26. images: Tensor,
  27. ) -> Tensor:
  28. features, sizes = self.encoder(images)
  29. context = None
  30. scales = self.decoder.scales
  31. for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
  32. if idx == 0:
  33. descriptions, context = self.decoder(feature_map, scale=scale, context=context)
  34. else:
  35. delta_descriptions, context = self.decoder(feature_map, scale=scale, context=context)
  36. descriptions = descriptions + delta_descriptions
  37. if idx < len(scales) - 1:
  38. size = sizes[-(idx + 2)]
  39. descriptions = F.interpolate(descriptions, size=size, mode="bilinear", align_corners=False)
  40. context = F.interpolate(context, size=size, mode="bilinear", align_corners=False)
  41. return descriptions