draw.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import List, Optional, Tuple, Union
  18. import torch
  19. from torch import Tensor
  20. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
  21. # TODO: implement width of the line
  22. def draw_point2d(image: Tensor, points: Tensor, color: Tensor) -> Tensor:
  23. r"""Set one or more coordinates in a Tensor to a color.
  24. Args:
  25. image: the input image on which to draw the points with shape :math`(C,H,W)` or :math`(H,W)`.
  26. points: the [x, y] points to be drawn on the image.
  27. color: the color of the pixel with :math`(C)` where :math`C` is the number of channels of the image.
  28. Return:
  29. The image with points set to the color.
  30. """
  31. KORNIA_CHECK(
  32. (len(image.shape) == 2 and len(color.shape) == 1) or (image.shape[0] == color.shape[0]),
  33. "Color dim must match the channel dims of the provided image",
  34. )
  35. points = points.to(dtype=torch.int64, device=image.device)
  36. x, y = zip(*points)
  37. if len(color.shape) == 1:
  38. color = torch.unsqueeze(color, dim=1)
  39. color = color.to(dtype=image.dtype, device=image.device)
  40. if len(image.shape) == 2:
  41. image[y, x] = color
  42. else:
  43. image[:, y, x] = color
  44. return image
  45. def _draw_pixel(image: torch.Tensor, x: int, y: int, color: torch.Tensor) -> None:
  46. r"""Draws a pixel into an image.
  47. Args:
  48. image: the input image to where to draw the lines with shape :math`(C,H,W)`.
  49. x: the x coordinate of the pixel.
  50. y: the y coordinate of the pixel.
  51. color: the color of the pixel with :math`(C)` where :math`C` is the number of channels of the image.
  52. Return:
  53. Nothing is returned.
  54. """
  55. image[:, y, x] = color
  56. def draw_line(image: torch.Tensor, p1: torch.Tensor, p2: torch.Tensor, color: torch.Tensor) -> torch.Tensor:
  57. r"""Draw a single line into an image.
  58. Args:
  59. image: the input image to where to draw the lines with shape :math`(C,H,W)`.
  60. p1: the start point [x y] of the line with shape (2, ) or (B, 2).
  61. p2: the end point [x y] of the line with shape (2, ) or (B, 2).
  62. color: the color of the line with shape :math`(C)` where :math`C` is the number of channels of the image.
  63. Return:
  64. the image with containing the line.
  65. Examples:
  66. >>> image = torch.zeros(1, 8, 8)
  67. >>> draw_line(image, torch.tensor([6, 4]), torch.tensor([1, 4]), torch.tensor([255]))
  68. tensor([[[ 0., 0., 0., 0., 0., 0., 0., 0.],
  69. [ 0., 0., 0., 0., 0., 0., 0., 0.],
  70. [ 0., 0., 0., 0., 0., 0., 0., 0.],
  71. [ 0., 0., 0., 0., 0., 0., 0., 0.],
  72. [ 0., 255., 255., 255., 255., 255., 255., 0.],
  73. [ 0., 0., 0., 0., 0., 0., 0., 0.],
  74. [ 0., 0., 0., 0., 0., 0., 0., 0.],
  75. [ 0., 0., 0., 0., 0., 0., 0., 0.]]])
  76. """
  77. if (p1.shape[0] != p2.shape[0]) or (p1.shape[-1] != 2 or p2.shape[-1] != 2):
  78. raise ValueError(
  79. "Input points must be 2D points with shape (2, ) or (B, 2) and must have the same batch sizes."
  80. )
  81. if (
  82. (p1[..., 0] < 0).any()
  83. or (p1[..., 0] >= image.shape[-1]).any()
  84. or (p1[..., 1] < 0).any()
  85. or (p1[..., 1] >= image.shape[-2]).any()
  86. ):
  87. raise ValueError("p1 is out of bounds.")
  88. if (
  89. (p2[..., 0] < 0).any()
  90. or (p2[..., 0] >= image.shape[-1]).any()
  91. or (p2[..., 1] < 0).any()
  92. or (p2[..., 1] >= image.shape[-2]).any()
  93. ):
  94. raise ValueError("p2 is out of bounds.")
  95. if len(image.size()) != 3:
  96. raise ValueError("image must have 3 dimensions (C,H,W).")
  97. if color.size(0) != image.size(0):
  98. raise ValueError("color must have the same number of channels as the image.")
  99. # move p1 and p2 to the same device as the input image
  100. # move color to the same device and dtype as the input image
  101. p1 = p1.to(image.device).to(torch.int64)
  102. p2 = p2.to(image.device).to(torch.int64)
  103. color = color.to(image)
  104. x1, y1 = p1[..., 0], p1[..., 1]
  105. x2, y2 = p2[..., 0], p2[..., 1]
  106. dx = x2 - x1
  107. dy = y2 - y1
  108. dx_sign = torch.sign(dx)
  109. dy_sign = torch.sign(dy)
  110. dx, dy = torch.abs(dx), torch.abs(dy)
  111. dx_zero_mask = dx == 0
  112. dy_zero_mask = dy == 0
  113. dx_gt_dy_mask = (dx > dy) & ~(dx_zero_mask | dy_zero_mask)
  114. rest_mask = ~(dx_zero_mask | dy_zero_mask | dx_gt_dy_mask)
  115. dx_zero_x_coords, dx_zero_y_coords = [], []
  116. dy_zero_x_coords, dy_zero_y_coords = [], []
  117. dx_gt_dy_x_coords, dx_gt_dy_y_coords = [], []
  118. rest_x_coords, rest_y_coords = [], []
  119. if dx_zero_mask.any():
  120. dx_zero_x_coords = [
  121. x for x_i, dy_i in zip(x1[dx_zero_mask], dy[dx_zero_mask]) for x in x_i.repeat(int(dy_i.item() + 1))
  122. ]
  123. dx_zero_y_coords = [
  124. y
  125. for y_i, s, dy_ in zip(y1[dx_zero_mask], dy_sign[dx_zero_mask], dy[dx_zero_mask])
  126. for y in (y_i + s * torch.arange(0, dy_ + 1, 1, device=image.device))
  127. ]
  128. if dy_zero_mask.any():
  129. dy_zero_x_coords = [
  130. x
  131. for x_i, s, dx_i in zip(x1[dy_zero_mask], dx_sign[dy_zero_mask], dx[dy_zero_mask])
  132. for x in (x_i + s * torch.arange(0, dx_i + 1, 1, device=image.device))
  133. ]
  134. dy_zero_y_coords = [
  135. y for y_i, dx_i in zip(y1[dy_zero_mask], dx[dy_zero_mask]) for y in y_i.repeat(int(dx_i.item() + 1))
  136. ]
  137. if dx_gt_dy_mask.any():
  138. dx_gt_dy_x_coords = [
  139. x
  140. for x_i, s, dx_i in zip(x1[dx_gt_dy_mask], dx_sign[dx_gt_dy_mask], dx[dx_gt_dy_mask])
  141. for x in (x_i + s * torch.arange(0, dx_i + 1, 1, device=image.device))
  142. ]
  143. dx_gt_dy_y_coords = [
  144. y
  145. for y_i, s, dx_i, dy_i in zip(
  146. y1[dx_gt_dy_mask], dy_sign[dx_gt_dy_mask], dx[dx_gt_dy_mask], dy[dx_gt_dy_mask]
  147. )
  148. for y in (
  149. y_i + s * torch.arange(0, dy_i + 1, dy_i / dx_i, device=image.device)[: int(dx_i.item()) + 1].ceil()
  150. )
  151. ]
  152. if rest_mask.any():
  153. rest_x_coords = [
  154. x
  155. for x_i, s, dx_i, dy_ in zip(x1[rest_mask], dx_sign[rest_mask], dx[rest_mask], dy[rest_mask])
  156. for x in (
  157. x_i + s * torch.arange(0, dx_i + 1, dx_i / dy_, device=image.device)[: int(dy_.item()) + 1].ceil()
  158. )
  159. ]
  160. rest_y_coords = [
  161. y
  162. for y_i, s, dy_i in zip(y1[rest_mask], dy_sign[rest_mask], dy[rest_mask])
  163. for y in (y_i + s * torch.arange(0, dy_i + 1, 1, device=image.device))
  164. ]
  165. x_coords = torch.clamp(
  166. torch.tensor(dx_zero_x_coords + dy_zero_x_coords + dx_gt_dy_x_coords + rest_x_coords).long(),
  167. min=0,
  168. max=image.shape[-1] - 1,
  169. )
  170. y_coords = torch.clamp(
  171. torch.tensor(dx_zero_y_coords + dy_zero_y_coords + dx_gt_dy_y_coords + rest_y_coords).long(),
  172. min=0,
  173. max=image.shape[-2] - 1,
  174. )
  175. image[:, y_coords, x_coords] = color.view(-1, 1)
  176. return image
  177. def draw_rectangle(
  178. image: torch.Tensor, rectangle: torch.Tensor, color: Optional[torch.Tensor] = None, fill: Optional[bool] = None
  179. ) -> torch.Tensor:
  180. r"""Draw N rectangles on a batch of image tensors.
  181. Args:
  182. image: is tensor of BxCxHxW.
  183. rectangle: represents number of rectangles to draw in BxNx4
  184. N is the number of boxes to draw per batch index[x1, y1, x2, y2]
  185. 4 is in (top_left.x, top_left.y, bot_right.x, bot_right.y).
  186. color: a size 1, size 3, BxNx1, or BxNx3 tensor.
  187. If C is 3, and color is 1 channel it will be broadcasted.
  188. fill: is a flag used to fill the boxes with color if True.
  189. Returns:
  190. This operation modifies image inplace but also returns the drawn tensor for
  191. convenience with same shape the of the input BxCxHxW.
  192. Example:
  193. >>> img = torch.rand(2, 3, 10, 12)
  194. >>> rect = torch.tensor([[[0, 0, 4, 4]], [[4, 4, 10, 10]]])
  195. >>> out = draw_rectangle(img, rect)
  196. """
  197. batch, c, h, w = image.shape
  198. batch_rect, num_rectangle, num_points = rectangle.shape
  199. if batch != batch_rect:
  200. raise AssertionError("Image batch and rectangle batch must be equal")
  201. if num_points != 4:
  202. raise AssertionError("Number of points in rectangle must be 4")
  203. # clone rectangle, in case it's been expanded assignment from clipping causes problems
  204. rectangle = rectangle.long().clone()
  205. # clip rectangle to hxw bounds
  206. rectangle[:, :, 1::2] = torch.clamp(rectangle[:, :, 1::2], 0, h - 1)
  207. rectangle[:, :, ::2] = torch.clamp(rectangle[:, :, ::2], 0, w - 1)
  208. if color is None:
  209. color = torch.tensor([0.0] * c).expand(batch, num_rectangle, c)
  210. if fill is None:
  211. fill = False
  212. if len(color.shape) == 1:
  213. color = color.expand(batch, num_rectangle, c)
  214. b, n, color_channels = color.shape
  215. if color_channels == 1 and c == 3:
  216. color = color.expand(batch, num_rectangle, c)
  217. for b in range(batch):
  218. for n in range(num_rectangle):
  219. if fill:
  220. image[
  221. b,
  222. :,
  223. int(rectangle[b, n, 1]) : int(rectangle[b, n, 3] + 1),
  224. int(rectangle[b, n, 0]) : int(rectangle[b, n, 2] + 1),
  225. ] = color[b, n, :, None, None]
  226. else:
  227. image[b, :, int(rectangle[b, n, 1]) : int(rectangle[b, n, 3] + 1), rectangle[b, n, 0]] = color[
  228. b, n, :, None
  229. ]
  230. image[b, :, int(rectangle[b, n, 1]) : int(rectangle[b, n, 3] + 1), rectangle[b, n, 2]] = color[
  231. b, n, :, None
  232. ]
  233. image[b, :, rectangle[b, n, 1], int(rectangle[b, n, 0]) : int(rectangle[b, n, 2] + 1)] = color[
  234. b, n, :, None
  235. ]
  236. image[b, :, rectangle[b, n, 3], int(rectangle[b, n, 0]) : int(rectangle[b, n, 2] + 1)] = color[
  237. b, n, :, None
  238. ]
  239. return image
  240. def _get_convex_edges(polygon: Tensor, h: int, w: int) -> Tuple[Tensor, Tensor]:
  241. r"""Get the left and right edges of a polygon for each y-coordinate y \in [0, h).
  242. Args:
  243. polygon: represents polygons to draw in BxNx2
  244. N is the number of points
  245. 2 is (x, y).
  246. h: bottom most coordinate (top coordinate is assumed to be 0)
  247. w: right most coordinate (left coordinate is assumed to be 0)
  248. Returns:
  249. The left and right edges of the polygon of shape (B,B).
  250. """
  251. dtype = polygon.dtype
  252. # Check if polygons are in loop closed format, if not -> make it so
  253. if not torch.allclose(polygon[..., -1, :], polygon[..., 0, :]):
  254. polygon = torch.cat((polygon, polygon[..., :1, :]), dim=-2) # (B, N+1, 2)
  255. # Partition points into edges
  256. x_start, y_start = polygon[..., :-1, 0], polygon[..., :-1, 1]
  257. x_end, y_end = polygon[..., 1:, 0], polygon[..., 1:, 1]
  258. # Create scanlines, edge dx/dy, and produce x values
  259. ys = torch.arange(h, device=polygon.device, dtype=dtype)
  260. dx = ((x_end - x_start) / (y_end - y_start + 1e-12)).clamp(-w, w)
  261. xs = (ys[..., :, None] - y_start[..., None, :]) * dx[..., None, :] + x_start[..., None, :]
  262. # Only count edge in their active regions (i.e between the vertices)
  263. valid_edges = (y_start[..., None, :] <= ys[..., :, None]).logical_and(ys[..., :, None] <= y_end[..., None, :])
  264. valid_edges |= (y_start[..., None, :] >= ys[..., :, None]).logical_and(ys[..., :, None] >= y_end[..., None, :])
  265. x_left_edges = xs.clone()
  266. x_left_edges[~valid_edges] = w
  267. x_right_edges = xs.clone()
  268. x_right_edges[~valid_edges] = -1
  269. # Find smallest and largest x values for the valid edges
  270. x_left = x_left_edges.min(dim=-1).values
  271. x_right = x_right_edges.max(dim=-1).values
  272. return x_left, x_right
  273. def _batch_polygons(polygons: List[Tensor]) -> Tensor:
  274. r"""Convert a List of variable length polygons into a fixed size tensor.
  275. Works by repeating the last element in the tensor.
  276. Args:
  277. polygons: List of variable length polygons of shape [N_1 x 2, N_2 x 2, ..., N_B x 2].
  278. B is the batch size,
  279. N_i is the number of points,
  280. 2 is (x, y).
  281. Returns:
  282. A fixed size tensor of shape (B, N, 2) where N = max_i(N_i)
  283. """
  284. B, N = len(polygons), len(max(polygons, key=len))
  285. batched_polygons = torch.zeros(B, N, 2, dtype=polygons[0].dtype, device=polygons[0].device)
  286. for b, p in enumerate(polygons):
  287. batched_polygons[b] = torch.cat((p, p[-1:].expand(N - len(p), 2))) if len(p) < N else p
  288. return batched_polygons
  289. def draw_convex_polygon(images: Tensor, polygons: Union[Tensor, List[Tensor]], colors: Tensor) -> Tensor:
  290. r"""Draws convex polygons on a batch of image tensors.
  291. Args:
  292. images: is tensor of BxCxHxW.
  293. polygons: represents polygons as points, either BxNx2 or List of variable length polygons.
  294. N is the number of points.
  295. 2 is (x, y).
  296. colors: a B x 3 tensor or 3 tensor with color to fill in.
  297. Returns:
  298. This operation modifies image inplace but also returns the drawn tensor for
  299. convenience with same shape the of the input BxCxHxW.
  300. Note:
  301. This function assumes a coordinate system (0, h - 1), (0, w - 1) in the image, with (0, 0) being the center
  302. of the top-left pixel and (w - 1, h - 1) being the center of the bottom-right coordinate.
  303. Example:
  304. >>> img = torch.rand(1, 3, 12, 16)
  305. >>> poly = torch.tensor([[[4, 4], [12, 4], [12, 8], [4, 8]]])
  306. >>> color = torch.tensor([[0.5, 0.5, 0.5]])
  307. >>> out = draw_convex_polygon(img, poly, color)
  308. """
  309. # TODO: implement optional linetypes for smooth edges
  310. KORNIA_CHECK_SHAPE(images, ["B", "C", "H", "W"])
  311. b_i, c_i, h_i, w_i, device = *images.shape, images.device
  312. if isinstance(polygons, List):
  313. polygons = _batch_polygons(polygons)
  314. b_p, _, xy, device_p, dtype_p = *polygons.shape, polygons.device, polygons.dtype
  315. if len(colors.shape) == 1:
  316. colors = colors.expand(b_i, c_i)
  317. b_c, _, device_c = *colors.shape, colors.device
  318. KORNIA_CHECK(xy == 2, "Polygon vertices must be xy, i.e. 2-dimensional")
  319. KORNIA_CHECK(b_i == b_p == b_c, "Image, polygon, and color must have same batch dimension")
  320. KORNIA_CHECK(device == device_p == device_c, "Image, polygon, and color must have same device")
  321. x_left, x_right = _get_convex_edges(polygons, h_i, w_i)
  322. ws = torch.arange(w_i, device=device, dtype=dtype_p)[None, None, :]
  323. fill_region = (ws >= x_left[..., :, None]) & (ws <= x_right[..., :, None])
  324. images.mul_(~fill_region[:, None]).add_(fill_region[:, None] * colors[..., None, None])
  325. return images