mask_utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from __future__ import annotations
  2. import cv2
  3. import numpy as np
  4. from tqdm.auto import tqdm
  5. from ultralytics.engine.results import Results
  6. from ultralytics.models.yolo.segment import SegmentationPredictor
  7. from ultralytics.utils.ops import scale_image
  8. import wandb
  9. from wandb.integration.ultralytics.bbox_utils import (
  10. get_ground_truth_bbox_annotations,
  11. get_mean_confidence_map,
  12. )
  13. def instance_mask_to_semantic_mask(instance_mask, class_indices):
  14. height, width, num_instances = instance_mask.shape
  15. semantic_mask = np.zeros((height, width), dtype=np.uint8)
  16. for i in range(num_instances):
  17. instance_map = instance_mask[:, :, i]
  18. class_index = class_indices[i]
  19. semantic_mask[instance_map == 1] = class_index
  20. return semantic_mask
  21. def get_boxes_and_masks(result: Results) -> tuple[dict, dict, dict]:
  22. boxes = result.boxes.xywh.long().numpy()
  23. classes = result.boxes.cls.long().numpy()
  24. confidence = result.boxes.conf.numpy()
  25. class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
  26. class_id_to_label.update({len(result.names.items()): "background"})
  27. mean_confidence_map = get_mean_confidence_map(
  28. classes, confidence, class_id_to_label
  29. )
  30. masks = None
  31. if result.masks is not None:
  32. scaled_instance_mask = scale_image(
  33. np.transpose(result.masks.data.numpy(), (1, 2, 0)),
  34. result.orig_img[:, :, ::-1].shape,
  35. )
  36. scaled_semantic_mask = instance_mask_to_semantic_mask(
  37. scaled_instance_mask, classes.tolist()
  38. )
  39. scaled_semantic_mask[scaled_semantic_mask == 0] = len(result.names.items())
  40. masks = {
  41. "predictions": {
  42. "mask_data": scaled_semantic_mask,
  43. "class_labels": class_id_to_label,
  44. }
  45. }
  46. box_data, total_confidence = [], 0.0
  47. for idx in range(len(boxes)):
  48. box_data.append(
  49. {
  50. "position": {
  51. "middle": [int(boxes[idx][0]), int(boxes[idx][1])],
  52. "width": int(boxes[idx][2]),
  53. "height": int(boxes[idx][3]),
  54. },
  55. "domain": "pixel",
  56. "class_id": int(classes[idx]),
  57. "box_caption": class_id_to_label[int(classes[idx])],
  58. "scores": {"confidence": float(confidence[idx])},
  59. }
  60. )
  61. total_confidence += float(confidence[idx])
  62. boxes = {
  63. "predictions": {
  64. "box_data": box_data,
  65. "class_labels": class_id_to_label,
  66. },
  67. }
  68. return boxes, masks, mean_confidence_map
  69. def plot_mask_predictions(
  70. result: Results, model_name: str, table: wandb.Table | None = None
  71. ) -> tuple[wandb.Image, dict, dict, dict]:
  72. result = result.to("cpu")
  73. boxes, masks, mean_confidence_map = get_boxes_and_masks(result)
  74. image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes, masks=masks)
  75. if table is not None:
  76. table.add_data(
  77. model_name,
  78. image,
  79. len(boxes["predictions"]["box_data"]),
  80. mean_confidence_map,
  81. result.speed,
  82. )
  83. return table
  84. return image, masks, boxes["predictions"], mean_confidence_map
  85. def structure_prompts_and_image(image: np.array, prompt: dict) -> dict:
  86. wb_box_data = []
  87. if prompt["bboxes"] is not None:
  88. wb_box_data.append(
  89. {
  90. "position": {
  91. "middle": [prompt["bboxes"][0], prompt["bboxes"][1]],
  92. "width": prompt["bboxes"][2],
  93. "height": prompt["bboxes"][3],
  94. },
  95. "domain": "pixel",
  96. "class_id": 1,
  97. "box_caption": "Prompt-Box",
  98. }
  99. )
  100. if prompt["points"] is not None:
  101. image = image.copy().astype(np.uint8)
  102. image = cv2.circle(
  103. image, tuple(prompt["points"]), 5, (0, 255, 0), -1, lineType=cv2.LINE_AA
  104. )
  105. wb_box_data = {
  106. "prompts": {
  107. "box_data": wb_box_data,
  108. "class_labels": {1: "Prompt-Box"},
  109. }
  110. }
  111. return image, wb_box_data
  112. def plot_sam_predictions(
  113. result: Results, prompt: dict, table: wandb.Table
  114. ) -> wandb.Table:
  115. result = result.to("cpu")
  116. image = result.orig_img[:, :, ::-1]
  117. image, wb_box_data = structure_prompts_and_image(image, prompt)
  118. image = wandb.Image(
  119. image,
  120. boxes=wb_box_data,
  121. masks={
  122. "predictions": {
  123. "mask_data": np.squeeze(result.masks.data.cpu().numpy().astype(int)),
  124. "class_labels": {0: "Background", 1: "Prediction"},
  125. }
  126. },
  127. )
  128. table.add_data(image)
  129. return table
  130. def plot_segmentation_validation_results(
  131. dataloader,
  132. class_label_map,
  133. model_name: str,
  134. predictor: SegmentationPredictor,
  135. table: wandb.Table,
  136. max_validation_batches: int,
  137. epoch: int | None = None,
  138. ):
  139. data_idx = 0
  140. num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
  141. max_validation_batches = min(max_validation_batches, num_dataloader_batches)
  142. for batch_idx, batch in enumerate(dataloader):
  143. prediction_results = predictor(batch["im_file"])
  144. progress_bar_result_iterable = tqdm(
  145. enumerate(prediction_results),
  146. total=len(prediction_results),
  147. desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
  148. )
  149. for img_idx, prediction_result in progress_bar_result_iterable:
  150. prediction_result = prediction_result.to("cpu")
  151. (
  152. _,
  153. prediction_mask_data,
  154. prediction_box_data,
  155. mean_confidence_map,
  156. ) = plot_mask_predictions(prediction_result, model_name)
  157. try:
  158. ground_truth_data = get_ground_truth_bbox_annotations(
  159. img_idx, batch["im_file"][img_idx], batch, class_label_map
  160. )
  161. wandb_image = wandb.Image(
  162. batch["im_file"][img_idx],
  163. boxes={
  164. "ground-truth": {
  165. "box_data": ground_truth_data,
  166. "class_labels": class_label_map,
  167. },
  168. "predictions": prediction_box_data,
  169. },
  170. masks=prediction_mask_data,
  171. )
  172. table_rows = [
  173. data_idx,
  174. batch_idx,
  175. wandb_image,
  176. mean_confidence_map,
  177. prediction_result.speed,
  178. ]
  179. table_rows = [epoch] + table_rows if epoch is not None else table_rows
  180. table_rows = [model_name] + table_rows
  181. table.add_data(*table_rows)
  182. data_idx += 1
  183. except TypeError:
  184. pass
  185. if batch_idx + 1 == max_validation_batches:
  186. break
  187. return table