detector.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch.nn.functional as F
  18. from torch import nn
  19. from kornia.core import Module, Tensor
  20. class DeDoDeDetector(nn.Module):
  21. def __init__(self, encoder: Module, decoder: Module, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
  22. super().__init__(*args, **kwargs)
  23. self.encoder = encoder
  24. self.decoder = decoder
  25. def forward(
  26. self,
  27. images: Tensor,
  28. ) -> Tensor:
  29. dtype = images.dtype
  30. features, sizes = self.encoder(images)
  31. context = None
  32. logits = None
  33. scales = ["8", "4", "2", "1"]
  34. for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
  35. delta_logits, context = self.decoder(feature_map, context=context, scale=scale)
  36. if logits is None:
  37. logits = delta_logits
  38. else:
  39. logits = logits + delta_logits.float() # ensure float (need bf16 doesn't have f.interpolate)
  40. if idx < len(scales) - 1:
  41. size = sizes[-(idx + 2)]
  42. logits = F.interpolate(logits, size=size, mode="bicubic", align_corners=False)
  43. context = F.interpolate(context.float(), size=size, mode="bilinear", align_corners=False)
  44. return logits.to(dtype) # type: ignore