image_segmentation.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from typing import Any, Union, overload
  2. import numpy as np
  3. from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
  4. from .base import Pipeline, build_pipeline_init_args
  5. if is_vision_available():
  6. from PIL import Image
  7. from ..image_utils import load_image
  8. if is_torch_available():
  9. from ..models.auto.modeling_auto import (
  10. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
  11. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
  12. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
  13. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
  14. )
  15. logger = logging.get_logger(__name__)
  16. @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
  17. class ImageSegmentationPipeline(Pipeline):
  18. """
  19. Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
  20. their classes.
  21. Example:
  22. ```python
  23. >>> from transformers import pipeline
  24. >>> segmenter = pipeline(model="facebook/detr-resnet-50-panoptic")
  25. >>> segments = segmenter("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
  26. >>> len(segments)
  27. 2
  28. >>> segments[0]["label"]
  29. 'bird'
  30. >>> segments[1]["label"]
  31. 'bird'
  32. >>> type(segments[0]["mask"]) # This is a black and white mask showing where is the bird on the original image.
  33. <class 'PIL.Image.Image'>
  34. >>> segments[0]["mask"].size
  35. (768, 512)
  36. ```
  37. This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  38. `"image-segmentation"`.
  39. See the list of available models on
  40. [huggingface.co/models](https://huggingface.co/models?filter=image-segmentation).
  41. """
  42. _load_processor = False
  43. _load_image_processor = True
  44. _load_feature_extractor = False
  45. _load_tokenizer = None # Oneformer uses it but no-one else does
  46. def __init__(self, *args, **kwargs):
  47. super().__init__(*args, **kwargs)
  48. requires_backends(self, "vision")
  49. mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
  50. mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
  51. mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
  52. mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
  53. self.check_model_type(mapping)
  54. def _sanitize_parameters(self, **kwargs):
  55. preprocess_kwargs = {}
  56. postprocess_kwargs = {}
  57. if "subtask" in kwargs:
  58. postprocess_kwargs["subtask"] = kwargs["subtask"]
  59. preprocess_kwargs["subtask"] = kwargs["subtask"]
  60. if "threshold" in kwargs:
  61. postprocess_kwargs["threshold"] = kwargs["threshold"]
  62. if "mask_threshold" in kwargs:
  63. postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
  64. if "overlap_mask_area_threshold" in kwargs:
  65. postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
  66. if "timeout" in kwargs:
  67. preprocess_kwargs["timeout"] = kwargs["timeout"]
  68. return preprocess_kwargs, {}, postprocess_kwargs
  69. @overload
  70. def __call__(self, inputs: Union[str, "Image.Image"], **kwargs: Any) -> list[dict[str, Any]]: ...
  71. @overload
  72. def __call__(self, inputs: list[str] | list["Image.Image"], **kwargs: Any) -> list[list[dict[str, Any]]]: ...
  73. def __call__(
  74. self, inputs: Union[str, "Image.Image", list[str], list["Image.Image"]], **kwargs: Any
  75. ) -> list[dict[str, Any]] | list[list[dict[str, Any]]]:
  76. """
  77. Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
  78. Args:
  79. inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
  80. The pipeline handles three types of images:
  81. - A string containing an HTTP(S) link pointing to an image
  82. - A string containing a local path to an image
  83. - An image loaded in PIL directly
  84. The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
  85. same format: all as HTTP(S) links, all as local paths, or all as PIL images.
  86. subtask (`str`, *optional*):
  87. Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
  88. capabilities. If not set, the pipeline will attempt tp resolve in the following order:
  89. `panoptic`, `instance`, `semantic`.
  90. threshold (`float`, *optional*, defaults to 0.9):
  91. Probability threshold to filter out predicted masks.
  92. mask_threshold (`float`, *optional*, defaults to 0.5):
  93. Threshold to use when turning the predicted masks into binary values.
  94. overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
  95. Mask overlap threshold to eliminate small, disconnected segments.
  96. timeout (`float`, *optional*, defaults to None):
  97. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
  98. the call may block forever.
  99. Return:
  100. If the input is a single image, will return a list of dictionaries, if the input is a list of several images,
  101. will return a list of list of dictionaries corresponding to each image.
  102. The dictionaries contain the mask, label and score (where applicable) of each detected object and contains
  103. the following keys:
  104. - **label** (`str`) -- The class label identified by the model.
  105. - **mask** (`PIL.Image`) -- A binary mask of the detected object as a Pil Image of shape (width, height) of
  106. the original image. Returns a mask filled with zeros if no object is found.
  107. - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
  108. "object" described by the label and the mask.
  109. """
  110. # After deprecation of this is completed, remove the default `None` value for `images`
  111. if "images" in kwargs:
  112. inputs = kwargs.pop("images")
  113. if inputs is None:
  114. raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
  115. return super().__call__(inputs, **kwargs)
  116. def preprocess(self, image, subtask=None, timeout=None):
  117. image = load_image(image, timeout=timeout)
  118. target_size = [(image.height, image.width)]
  119. if self.model.config.__class__.__name__ == "OneFormerConfig":
  120. if subtask is None:
  121. kwargs = {}
  122. else:
  123. kwargs = {"task_inputs": [subtask]}
  124. inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
  125. inputs = inputs.to(self.dtype)
  126. inputs["task_inputs"] = self.tokenizer(
  127. inputs["task_inputs"],
  128. padding="max_length",
  129. max_length=self.model.config.task_seq_len,
  130. return_tensors="pt",
  131. )["input_ids"]
  132. else:
  133. inputs = self.image_processor(images=[image], return_tensors="pt")
  134. inputs = inputs.to(self.dtype)
  135. inputs["target_size"] = target_size
  136. return inputs
  137. def _forward(self, model_inputs):
  138. target_size = model_inputs.pop("target_size")
  139. model_outputs = self.model(**model_inputs)
  140. model_outputs["target_size"] = target_size
  141. return model_outputs
  142. def postprocess(
  143. self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
  144. ):
  145. fn = None
  146. if subtask in {"panoptic", None} and hasattr(self.image_processor, "post_process_panoptic_segmentation"):
  147. fn = self.image_processor.post_process_panoptic_segmentation
  148. elif subtask in {"instance", None} and hasattr(self.image_processor, "post_process_instance_segmentation"):
  149. fn = self.image_processor.post_process_instance_segmentation
  150. if fn is not None:
  151. outputs = fn(
  152. model_outputs,
  153. threshold=threshold,
  154. mask_threshold=mask_threshold,
  155. overlap_mask_area_threshold=overlap_mask_area_threshold,
  156. target_sizes=model_outputs["target_size"],
  157. )[0]
  158. annotation = []
  159. segmentation = outputs["segmentation"]
  160. for segment in outputs["segments_info"]:
  161. mask = (segmentation == segment["id"]) * 255
  162. mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
  163. label = self.model.config.id2label[segment["label_id"]]
  164. score = segment["score"]
  165. annotation.append({"score": score, "label": label, "mask": mask})
  166. elif subtask in {"semantic", None} and hasattr(self.image_processor, "post_process_semantic_segmentation"):
  167. outputs = self.image_processor.post_process_semantic_segmentation(
  168. model_outputs, target_sizes=model_outputs["target_size"]
  169. )[0]
  170. annotation = []
  171. segmentation = outputs.numpy()
  172. labels = np.unique(segmentation)
  173. for label in labels:
  174. mask = (segmentation == label) * 255
  175. mask = Image.fromarray(mask.astype(np.uint8), mode="L")
  176. label = self.model.config.id2label[label]
  177. annotation.append({"score": None, "label": label, "mask": mask})
  178. else:
  179. raise ValueError(f"Subtask {subtask} is not supported for model {type(self.model)}")
  180. return annotation