| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- from collections import defaultdict
- from typing import TYPE_CHECKING, Any, Union, overload
- from ..image_utils import load_image
- from ..utils import (
- add_end_docstrings,
- is_torch_available,
- logging,
- requires_backends,
- )
- from .base import ChunkPipeline, build_pipeline_init_args
- if is_torch_available():
- import torch
- from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
- if TYPE_CHECKING:
- from PIL import Image
- logger = logging.get_logger(__name__)
- @add_end_docstrings(
- build_pipeline_init_args(has_image_processor=True),
- r"""
- points_per_batch (*optional*, int, default to 64):
- Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
- memory.
- output_bboxes_mask (`bool`, *optional*, default to `False`):
- Whether or not to output the bounding box predictions.
- output_rle_masks (`bool`, *optional*, default to `False`):
- Whether or not to output the masks in `RLE` format""",
- )
- class MaskGenerationPipeline(ChunkPipeline):
- """
- Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
- image, given an image. It is a `ChunkPipeline` because you can separate the points in a mini-batch in order to
- avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
- same time. Default is `64`.
- The pipeline works in 3 steps:
- 1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
- labels.
- For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
- function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
- `points_per_batch`.
- 2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
- Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
- tensors and models are on the same device.
- 3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
- are induced:
- - image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
- resizes them according
- to the image size, and transforms there to binary masks.
- - image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
- `stability_scores`. Also
- applies a variety of filters based on non maximum suppression to remove bad masks.
- - image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
- Example:
- ```python
- >>> from transformers import pipeline
- >>> generator = pipeline(model="facebook/sam-vit-base", task="mask-generation")
- >>> outputs = generator(
- ... "http://images.cocodataset.org/val2017/000000039769.jpg",
- ... )
- >>> outputs = generator(
- ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
- ... )
- ```
- Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
- This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
- `"mask-generation"`.
- See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
- """
- _load_processor = False
- _load_image_processor = True
- _load_feature_extractor = False
- _load_tokenizer = False
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- requires_backends(self, "vision")
- requires_backends(self, "torch")
- self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
- def _sanitize_parameters(self, **kwargs):
- preprocess_kwargs = {}
- postprocess_kwargs = {}
- forward_params = {}
- # preprocess args
- if "points_per_batch" in kwargs:
- preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
- if "points_per_crop" in kwargs:
- preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
- if "crops_n_layers" in kwargs:
- preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
- if "crop_overlap_ratio" in kwargs:
- preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
- if "crop_n_points_downscale_factor" in kwargs:
- preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
- if "timeout" in kwargs:
- preprocess_kwargs["timeout"] = kwargs["timeout"]
- # postprocess args
- if "pred_iou_thresh" in kwargs:
- forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
- if "stability_score_offset" in kwargs:
- forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
- if "mask_threshold" in kwargs:
- forward_params["mask_threshold"] = kwargs["mask_threshold"]
- if "stability_score_thresh" in kwargs:
- forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
- if "max_hole_area" in kwargs:
- forward_params["max_hole_area"] = kwargs["max_hole_area"]
- if "max_sprinkle_area" in kwargs:
- forward_params["max_sprinkle_area"] = kwargs["max_sprinkle_area"]
- if "crops_nms_thresh" in kwargs:
- postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
- if "output_rle_mask" in kwargs:
- postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
- if "output_bboxes_mask" in kwargs:
- postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
- return preprocess_kwargs, forward_params, postprocess_kwargs
- @overload
- def __call__(self, image: Union[str, "Image.Image"], *args: Any, **kwargs: Any) -> dict[str, Any]: ...
- @overload
- def __call__(self, image: list[str] | list["Image.Image"], *args: Any, **kwargs: Any) -> list[dict[str, Any]]: ...
- def __call__(
- self, image: Union[str, "Image.Image", list[str], list["Image.Image"]], *args: Any, **kwargs: Any
- ) -> dict[str, Any] | list[dict[str, Any]]:
- """
- Generates binary segmentation masks
- Args:
- image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
- Image or list of images.
- mask_threshold (`float`, *optional*, defaults to 0.0):
- Threshold to use when turning the predicted masks into binary values.
- pred_iou_thresh (`float`, *optional*, defaults to 0.88):
- A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
- stability_score_thresh (`float`, *optional*, defaults to 0.95):
- A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
- binarize the model's mask predictions.
- stability_score_offset (`int`, *optional*, defaults to 1):
- The amount to shift the cutoff when calculated the stability score.
- crops_nms_thresh (`float`, *optional*, defaults to 0.7):
- The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
- crops_n_layers (`int`, *optional*, defaults to 0):
- If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
- layers to run, where each layer has 2**i_layer number of image crops.
- crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
- Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
- the image length. Later layers with more crops scale down this overlap.
- crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
- The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
- timeout (`float`, *optional*, defaults to None):
- The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
- the call may block forever.
- Return:
- `Dict`: A dictionary with the following keys:
- - **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
- height)` of the original image. Returns a mask filled with zeros if no object is found.
- - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
- the "object" described by the label and the mask.
- """
- num_workers = kwargs.pop("num_workers", None)
- batch_size = kwargs.pop("batch_size", None)
- return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
- def preprocess(
- self,
- image,
- points_per_batch=64,
- crops_n_layers: int = 0,
- crop_overlap_ratio: float = 512 / 1500,
- points_per_crop: int = 32,
- crop_n_points_downscale_factor: int = 1,
- timeout: float | None = None,
- ):
- image = load_image(image, timeout=timeout)
- target_size = self.image_processor.size.get("longest_edge", self.image_processor.size.get("height"))
- crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
- image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
- )
- model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
- model_inputs = model_inputs.to(self.dtype)
- with self.device_placement():
- inference_context = self.get_inference_context()
- with inference_context():
- model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
- embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
- # Handle both SAM (single tensor) and SAM-HQ (tuple) outputs
- if isinstance(embeddings, tuple):
- image_embeddings, intermediate_embeddings = embeddings
- model_inputs["intermediate_embeddings"] = intermediate_embeddings
- else:
- image_embeddings = embeddings
- # TODO: Identifying the model by the type of its returned embeddings is brittle.
- # Consider using a more robust method for distinguishing model types here.
- model_inputs["image_embeddings"] = image_embeddings
- n_points = grid_points.shape[1]
- points_per_batch = points_per_batch if points_per_batch is not None else n_points
- if points_per_batch <= 0:
- raise ValueError(
- "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
- "To return all points at once, set points_per_batch to None"
- )
- for i in range(0, n_points, points_per_batch):
- batched_points = grid_points[:, i : i + points_per_batch, :, :]
- labels = input_labels[:, i : i + points_per_batch]
- is_last = i == n_points - points_per_batch
- yield {
- "input_points": batched_points,
- "input_labels": labels,
- "input_boxes": crop_boxes,
- "is_last": is_last,
- **model_inputs,
- }
- def _forward(
- self,
- model_inputs,
- pred_iou_thresh=0.88,
- stability_score_thresh=0.95,
- mask_threshold=0,
- stability_score_offset=1,
- max_hole_area=None,
- max_sprinkle_area=None,
- ):
- input_boxes = model_inputs.pop("input_boxes")
- is_last = model_inputs.pop("is_last")
- original_sizes = model_inputs.pop("original_sizes").tolist()
- reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes", None)
- reshaped_input_sizes = reshaped_input_sizes.tolist() if reshaped_input_sizes is not None else None
- model_outputs = self.model(**model_inputs)
- # post processing happens here in order to avoid CPU GPU copies of ALL the masks
- low_resolution_masks = model_outputs["pred_masks"]
- postprocess_kwargs = {}
- if max_hole_area is not None:
- postprocess_kwargs["max_hole_area"] = max_hole_area
- if max_sprinkle_area is not None and max_sprinkle_area > 0:
- postprocess_kwargs["max_sprinkle_area"] = max_sprinkle_area
- if postprocess_kwargs:
- low_resolution_masks = self.image_processor.post_process_masks(
- low_resolution_masks,
- original_sizes,
- mask_threshold=mask_threshold,
- reshaped_input_sizes=reshaped_input_sizes,
- binarize=False,
- **postprocess_kwargs,
- )
- masks = self.image_processor.post_process_masks(
- low_resolution_masks,
- original_sizes,
- mask_threshold=mask_threshold,
- reshaped_input_sizes=reshaped_input_sizes,
- binarize=False,
- )
- iou_scores = model_outputs["iou_scores"]
- masks, iou_scores, boxes = self.image_processor.filter_masks(
- masks[0],
- iou_scores[0],
- original_sizes[0],
- input_boxes[0],
- pred_iou_thresh,
- stability_score_thresh,
- mask_threshold,
- stability_score_offset,
- )
- return {
- "masks": masks,
- "is_last": is_last,
- "boxes": boxes,
- "iou_scores": iou_scores,
- }
- def postprocess(
- self,
- model_outputs,
- output_rle_mask=False,
- output_bboxes_mask=False,
- crops_nms_thresh=0.7,
- ):
- all_scores = []
- all_masks = []
- all_boxes = []
- for model_output in model_outputs:
- all_scores.append(model_output.pop("iou_scores"))
- all_masks.extend(model_output.pop("masks"))
- all_boxes.append(model_output.pop("boxes"))
- all_scores = torch.cat(all_scores)
- all_boxes = torch.cat(all_boxes)
- output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
- all_masks, all_scores, all_boxes, crops_nms_thresh
- )
- extra = defaultdict(list)
- for output in model_outputs:
- for k, v in output.items():
- extra[k].append(v)
- optional = {}
- if output_rle_mask:
- optional["rle_mask"] = rle_mask
- if output_bboxes_mask:
- optional["bounding_boxes"] = bounding_boxes
- return {"masks": output_masks, "scores": iou_scores, **optional, **extra}
|