| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- from __future__ import annotations
- import cv2
- import numpy as np
- from tqdm.auto import tqdm
- from ultralytics.engine.results import Results
- from ultralytics.models.yolo.segment import SegmentationPredictor
- from ultralytics.utils.ops import scale_image
- import wandb
- from wandb.integration.ultralytics.bbox_utils import (
- get_ground_truth_bbox_annotations,
- get_mean_confidence_map,
- )
- def instance_mask_to_semantic_mask(instance_mask, class_indices):
- height, width, num_instances = instance_mask.shape
- semantic_mask = np.zeros((height, width), dtype=np.uint8)
- for i in range(num_instances):
- instance_map = instance_mask[:, :, i]
- class_index = class_indices[i]
- semantic_mask[instance_map == 1] = class_index
- return semantic_mask
- def get_boxes_and_masks(result: Results) -> tuple[dict, dict, dict]:
- boxes = result.boxes.xywh.long().numpy()
- classes = result.boxes.cls.long().numpy()
- confidence = result.boxes.conf.numpy()
- class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
- class_id_to_label.update({len(result.names.items()): "background"})
- mean_confidence_map = get_mean_confidence_map(
- classes, confidence, class_id_to_label
- )
- masks = None
- if result.masks is not None:
- scaled_instance_mask = scale_image(
- np.transpose(result.masks.data.numpy(), (1, 2, 0)),
- result.orig_img[:, :, ::-1].shape,
- )
- scaled_semantic_mask = instance_mask_to_semantic_mask(
- scaled_instance_mask, classes.tolist()
- )
- scaled_semantic_mask[scaled_semantic_mask == 0] = len(result.names.items())
- masks = {
- "predictions": {
- "mask_data": scaled_semantic_mask,
- "class_labels": class_id_to_label,
- }
- }
- box_data, total_confidence = [], 0.0
- for idx in range(len(boxes)):
- box_data.append(
- {
- "position": {
- "middle": [int(boxes[idx][0]), int(boxes[idx][1])],
- "width": int(boxes[idx][2]),
- "height": int(boxes[idx][3]),
- },
- "domain": "pixel",
- "class_id": int(classes[idx]),
- "box_caption": class_id_to_label[int(classes[idx])],
- "scores": {"confidence": float(confidence[idx])},
- }
- )
- total_confidence += float(confidence[idx])
- boxes = {
- "predictions": {
- "box_data": box_data,
- "class_labels": class_id_to_label,
- },
- }
- return boxes, masks, mean_confidence_map
- def plot_mask_predictions(
- result: Results, model_name: str, table: wandb.Table | None = None
- ) -> tuple[wandb.Image, dict, dict, dict]:
- result = result.to("cpu")
- boxes, masks, mean_confidence_map = get_boxes_and_masks(result)
- image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes, masks=masks)
- if table is not None:
- table.add_data(
- model_name,
- image,
- len(boxes["predictions"]["box_data"]),
- mean_confidence_map,
- result.speed,
- )
- return table
- return image, masks, boxes["predictions"], mean_confidence_map
- def structure_prompts_and_image(image: np.array, prompt: dict) -> dict:
- wb_box_data = []
- if prompt["bboxes"] is not None:
- wb_box_data.append(
- {
- "position": {
- "middle": [prompt["bboxes"][0], prompt["bboxes"][1]],
- "width": prompt["bboxes"][2],
- "height": prompt["bboxes"][3],
- },
- "domain": "pixel",
- "class_id": 1,
- "box_caption": "Prompt-Box",
- }
- )
- if prompt["points"] is not None:
- image = image.copy().astype(np.uint8)
- image = cv2.circle(
- image, tuple(prompt["points"]), 5, (0, 255, 0), -1, lineType=cv2.LINE_AA
- )
- wb_box_data = {
- "prompts": {
- "box_data": wb_box_data,
- "class_labels": {1: "Prompt-Box"},
- }
- }
- return image, wb_box_data
- def plot_sam_predictions(
- result: Results, prompt: dict, table: wandb.Table
- ) -> wandb.Table:
- result = result.to("cpu")
- image = result.orig_img[:, :, ::-1]
- image, wb_box_data = structure_prompts_and_image(image, prompt)
- image = wandb.Image(
- image,
- boxes=wb_box_data,
- masks={
- "predictions": {
- "mask_data": np.squeeze(result.masks.data.cpu().numpy().astype(int)),
- "class_labels": {0: "Background", 1: "Prediction"},
- }
- },
- )
- table.add_data(image)
- return table
- def plot_segmentation_validation_results(
- dataloader,
- class_label_map,
- model_name: str,
- predictor: SegmentationPredictor,
- table: wandb.Table,
- max_validation_batches: int,
- epoch: int | None = None,
- ):
- data_idx = 0
- num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
- max_validation_batches = min(max_validation_batches, num_dataloader_batches)
- for batch_idx, batch in enumerate(dataloader):
- prediction_results = predictor(batch["im_file"])
- progress_bar_result_iterable = tqdm(
- enumerate(prediction_results),
- total=len(prediction_results),
- desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
- )
- for img_idx, prediction_result in progress_bar_result_iterable:
- prediction_result = prediction_result.to("cpu")
- (
- _,
- prediction_mask_data,
- prediction_box_data,
- mean_confidence_map,
- ) = plot_mask_predictions(prediction_result, model_name)
- try:
- ground_truth_data = get_ground_truth_bbox_annotations(
- img_idx, batch["im_file"][img_idx], batch, class_label_map
- )
- wandb_image = wandb.Image(
- batch["im_file"][img_idx],
- boxes={
- "ground-truth": {
- "box_data": ground_truth_data,
- "class_labels": class_label_map,
- },
- "predictions": prediction_box_data,
- },
- masks=prediction_mask_data,
- )
- table_rows = [
- data_idx,
- batch_idx,
- wandb_image,
- mean_confidence_map,
- prediction_result.speed,
- ]
- table_rows = [epoch] + table_rows if epoch is not None else table_rows
- table_rows = [model_name] + table_rows
- table.add_data(*table_rows)
- data_idx += 1
- except TypeError:
- pass
- if batch_idx + 1 == max_validation_batches:
- break
- return table
|