_box_convert.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import torch
  2. from torch import Tensor
  3. def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
  4. """
  5. Converts bounding boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format.
  6. (cx, cy) refers to center of bounding box
  7. (w, h) are width and height of bounding box
  8. Args:
  9. boxes (Tensor[N, 4]): boxes in (cx, cy, w, h) format which will be converted.
  10. Returns:
  11. boxes (Tensor(N, 4)): boxes in (x1, y1, x2, y2) format.
  12. """
  13. # We need to change all 4 of them so some temporary variable is needed.
  14. cx, cy, w, h = boxes.unbind(-1)
  15. x1 = cx - 0.5 * w
  16. y1 = cy - 0.5 * h
  17. x2 = cx + 0.5 * w
  18. y2 = cy + 0.5 * h
  19. boxes = torch.stack((x1, y1, x2, y2), dim=-1)
  20. return boxes
  21. def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
  22. """
  23. Converts bounding boxes from (x1, y1, x2, y2) format to (cx, cy, w, h) format.
  24. (x1, y1) refer to top left of bounding box
  25. (x2, y2) refer to bottom right of bounding box
  26. Args:
  27. boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format which will be converted.
  28. Returns:
  29. boxes (Tensor(N, 4)): boxes in (cx, cy, w, h) format.
  30. """
  31. x1, y1, x2, y2 = boxes.unbind(-1)
  32. cx = (x1 + x2) / 2
  33. cy = (y1 + y2) / 2
  34. w = x2 - x1
  35. h = y2 - y1
  36. boxes = torch.stack((cx, cy, w, h), dim=-1)
  37. return boxes
  38. def _box_xywh_to_xyxy(boxes: Tensor) -> Tensor:
  39. """
  40. Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format.
  41. (x, y) refers to top left of bounding box.
  42. (w, h) refers to width and height of box.
  43. Args:
  44. boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted.
  45. Returns:
  46. boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format.
  47. """
  48. x, y, w, h = boxes.unbind(-1)
  49. boxes = torch.stack([x, y, x + w, y + h], dim=-1)
  50. return boxes
  51. def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor:
  52. """
  53. Converts bounding boxes from (x1, y1, x2, y2) format to (x, y, w, h) format.
  54. (x1, y1) refer to top left of bounding box
  55. (x2, y2) refer to bottom right of bounding box
  56. Args:
  57. boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) which will be converted.
  58. Returns:
  59. boxes (Tensor[N, 4]): boxes in (x, y, w, h) format.
  60. """
  61. x1, y1, x2, y2 = boxes.unbind(-1)
  62. w = x2 - x1 # x2 - x1
  63. h = y2 - y1 # y2 - y1
  64. boxes = torch.stack((x1, y1, w, h), dim=-1)
  65. return boxes
  66. def _box_cxcywhr_to_xywhr(boxes: Tensor) -> Tensor:
  67. """
  68. Converts rotated bounding boxes from (cx, cy, w, h, r) format to (x1, y1, w, h, r) format.
  69. (cx, cy) refers to center of bounding box
  70. (w, h) refers to width and height of rotated bounding box
  71. (x1, y1) refers to top left of rotated bounding box
  72. r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
  73. Args:
  74. boxes (Tensor[N, 5]): boxes in (cx, cy, w, h, r) format which will be converted.
  75. Returns:
  76. boxes (Tensor(N, 5)): rotated boxes in (x1, y1, w, h, r) format.
  77. """
  78. dtype = boxes.dtype
  79. need_cast = not boxes.is_floating_point()
  80. cx, cy, w, h, r = boxes.unbind(-1)
  81. r_rad = r * torch.pi / 180.0
  82. cos, sin = torch.cos(r_rad), torch.sin(r_rad)
  83. x1 = cx - w / 2 * cos - h / 2 * sin
  84. y1 = cy - h / 2 * cos + w / 2 * sin
  85. boxes = torch.stack((x1, y1, w, h, r), dim=-1)
  86. if need_cast:
  87. boxes.round_()
  88. boxes = boxes.to(dtype)
  89. return boxes
  90. def _box_xywhr_to_cxcywhr(boxes: Tensor) -> Tensor:
  91. """
  92. Converts rotated bounding boxes from (x1, y1, w, h, r) format to (cx, cy, w, h, r) format.
  93. (x1, y1) refers to top left of rotated bounding box
  94. (w, h) refers to width and height of rotated bounding box
  95. r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
  96. Args:
  97. boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format which will be converted.
  98. Returns:
  99. boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format.
  100. """
  101. dtype = boxes.dtype
  102. need_cast = not boxes.is_floating_point()
  103. x1, y1, w, h, r = boxes.unbind(-1)
  104. r_rad = r * torch.pi / 180.0
  105. cos, sin = torch.cos(r_rad), torch.sin(r_rad)
  106. cx = x1 + w / 2 * cos + h / 2 * sin
  107. cy = y1 - w / 2 * sin + h / 2 * cos
  108. boxes = torch.stack([cx, cy, w, h, r], dim=-1)
  109. if need_cast:
  110. boxes.round_()
  111. boxes = boxes.to(dtype)
  112. return boxes
  113. def _box_xywhr_to_xyxyxyxy(boxes: Tensor) -> Tensor:
  114. """
  115. Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x2, y2, x3, y3, x4, y4) format.
  116. (x1, y1) refer to top left of bounding box
  117. (w, h) are width and height of the rotated bounding box
  118. r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
  119. (x1, y1) refer to top left of rotated bounding box
  120. (x2, y2) refer to top right of rotated bounding box
  121. (x3, y3) refer to bottom right of rotated bounding box
  122. (x4, y4) refer to bottom left ofrotated bounding box
  123. Args:
  124. boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format which will be converted.
  125. Returns:
  126. boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format.
  127. """
  128. dtype = boxes.dtype
  129. need_cast = not boxes.is_floating_point()
  130. x1, y1, w, h, r = boxes.unbind(-1)
  131. r_rad = r * torch.pi / 180.0
  132. cos, sin = torch.cos(r_rad), torch.sin(r_rad)
  133. x2 = x1 + w * cos
  134. y2 = y1 - w * sin
  135. x3 = x2 + h * sin
  136. y3 = y2 + h * cos
  137. x4 = x1 + h * sin
  138. y4 = y1 + h * cos
  139. boxes = torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1)
  140. if need_cast:
  141. boxes.round_()
  142. boxes = boxes.to(dtype)
  143. return boxes
  144. def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor:
  145. """
  146. Converts rotated bounding boxes from (x1, y1, x2, y2, x3, y3, x4, y4) format to (x1, y1, w, h, r) format.
  147. (x1, y1) refer to top left of the rotated bounding box
  148. (x2, y2) refer to bottom left of the rotated bounding box
  149. (x3, y3) refer to bottom right of the rotated bounding box
  150. (x4, y4) refer to top right of the rotated bounding box
  151. (w, h) refers to width and height of rotated bounding box
  152. r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
  153. Args:
  154. boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format.
  155. Returns:
  156. boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format.
  157. """
  158. dtype = boxes.dtype
  159. need_cast = not boxes.is_floating_point()
  160. x1, y1, x2, y2, x3, y3, x4, y4 = boxes.unbind(-1)
  161. r_rad = torch.atan2(y1 - y2, x2 - x1)
  162. r = r_rad * 180 / torch.pi
  163. w = ((x2 - x1) ** 2 + (y1 - y2) ** 2).sqrt()
  164. h = ((x3 - x2) ** 2 + (y3 - y2) ** 2).sqrt()
  165. boxes = torch.stack((x1, y1, w, h, r), dim=-1)
  166. if need_cast:
  167. boxes.round_()
  168. boxes = boxes.to(dtype)
  169. return boxes