| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017 |
- # Copyright 2021 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import base64
- import os
- from collections.abc import Iterable
- from dataclasses import dataclass, fields
- from io import BytesIO
- from typing import Any, Union
- import httpx
- import numpy as np
- from .utils import (
- ExplicitEnum,
- is_numpy_array,
- is_torch_available,
- is_torch_tensor,
- is_torchvision_available,
- is_vision_available,
- logging,
- requires_backends,
- to_numpy,
- )
- from .utils.constants import ( # noqa: F401
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- IMAGENET_STANDARD_MEAN,
- IMAGENET_STANDARD_STD,
- OPENAI_CLIP_MEAN,
- OPENAI_CLIP_STD,
- )
- if is_vision_available():
- import PIL.Image
- import PIL.ImageOps
- PILImageResampling = PIL.Image.Resampling
- if is_torchvision_available():
- from torchvision.transforms import InterpolationMode
- pil_torch_interpolation_mapping = {
- PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT,
- PILImageResampling.BOX: InterpolationMode.BOX,
- PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
- PILImageResampling.HAMMING: InterpolationMode.HAMMING,
- PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
- PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
- }
- # Create inverse mapping: InterpolationMode -> PILImageResampling
- torch_pil_interpolation_mapping = {v: k for k, v in pil_torch_interpolation_mapping.items()}
- else:
- pil_torch_interpolation_mapping = {}
- torch_pil_interpolation_mapping = {}
- if is_torch_available():
- import torch
- logger = logging.get_logger(__name__)
- ImageInput = Union[
- "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
- ]
- class ChannelDimension(ExplicitEnum):
- FIRST = "channels_first"
- LAST = "channels_last"
- class AnnotationFormat(ExplicitEnum):
- COCO_DETECTION = "coco_detection"
- COCO_PANOPTIC = "coco_panoptic"
- AnnotationType = dict[str, int | str | list[dict]]
- def is_pil_image(img):
- return is_vision_available() and isinstance(img, PIL.Image.Image)
- class ImageType(ExplicitEnum):
- PIL = "pillow"
- TORCH = "torch"
- NUMPY = "numpy"
- def get_image_type(image):
- if is_pil_image(image):
- return ImageType.PIL
- if is_torch_tensor(image):
- return ImageType.TORCH
- if is_numpy_array(image):
- return ImageType.NUMPY
- raise ValueError(f"Unrecognized image type {type(image)}")
- def is_valid_image(img):
- return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img)
- def is_valid_list_of_images(images: list):
- return images and all(is_valid_image(image) for image in images)
- def concatenate_list(input_list):
- if isinstance(input_list[0], list):
- return [item for sublist in input_list for item in sublist]
- elif isinstance(input_list[0], np.ndarray):
- return np.concatenate(input_list, axis=0)
- elif isinstance(input_list[0], torch.Tensor):
- return torch.cat(input_list, dim=0)
- def valid_images(imgs):
- # If we have an list of images, make sure every image is valid
- if isinstance(imgs, (list, tuple)):
- for img in imgs:
- if not valid_images(img):
- return False
- # If not a list of tuple, we have been given a single image or batched tensor of images
- elif not is_valid_image(imgs):
- return False
- return True
- def is_batched(img):
- if isinstance(img, (list, tuple)):
- return is_valid_image(img[0])
- return False
- def is_scaled_image(image: np.ndarray) -> bool:
- """
- Checks to see whether the pixel values have already been rescaled to [0, 1].
- """
- if image.dtype == np.uint8:
- return False
- # It's possible the image has pixel values in [0, 255] but is of floating type
- return np.min(image) >= 0 and np.max(image) <= 1
- def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
- """
- Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1.
- If the input is a batch of images, it is converted to a list of images.
- Args:
- images (`ImageInput`):
- Image of images to turn into a list of images.
- expected_ndims (`int`, *optional*, defaults to 3):
- Expected number of dimensions for a single input image. If the input image has a different number of
- dimensions, an error is raised.
- """
- if is_batched(images):
- return images
- # Either the input is a single image, in which case we create a list of length 1
- if is_pil_image(images):
- # PIL images are never batched
- return [images]
- if is_valid_image(images):
- if images.ndim == expected_ndims + 1:
- # Batch of images
- images = list(images)
- elif images.ndim == expected_ndims:
- # Single image
- images = [images]
- else:
- raise ValueError(
- f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
- f" {images.ndim} dimensions."
- )
- return images
- raise ValueError(
- f"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, or torch.Tensor, but got {type(images)}."
- )
- def make_flat_list_of_images(
- images: list[ImageInput] | ImageInput,
- expected_ndims: int = 3,
- ) -> ImageInput:
- """
- Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
- If the input is a nested list of images, it is converted to a flat list of images.
- Args:
- images (`Union[list[ImageInput], ImageInput]`):
- The input image.
- expected_ndims (`int`, *optional*, defaults to 3):
- The expected number of dimensions for a single input image.
- Returns:
- list: A list of images or a 4d array of images.
- """
- # If the input is a nested list of images, we flatten it
- if (
- isinstance(images, (list, tuple))
- and all(isinstance(images_i, (list, tuple)) for images_i in images)
- and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
- ):
- return [img for img_list in images for img in img_list]
- if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
- if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
- return images
- if images[0].ndim == expected_ndims + 1:
- return [img for img_list in images for img in img_list]
- if is_valid_image(images):
- if is_pil_image(images) or images.ndim == expected_ndims:
- return [images]
- if images.ndim == expected_ndims + 1:
- return list(images)
- raise ValueError(f"Could not make a flat list of images from {images}")
- def make_nested_list_of_images(
- images: list[ImageInput] | ImageInput,
- expected_ndims: int = 3,
- ) -> list[ImageInput]:
- """
- Ensure that the output is a nested list of images.
- Args:
- images (`Union[list[ImageInput], ImageInput]`):
- The input image.
- expected_ndims (`int`, *optional*, defaults to 3):
- The expected number of dimensions for a single input image.
- Returns:
- list: A list of list of images or a list of 4d array of images.
- """
- # If it's a list of batches, it's already in the right format
- if (
- isinstance(images, (list, tuple))
- and all(isinstance(images_i, (list, tuple)) for images_i in images)
- and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
- ):
- return images
- # If it's a list of images, it's a single batch, so convert it to a list of lists
- if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
- if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
- return [images]
- if images[0].ndim == expected_ndims + 1:
- return [list(image) for image in images]
- # If it's a single image, convert it to a list of lists
- if is_valid_image(images):
- if is_pil_image(images) or images.ndim == expected_ndims:
- return [[images]]
- if images.ndim == expected_ndims + 1:
- return [list(images)]
- raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
- def to_numpy_array(img) -> np.ndarray:
- if not is_valid_image(img):
- raise ValueError(f"Invalid image type: {type(img)}")
- if is_vision_available() and isinstance(img, PIL.Image.Image):
- return np.array(img)
- return to_numpy(img)
- def infer_channel_dimension_format(
- image: np.ndarray, num_channels: int | tuple[int, ...] | None = None
- ) -> ChannelDimension:
- """
- Infers the channel dimension format of `image`.
- Args:
- image (`np.ndarray`):
- The image to infer the channel dimension of.
- num_channels (`int` or `tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
- The number of channels of the image.
- Returns:
- The channel dimension of the image.
- """
- num_channels = num_channels if num_channels is not None else (1, 3)
- num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
- if image.ndim == 3:
- first_dim, last_dim = 0, 2
- elif image.ndim == 4:
- first_dim, last_dim = 1, 3
- elif image.ndim == 5:
- first_dim, last_dim = 2, 4
- else:
- raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
- if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
- logger.warning(
- f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension."
- )
- return ChannelDimension.FIRST
- elif image.shape[first_dim] in num_channels:
- return ChannelDimension.FIRST
- elif image.shape[last_dim] in num_channels:
- return ChannelDimension.LAST
- raise ValueError("Unable to infer channel dimension format")
- def get_channel_dimension_axis(image: np.ndarray, input_data_format: ChannelDimension | str | None = None) -> int:
- """
- Returns the channel dimension axis of the image.
- Args:
- image (`np.ndarray`):
- The image to get the channel dimension axis of.
- input_data_format (`ChannelDimension` or `str`, *optional*):
- The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
- Returns:
- The channel dimension axis of the image.
- """
- if input_data_format is None:
- input_data_format = infer_channel_dimension_format(image)
- if input_data_format == ChannelDimension.FIRST:
- return image.ndim - 3
- elif input_data_format == ChannelDimension.LAST:
- return image.ndim - 1
- raise ValueError(f"Unsupported data format: {input_data_format}")
- def get_image_size(image: np.ndarray, channel_dim: ChannelDimension | None = None) -> tuple[int, int]:
- """
- Returns the (height, width) dimensions of the image.
- Args:
- image (`np.ndarray`):
- The image to get the dimensions of.
- channel_dim (`ChannelDimension`, *optional*):
- Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
- Returns:
- A tuple of the image's height and width.
- """
- if channel_dim is None:
- channel_dim = infer_channel_dimension_format(image)
- if channel_dim == ChannelDimension.FIRST:
- return image.shape[-2], image.shape[-1]
- elif channel_dim == ChannelDimension.LAST:
- return image.shape[-3], image.shape[-2]
- else:
- raise ValueError(f"Unsupported data format: {channel_dim}")
- def get_image_size_for_max_height_width(
- image_size: tuple[int, int],
- max_height: int,
- max_width: int,
- ) -> tuple[int, int]:
- """
- Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
- Important, even if image_height < max_height and image_width < max_width, the image will be resized
- to at least one of the edges be equal to max_height or max_width.
- For example:
- - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
- - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
- Args:
- image_size (`tuple[int, int]`):
- The image to resize.
- max_height (`int`):
- The maximum allowed height.
- max_width (`int`):
- The maximum allowed width.
- """
- height, width = image_size
- height_scale = max_height / height
- width_scale = max_width / width
- min_scale = min(height_scale, width_scale)
- new_height = int(height * min_scale)
- new_width = int(width * min_scale)
- return new_height, new_width
- def max_across_indices(values: Iterable[Any]) -> list[Any]:
- """
- Return the maximum value across all indices of an iterable of values.
- """
- return [max(values_i) for values_i in zip(*values)]
- def get_max_height_width(
- images: list[Union["torch.Tensor", np.ndarray]], input_data_format: str | ChannelDimension = ChannelDimension.FIRST
- ) -> list[int]:
- """
- Get the maximum height and width across all images in a batch.
- """
- if input_data_format == ChannelDimension.FIRST:
- _, max_height, max_width = max_across_indices([img.shape for img in images])
- elif input_data_format == ChannelDimension.LAST:
- max_height, max_width, _ = max_across_indices([img.shape for img in images])
- else:
- raise ValueError(f"Invalid channel dimension format: {input_data_format}")
- return (max_height, max_width)
- def is_valid_annotation_coco_detection(annotation: dict[str, list | tuple]) -> bool:
- if (
- isinstance(annotation, dict)
- and "image_id" in annotation
- and "annotations" in annotation
- and isinstance(annotation["annotations"], (list, tuple))
- and (
- # an image can have no annotations
- len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
- )
- ):
- return True
- return False
- def is_valid_annotation_coco_panoptic(annotation: dict[str, list | tuple]) -> bool:
- if (
- isinstance(annotation, dict)
- and "image_id" in annotation
- and "segments_info" in annotation
- and "file_name" in annotation
- and isinstance(annotation["segments_info"], (list, tuple))
- and (
- # an image can have no segments
- len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
- )
- ):
- return True
- return False
- def valid_coco_detection_annotations(annotations: Iterable[dict[str, list | tuple]]) -> bool:
- return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
- def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, list | tuple]]) -> bool:
- return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
- def load_image(image: Union[str, "PIL.Image.Image"], timeout: float | None = None) -> "PIL.Image.Image":
- """
- Loads `image` to a PIL Image.
- Args:
- image (`str` or `PIL.Image.Image`):
- The image to convert to the PIL Image format.
- timeout (`float`, *optional*):
- The timeout value in seconds for the URL request.
- Returns:
- `PIL.Image.Image`: A PIL Image.
- """
- requires_backends(load_image, ["vision"])
- if isinstance(image, str):
- if image.startswith("http://") or image.startswith("https://"):
- # We need to actually check for a real protocol, otherwise it's impossible to use a local file
- # like http_huggingface_co.png
- image = PIL.Image.open(BytesIO(httpx.get(image, timeout=timeout, follow_redirects=True).content))
- elif os.path.isfile(image):
- image = PIL.Image.open(image)
- else:
- if image.startswith("data:image/"):
- image = image.split(",")[1]
- # Try to load as base64
- try:
- b64 = base64.decodebytes(image.encode())
- image = PIL.Image.open(BytesIO(b64))
- except Exception as e:
- raise ValueError(
- f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
- )
- elif not isinstance(image, PIL.Image.Image):
- raise TypeError(
- "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
- )
- image = PIL.ImageOps.exif_transpose(image)
- image = image.convert("RGB")
- return image
- def load_images(
- images: Union[list, tuple, str, "PIL.Image.Image"], timeout: float | None = None
- ) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
- """Loads images, handling different levels of nesting.
- Args:
- images: A single image, a list of images, or a list of lists of images to load.
- timeout: Timeout for loading images.
- Returns:
- A single image, a list of images, a list of lists of images.
- """
- if isinstance(images, (list, tuple)):
- if len(images) and isinstance(images[0], (list, tuple)):
- return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
- else:
- return [load_image(image, timeout=timeout) for image in images]
- else:
- return load_image(images, timeout=timeout)
- def validate_preprocess_arguments(
- do_rescale: bool | None = None,
- rescale_factor: float | None = None,
- do_normalize: bool | None = None,
- image_mean: float | list[float] | None = None,
- image_std: float | list[float] | None = None,
- do_pad: bool | None = None,
- pad_size: dict[str, int] | int | None = None,
- do_center_crop: bool | None = None,
- crop_size: dict[str, int] | None = None,
- do_resize: bool | None = None,
- size: dict[str, int] | None = None,
- resample: Union["PILImageResampling", "InterpolationMode", int] | None = None,
- ):
- """
- Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
- Raises `ValueError` if arguments incompatibility is caught.
- Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
- sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
- existing arguments when possible.
- """
- if do_rescale and rescale_factor is None:
- raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
- if do_pad and pad_size is None:
- # Processors pad images using different args depending on the model, so the below check is pointless
- # but we keep it for BC for now. TODO: remove in v5
- # Usually padding can be called with:
- # - "pad_size/size" if we're padding to specific values
- # - "size_divisor" if we're padding to any value divisible by X
- # - "None" if we're padding to the maximum size image in batch
- raise ValueError(
- "Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`."
- )
- if do_normalize and (image_mean is None or image_std is None):
- raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
- if do_center_crop and crop_size is None:
- raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
- if do_resize and not (size is not None and resample is not None):
- raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
- class ImageFeatureExtractionMixin:
- """
- Mixin that contain utilities for preparing image features.
- """
- def _ensure_format_supported(self, image):
- if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
- raise ValueError(
- f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.ndarray` and "
- "`torch.Tensor` are."
- )
- def to_pil_image(self, image, rescale=None):
- """
- Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
- needed.
- Args:
- image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
- The image to convert to the PIL Image format.
- rescale (`bool`, *optional*):
- Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
- default to `True` if the image type is a floating type, `False` otherwise.
- """
- self._ensure_format_supported(image)
- if is_torch_tensor(image):
- image = image.numpy()
- if isinstance(image, np.ndarray):
- if rescale is None:
- # rescale default to the array being of floating type.
- rescale = isinstance(image.flat[0], np.floating)
- # If the channel as been moved to first dim, we put it back at the end.
- if image.ndim == 3 and image.shape[0] in [1, 3]:
- image = image.transpose(1, 2, 0)
- if rescale:
- image = image * 255
- image = image.astype(np.uint8)
- return PIL.Image.fromarray(image)
- return image
- def convert_rgb(self, image):
- """
- Converts `PIL.Image.Image` to RGB format.
- Args:
- image (`PIL.Image.Image`):
- The image to convert.
- """
- self._ensure_format_supported(image)
- if not isinstance(image, PIL.Image.Image):
- return image
- return image.convert("RGB")
- def rescale(self, image: np.ndarray, scale: float | int) -> np.ndarray:
- """
- Rescale a numpy image by scale amount
- """
- self._ensure_format_supported(image)
- return image * scale
- def to_numpy_array(self, image, rescale=None, channel_first=True):
- """
- Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
- dimension.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to convert to a NumPy array.
- rescale (`bool`, *optional*):
- Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
- default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
- channel_first (`bool`, *optional*, defaults to `True`):
- Whether or not to permute the dimensions of the image to put the channel dimension first.
- """
- self._ensure_format_supported(image)
- if isinstance(image, PIL.Image.Image):
- image = np.array(image)
- if is_torch_tensor(image):
- image = image.numpy()
- rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
- if rescale:
- image = self.rescale(image.astype(np.float32), 1 / 255.0)
- if channel_first and image.ndim == 3:
- image = image.transpose(2, 0, 1)
- return image
- def expand_dims(self, image):
- """
- Expands 2-dimensional `image` to 3 dimensions.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to expand.
- """
- self._ensure_format_supported(image)
- # Do nothing if PIL image
- if isinstance(image, PIL.Image.Image):
- return image
- if is_torch_tensor(image):
- image = image.unsqueeze(0)
- else:
- image = np.expand_dims(image, axis=0)
- return image
- def normalize(self, image, mean, std, rescale=False):
- """
- Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
- if it's a PIL Image.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to normalize.
- mean (`list[float]` or `np.ndarray` or `torch.Tensor`):
- The mean (per channel) to use for normalization.
- std (`list[float]` or `np.ndarray` or `torch.Tensor`):
- The standard deviation (per channel) to use for normalization.
- rescale (`bool`, *optional*, defaults to `False`):
- Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
- happen automatically.
- """
- self._ensure_format_supported(image)
- if isinstance(image, PIL.Image.Image):
- image = self.to_numpy_array(image, rescale=True)
- # If the input image is a PIL image, it automatically gets rescaled. If it's another
- # type it may need rescaling.
- elif rescale:
- if isinstance(image, np.ndarray):
- image = self.rescale(image.astype(np.float32), 1 / 255.0)
- elif is_torch_tensor(image):
- image = self.rescale(image.float(), 1 / 255.0)
- if isinstance(image, np.ndarray):
- if not isinstance(mean, np.ndarray):
- mean = np.array(mean).astype(image.dtype)
- if not isinstance(std, np.ndarray):
- std = np.array(std).astype(image.dtype)
- elif is_torch_tensor(image):
- import torch
- if not isinstance(mean, torch.Tensor):
- if isinstance(mean, np.ndarray):
- mean = torch.from_numpy(mean)
- else:
- mean = torch.tensor(mean)
- if not isinstance(std, torch.Tensor):
- if isinstance(std, np.ndarray):
- std = torch.from_numpy(std)
- else:
- std = torch.tensor(std)
- if image.ndim == 3 and image.shape[0] in [1, 3]:
- return (image - mean[:, None, None]) / std[:, None, None]
- else:
- return (image - mean) / std
- def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
- """
- Resizes `image`. Enforces conversion of input to PIL.Image.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to resize.
- size (`int` or `tuple[int, int]`):
- The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
- matched to this.
- If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
- `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
- this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
- resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
- The filter to user for resampling.
- default_to_square (`bool`, *optional*, defaults to `True`):
- How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
- square (`size`,`size`). If set to `False`, will replicate
- [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
- with support for resizing only the smallest edge and providing an optional `max_size`.
- max_size (`int`, *optional*, defaults to `None`):
- The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
- greater than `max_size` after being resized according to `size`, then the image is resized again so
- that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
- edge may be shorter than `size`. Only used if `default_to_square` is `False`.
- Returns:
- image: A resized `PIL.Image.Image`.
- """
- resample = resample if resample is not None else PILImageResampling.BILINEAR
- self._ensure_format_supported(image)
- if not isinstance(image, PIL.Image.Image):
- image = self.to_pil_image(image)
- if isinstance(size, list):
- size = tuple(size)
- if isinstance(size, int) or len(size) == 1:
- if default_to_square:
- size = (size, size) if isinstance(size, int) else (size[0], size[0])
- else:
- width, height = image.size
- # specified size only for the smallest edge
- short, long = (width, height) if width <= height else (height, width)
- requested_new_short = size if isinstance(size, int) else size[0]
- if short == requested_new_short:
- return image
- new_short, new_long = requested_new_short, int(requested_new_short * long / short)
- if max_size is not None:
- if max_size <= requested_new_short:
- raise ValueError(
- f"max_size = {max_size} must be strictly greater than the requested "
- f"size for the smaller edge size = {size}"
- )
- if new_long > max_size:
- new_short, new_long = int(max_size * new_short / new_long), max_size
- size = (new_short, new_long) if width <= height else (new_long, new_short)
- return image.resize(size, resample=resample)
- def center_crop(self, image, size):
- """
- Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
- size given, it will be padded (so the returned result has the size asked).
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
- The image to resize.
- size (`int` or `tuple[int, int]`):
- The size to which crop the image.
- Returns:
- new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
- height, width).
- """
- self._ensure_format_supported(image)
- if not isinstance(size, tuple):
- size = (size, size)
- # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
- if is_torch_tensor(image) or isinstance(image, np.ndarray):
- if image.ndim == 2:
- image = self.expand_dims(image)
- image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
- else:
- image_shape = (image.size[1], image.size[0])
- top = (image_shape[0] - size[0]) // 2
- bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
- left = (image_shape[1] - size[1]) // 2
- right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
- # For PIL Images we have a method to crop directly.
- if isinstance(image, PIL.Image.Image):
- return image.crop((left, top, right, bottom))
- # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
- channel_first = image.shape[0] in [1, 3]
- # Transpose (height, width, n_channels) format images
- if not channel_first:
- if isinstance(image, np.ndarray):
- image = image.transpose(2, 0, 1)
- if is_torch_tensor(image):
- image = image.permute(2, 0, 1)
- # Check if cropped area is within image boundaries
- if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
- return image[..., top:bottom, left:right]
- # Otherwise, we may need to pad if the image is too small. Oh joy...
- new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
- if isinstance(image, np.ndarray):
- new_image = np.zeros_like(image, shape=new_shape)
- elif is_torch_tensor(image):
- new_image = image.new_zeros(new_shape)
- top_pad = (new_shape[-2] - image_shape[0]) // 2
- bottom_pad = top_pad + image_shape[0]
- left_pad = (new_shape[-1] - image_shape[1]) // 2
- right_pad = left_pad + image_shape[1]
- new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
- top += top_pad
- bottom += top_pad
- left += left_pad
- right += left_pad
- new_image = new_image[
- ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
- ]
- return new_image
- def flip_channel_order(self, image):
- """
- Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
- `image` to a NumPy array if it's a PIL Image.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
- be first.
- """
- self._ensure_format_supported(image)
- if isinstance(image, PIL.Image.Image):
- image = self.to_numpy_array(image)
- return image[::-1, :, :]
- def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
- """
- Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
- counter clockwise around its centre.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
- rotating.
- Returns:
- image: A rotated `PIL.Image.Image`.
- """
- resample = resample if resample is not None else PIL.Image.NEAREST
- self._ensure_format_supported(image)
- if not isinstance(image, PIL.Image.Image):
- image = self.to_pil_image(image)
- return image.rotate(
- angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
- )
- def validate_annotations(
- annotation_format: AnnotationFormat,
- supported_annotation_formats: tuple[AnnotationFormat, ...],
- annotations: list[dict],
- ) -> None:
- if annotation_format not in supported_annotation_formats:
- raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
- if annotation_format is AnnotationFormat.COCO_DETECTION:
- if not valid_coco_detection_annotations(annotations):
- raise ValueError(
- "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
- "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
- "being a list of annotations in the COCO format."
- )
- if annotation_format is AnnotationFormat.COCO_PANOPTIC:
- if not valid_coco_panoptic_annotations(annotations):
- raise ValueError(
- "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
- "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
- "the latter being a list of annotations in the COCO format."
- )
- def validate_kwargs(valid_processor_keys: list[str], captured_kwargs: list[str]):
- unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
- if unused_keys:
- unused_key_str = ", ".join(unused_keys)
- # TODO raise a warning here instead of simply logging?
- logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
- @dataclass()
- class SizeDict:
- """
- Hashable dictionary to store image size information.
- """
- height: int | None = None
- width: int | None = None
- longest_edge: int | None = None
- shortest_edge: int | None = None
- max_height: int | None = None
- max_width: int | None = None
- def __getitem__(self, key):
- if hasattr(self, key):
- return getattr(self, key)
- raise KeyError(f"Key {key} not found in SizeDict.")
- def get(self, key, default=None):
- if hasattr(self, key) and getattr(self, key) is not None:
- return getattr(self, key)
- return default
- def __iter__(self):
- # Yield only non-None (key, value) pairs so dict(self) excludes missing values.
- for f in fields(self):
- val = getattr(self, f.name)
- if val is not None:
- yield f.name, val
- def __hash__(self):
- return hash((self.height, self.width, self.longest_edge, self.shortest_edge, self.max_height, self.max_width))
- def __contains__(self, key):
- return hasattr(self, key) and getattr(self, key) is not None
- def __setitem__(self, key, value):
- if not hasattr(self, key):
- raise KeyError(f"Key {key} is not a valid field of SizeDict.")
- object.__setattr__(self, key, value)
- def __eq__(self, other):
- if isinstance(other, dict):
- return dict(self) == other
- if isinstance(other, SizeDict):
- return tuple(getattr(self, f.name) for f in fields(self)) == tuple(
- getattr(other, f.name) for f in fields(self)
- )
- return NotImplemented
- def __or__(self, other) -> "SizeDict":
- if isinstance(other, dict | SizeDict):
- merged = dict(self)
- merged.update(dict(other))
- return SizeDict(**merged)
- return NotImplemented
- def __ror__(self, other) -> dict:
- if isinstance(other, dict):
- merged = dict(other)
- merged.update(dict(self))
- return merged
- return NotImplemented
|