structures.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from dataclasses import dataclass, field
  18. @dataclass
  19. class HeatMapRefineCfg:
  20. mode: str = "local"
  21. ratio: float = 0.2
  22. valid_thresh: float = 0.001
  23. num_blocks: int = 20
  24. overlap_ratio: float = 0.5
  25. @dataclass
  26. class JunctionRefineCfg:
  27. num_perturbs: int = 9
  28. perturb_interval: float = 0.25
  29. @dataclass
  30. class LineDetectorCfg:
  31. detect_thresh: float = 0.5
  32. num_samples: int = 64
  33. inlier_thresh: float = 0.99
  34. use_candidate_suppression: bool = True
  35. nms_dist_tolerance: float = 3.0
  36. heatmap_low_thresh: float = 0.15
  37. heatmap_high_thresh: float = 0.2
  38. max_local_patch_radius: float = 3
  39. lambda_radius: float = 2.0
  40. use_heatmap_refinement: bool = True
  41. heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg)
  42. use_junction_refinement: bool = True
  43. junction_refine_cfg: JunctionRefineCfg = field(default_factory=JunctionRefineCfg)
  44. @dataclass
  45. class LineMatcherCfg:
  46. cross_check: bool = True
  47. num_samples: int = 5
  48. min_dist_pts: int = 8
  49. top_k_candidates: int = 10
  50. grid_size: int = 4
  51. line_score: bool = False # True to compute saliency on a line
  52. @dataclass
  53. class BackboneCfg:
  54. input_channel: int = 1
  55. depth: int = 4
  56. num_stacks: int = 2
  57. num_blocks: int = 1
  58. num_classes: int = 5
  59. @dataclass
  60. class DetectorCfg:
  61. backbone_cfg: BackboneCfg = field(default_factory=BackboneCfg)
  62. use_descriptor: bool = False
  63. grid_size: int = 8
  64. keep_border_valid: bool = True
  65. detection_thresh: float = 0.0153846 # = 1/65: threshold of junction detection
  66. max_num_junctions: int = 500 # maximum number of junctions per image
  67. line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg)
  68. line_matcher_cfg: LineMatcherCfg = field(default_factory=LineMatcherCfg)