utils.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, ClassVar, List, Optional, Tuple, Union
  18. from kornia.core import Module, Tensor, rand, tensor
  19. from kornia.core.mixin.onnx import ONNXExportMixin
  20. __all__ = ["BoxFiltering"]
  21. class BoxFiltering(Module, ONNXExportMixin):
  22. """Filter boxes according to the desired threshold.
  23. Args:
  24. confidence_threshold: an 0-d scalar that represents the desired threshold.
  25. classes_to_keep: a 1-d list of classes to keep. If None, keep all classes.
  26. filter_as_zero: whether to filter boxes as zero.
  27. """
  28. ONNX_DEFAULT_INPUTSHAPE: ClassVar[List[int]] = [-1, -1, 6]
  29. ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[List[int]] = [-1, -1, 6]
  30. ONNX_EXPORT_PSEUDO_SHAPE: ClassVar[List[int]] = [5, 20, 6]
  31. def __init__(
  32. self,
  33. confidence_threshold: Optional[Union[Tensor, float]] = None,
  34. classes_to_keep: Optional[Union[Tensor, List[int]]] = None,
  35. filter_as_zero: bool = False,
  36. ) -> None:
  37. super().__init__()
  38. self.filter_as_zero = filter_as_zero
  39. self.classes_to_keep = None
  40. self.confidence_threshold = None
  41. if classes_to_keep is not None:
  42. self.classes_to_keep = classes_to_keep if isinstance(classes_to_keep, Tensor) else tensor(classes_to_keep)
  43. if confidence_threshold is not None:
  44. self.confidence_threshold = (
  45. confidence_threshold or confidence_threshold
  46. if isinstance(confidence_threshold, Tensor)
  47. else tensor(confidence_threshold)
  48. )
  49. def forward(
  50. self, boxes: Tensor, confidence_threshold: Optional[Tensor] = None, classes_to_keep: Optional[Tensor] = None
  51. ) -> Union[Tensor, List[Tensor]]:
  52. """Filter boxes according to the desired threshold.
  53. To be ONNX-friendly, the inputs for direct forwarding need to be all tensors.
  54. Args:
  55. boxes: [B, D, 6], where B is the batchsize, D is the number of detections in the image,
  56. 6 represent (class_id, confidence_score, x, y, w, h).
  57. confidence_threshold: an 0-d scalar that represents the desired threshold.
  58. classes_to_keep: a 1-d tensor of classes to keep. If None, keep all classes.
  59. Returns:
  60. Union[Tensor, List[Tensor]]
  61. If `filter_as_zero` is True, return a tensor of shape [D, 6], where D is the total number of
  62. detections as input.
  63. If `filter_as_zero` is False, return a list of tensors of shape [D, 6], where D is the number of
  64. valid detections for each element in the batch.
  65. """
  66. # Apply confidence filtering
  67. zero_tensor = tensor(0.0, device=boxes.device, dtype=boxes.dtype)
  68. confidence_threshold = (
  69. confidence_threshold or self.confidence_threshold or zero_tensor
  70. ) # If None, use 0 as threshold
  71. confidence_mask = boxes[:, :, 1] > confidence_threshold # [B, D]
  72. # Apply class filtering
  73. classes_to_keep = classes_to_keep or self.classes_to_keep
  74. if classes_to_keep is not None:
  75. class_ids = boxes[:, :, 0:1] # [B, D, 1]
  76. classes_to_keep = classes_to_keep.view(1, 1, -1) # [1, 1, C] for broadcasting
  77. class_mask = (class_ids == classes_to_keep).any(dim=-1) # [B, D]
  78. else:
  79. # If no class filtering is needed, just use a mask of all `True`
  80. class_mask = (confidence_mask * 0 + 1).bool()
  81. # Combine the confidence and class masks
  82. combined_mask = confidence_mask & class_mask # [B, D]
  83. if self.filter_as_zero:
  84. filtered_boxes = boxes * combined_mask[:, :, None]
  85. return filtered_boxes
  86. filtered_boxes_list = []
  87. for i in range(boxes.shape[0]):
  88. box = boxes[i]
  89. mask = combined_mask[i] # [D]
  90. valid_boxes = box[mask]
  91. filtered_boxes_list.append(valid_boxes)
  92. return filtered_boxes_list
  93. def _create_dummy_input(
  94. self, input_shape: List[int], pseudo_shape: Optional[List[int]] = None
  95. ) -> Union[Tuple[Any, ...], Tensor]:
  96. pseudo_input = rand(
  97. *[
  98. ((self.ONNX_EXPORT_PSEUDO_SHAPE[i] if pseudo_shape is None else pseudo_shape[i]) if dim == -1 else dim)
  99. for i, dim in enumerate(input_shape)
  100. ]
  101. )
  102. if self.confidence_threshold is None:
  103. return pseudo_input, 0.1
  104. return pseudo_input