image_processing_utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. # Copyright 2022 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from collections.abc import Iterable
  16. from copy import deepcopy
  17. from functools import partial
  18. from typing import Any
  19. import numpy as np
  20. from huggingface_hub.dataclasses import validate_typed_dict
  21. from .image_processing_base import BatchFeature, ImageProcessingMixin
  22. from .image_transforms import center_crop, normalize, rescale
  23. from .image_utils import (
  24. ChannelDimension,
  25. ImageInput,
  26. SizeDict,
  27. get_image_size,
  28. make_flat_list_of_images,
  29. validate_preprocess_arguments,
  30. )
  31. from .processing_utils import ImagesKwargs, Unpack
  32. from .utils import (
  33. auto_docstring,
  34. is_torchvision_available,
  35. is_vision_available,
  36. logging,
  37. )
  38. if is_vision_available():
  39. from .image_utils import PILImageResampling
  40. if is_torchvision_available():
  41. from torchvision.transforms.v2 import functional as tvF
  42. logger = logging.get_logger(__name__)
  43. INIT_SERVICE_KWARGS = [
  44. "processor_class",
  45. "image_processor_type",
  46. ]
  47. class BaseImageProcessor(ImageProcessingMixin):
  48. r"""
  49. Base class for image processors with an inheritance-based backend architecture.
  50. This class defines the preprocessing pipeline: kwargs validation, input preparation, and dispatching to the
  51. backend's `_preprocess` method. Backend subclasses (`TorchvisionBackend`, `PilBackend`) inherit from this class
  52. and implement the actual image operations (resize, crop, rescale, normalize, etc.). Model-specific image
  53. processors then inherit from the appropriate backend class.
  54. Architecture Overview
  55. ---------------------
  56. The class hierarchy is:
  57. BaseImageProcessor (this class)
  58. ├── TorchvisionBackend (GPU-accelerated, torch.Tensor)
  59. │ └── ModelImageProcessor (e.g. LlavaNextImageProcessor)
  60. └── PilBackend (portable CPU, np.ndarray)
  61. └── ModelImageProcessorPil (e.g. CLIPImageProcessorPil)
  62. The preprocessing flow is:
  63. __call__() → preprocess() → _preprocess_image_like_inputs() → _prepare_image_like_inputs()
  64. (calls process_image per image)
  65. → _preprocess()
  66. (batch operations: resize, crop, etc.)
  67. - `process_image`: Implemented by backends. Converts a single raw input (PIL, NumPy, or Tensor) to the
  68. backend's working format (torch.Tensor or np.ndarray), handles RGB conversion and channel reordering.
  69. - `_preprocess`: Implemented by backends. Performs the actual batch processing (resize, center crop, rescale,
  70. normalize, pad) and returns a `BatchFeature`.
  71. Basic Implementation
  72. --------------------
  73. For processors that only need standard operations (resize, center crop, rescale, normalize), inherit from
  74. a backend and define class attributes:
  75. from transformers.image_processing_backends import PilBackend
  76. class MyImageProcessorPil(PilBackend):
  77. resample = PILImageResampling.BILINEAR
  78. image_mean = IMAGENET_DEFAULT_MEAN
  79. image_std = IMAGENET_DEFAULT_STD
  80. size = {"height": 224, "width": 224}
  81. do_resize = True
  82. do_rescale = True
  83. do_normalize = True
  84. The backend's `_preprocess` method handles the standard pipeline automatically.
  85. Custom Processing
  86. -----------------
  87. For processors that need custom logic (e.g., patch-based processing, multiple input types), override
  88. `_preprocess` in your model-specific processor. The `_preprocess` method receives already-prepared images
  89. (converted to the backend format with channels-first ordering) and performs the actual processing:
  90. class MyImageProcessor(TorchvisionBackend):
  91. def _preprocess(self, images, do_resize, size, do_normalize, image_mean, image_std, **kwargs):
  92. # Group images by shape for efficient batched operations
  93. grouped_images, grouped_images_index = group_images_by_shape(images)
  94. processed_groups = {}
  95. for shape, stacked_images in grouped_images.items():
  96. if do_resize:
  97. stacked_images = self.resize(stacked_images, size=size)
  98. if do_normalize:
  99. stacked_images = self.normalize(stacked_images, mean=image_mean, std=image_std)
  100. processed_groups[shape] = stacked_images
  101. processed_images = reorder_images(processed_groups, grouped_images_index)
  102. return BatchFeature(data={"pixel_values": processed_images})
  103. For processors handling multiple input types (e.g., images + segmentation maps), override
  104. `_preprocess_image_like_inputs`:
  105. def _preprocess_image_like_inputs(
  106. self,
  107. images: ImageInput,
  108. segmentation_maps: ImageInput | None = None,
  109. **kwargs,
  110. ) -> BatchFeature:
  111. images = self._prepare_image_like_inputs(images, **kwargs)
  112. batch_feature = self._preprocess(images, **kwargs)
  113. if segmentation_maps is not None:
  114. maps = self._prepare_image_like_inputs(segmentation_maps, **kwargs)
  115. batch_feature["labels"] = self._preprocess(maps, **kwargs).pixel_values
  116. return batch_feature
  117. Extending Backend Behavior
  118. --------------------------
  119. To customize operations for a specific backend, subclass the backend and override its methods:
  120. from transformers.image_processing_backends import TorchvisionBackend, PilBackend
  121. class MyTorchvisionProcessor(TorchvisionBackend):
  122. def resize(self, image, size, **kwargs):
  123. # Custom resize logic for torchvision
  124. return super().resize(image, size, **kwargs)
  125. class MyPilProcessor(PilBackend):
  126. def resize(self, image, size, **kwargs):
  127. # Custom resize logic for PIL
  128. return super().resize(image, size, **kwargs)
  129. Custom Parameters
  130. -----------------
  131. To add parameters beyond `ImagesKwargs`, create a custom kwargs class and set it as `valid_kwargs`:
  132. class MyImageProcessorKwargs(ImagesKwargs):
  133. custom_param: int | None = None
  134. class MyImageProcessor(TorchvisionBackend):
  135. valid_kwargs = MyImageProcessorKwargs
  136. custom_param = 10 # default value
  137. Key Notes
  138. ---------
  139. - Backend selection is done at the class level: inherit from `TorchvisionBackend` or `PilBackend`
  140. - Backends receive images as `torch.Tensor` (Torchvision) or `np.ndarray` (PIL), always channels-first
  141. - All images have channel dimension first during processing, regardless of backend
  142. - Arguments not provided by users default to class attribute values
  143. - Backend classes encapsulate backend-specific logic (resize, normalize, etc.) and can be overridden
  144. """
  145. valid_kwargs = ImagesKwargs
  146. default_to_square = True
  147. rescale_factor = 1 / 255
  148. model_input_names = ["pixel_values"]
  149. def __init__(self, **kwargs: Unpack[ImagesKwargs]):
  150. super().__init__(**kwargs)
  151. # We don't call self._set_attributes in BaseImageProcessor for backward compatibility with remote code
  152. # We call it instead in the backend subclasses' __init__ methods.
  153. def _set_attributes(self, **kwargs):
  154. """Resolve and set instance attributes from kwargs and class-level defaults for all valid kwargs."""
  155. attributes = {}
  156. for key in self.valid_kwargs.__annotations__:
  157. kwarg = kwargs.pop(key, None)
  158. if kwarg is not None:
  159. attributes[key] = kwarg
  160. else:
  161. attributes[key] = deepcopy(getattr(self, key, None))
  162. attributes = self._standardize_kwargs(**attributes)
  163. for key, value in attributes.items():
  164. setattr(self, key, value)
  165. self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
  166. def __call__(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
  167. """Preprocess an image or a batch of images."""
  168. return self.preprocess(images, *args, **kwargs)
  169. def process_image(self, *args, **kwargs):
  170. """
  171. Process a single raw image into the backend's working format.
  172. Implemented by backend subclasses (`TorchvisionBackend`, `PilBackend`). Converts a raw input
  173. (PIL Image, NumPy array, or torch Tensor) to the backend's internal format (`torch.Tensor` for
  174. Torchvision, `np.ndarray` for PIL), handles RGB conversion and ensures channels-first ordering.
  175. """
  176. raise NotImplementedError
  177. def _preprocess(self, *args, **kwargs):
  178. """
  179. Perform the actual batch image preprocessing (resize, center crop, rescale, normalize, pad).
  180. Implemented by backend subclasses (`TorchvisionBackend`, `PilBackend`). Receives a list of
  181. already-prepared images (in the backend's format, channels-first) and applies the configured
  182. preprocessing operations. Returns a `BatchFeature` with the processed pixel values.
  183. Model-specific processors can override this method to implement custom preprocessing logic
  184. (e.g., patch-based processing in LLaVA-NeXT).
  185. """
  186. raise NotImplementedError
  187. def _prepare_images_structure(
  188. self,
  189. images: ImageInput,
  190. expected_ndims: int = 3,
  191. ) -> ImageInput:
  192. """
  193. Prepare the images structure for processing.
  194. Args:
  195. images (`ImageInput`):
  196. The input images to process.
  197. Returns:
  198. `ImageInput`: The images with a valid nesting.
  199. """
  200. images = self.fetch_images(images)
  201. return make_flat_list_of_images(images, expected_ndims=expected_ndims)
  202. def _prepare_image_like_inputs(
  203. self,
  204. images: ImageInput,
  205. *args,
  206. expected_ndims: int = 3,
  207. **kwargs: Unpack[ImagesKwargs],
  208. ) -> list[Any]:
  209. """
  210. Prepare image-like inputs for processing by converting each image via `process_image`.
  211. Flattens the input structure and applies `process_image` (implemented by the backend) to each
  212. individual image, converting raw inputs (PIL, NumPy, Tensor) into the backend's working format
  213. with channels-first ordering.
  214. Args:
  215. images (`ImageInput`):
  216. The image-like inputs to process.
  217. expected_ndims (`int`, *optional*, defaults to 3):
  218. The expected number of dimensions for the images.
  219. Returns:
  220. `list[torch.Tensor]` or `list[np.ndarray]`: The prepared images in the backend's format,
  221. with channels-first ordering.
  222. """
  223. images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
  224. process_image_partial = partial(self.process_image, *args, **kwargs)
  225. has_nested_structure = len(images) > 0 and isinstance(images[0], list | tuple)
  226. if has_nested_structure:
  227. processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
  228. else:
  229. processed_images = [process_image_partial(img) for img in images]
  230. return processed_images
  231. def _preprocess_image_like_inputs(
  232. self,
  233. images: ImageInput,
  234. *args,
  235. **kwargs: Unpack[ImagesKwargs],
  236. ) -> BatchFeature:
  237. """
  238. Preprocess image-like inputs by preparing them and dispatching to `_preprocess`.
  239. This method first calls `_prepare_image_like_inputs` to convert raw inputs into the backend's
  240. format, then calls `_preprocess` for the actual batch processing. Override this method in
  241. model-specific processors that need to handle multiple image-like input types (e.g., images
  242. and segmentation maps) or need custom orchestration of the preprocessing pipeline.
  243. """
  244. images = self._prepare_image_like_inputs(images, **kwargs)
  245. return self._preprocess(images, *args, **kwargs)
  246. def _standardize_kwargs(
  247. self,
  248. size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
  249. crop_size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
  250. pad_size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
  251. default_to_square: bool | None = None,
  252. image_mean: float | list[float] | None = None,
  253. image_std: float | list[float] | None = None,
  254. **kwargs,
  255. ) -> dict:
  256. """
  257. Standardize kwargs to canonical format before validation.
  258. Can be overridden by subclasses to customize the processing of kwargs.
  259. """
  260. if kwargs is None:
  261. kwargs = {}
  262. if size is not None and not isinstance(size, SizeDict):
  263. size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
  264. if crop_size is not None and not isinstance(crop_size, SizeDict):
  265. crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
  266. if pad_size is not None and not isinstance(pad_size, SizeDict):
  267. pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
  268. if isinstance(image_mean, list):
  269. image_mean = tuple(image_mean)
  270. if isinstance(image_std, list):
  271. image_std = tuple(image_std)
  272. kwargs["size"] = size
  273. kwargs["crop_size"] = crop_size
  274. kwargs["pad_size"] = pad_size
  275. kwargs["image_mean"] = image_mean
  276. kwargs["image_std"] = image_std
  277. return kwargs
  278. # Backwards compatibility for method that was renamed
  279. _further_process_kwargs = _standardize_kwargs
  280. def _validate_preprocess_kwargs(
  281. self,
  282. do_rescale: bool | None = None,
  283. rescale_factor: float | None = None,
  284. do_normalize: bool | None = None,
  285. image_mean: float | tuple[float] | None = None,
  286. image_std: float | tuple[float] | None = None,
  287. do_resize: bool | None = None,
  288. size: SizeDict | None = None,
  289. do_center_crop: bool | None = None,
  290. crop_size: SizeDict | None = None,
  291. resample: "PILImageResampling | tvF.InterpolationMode | int | None" = None,
  292. **kwargs,
  293. ):
  294. """
  295. Validate the kwargs for the preprocess method.
  296. """
  297. validate_preprocess_arguments(
  298. do_rescale=do_rescale,
  299. rescale_factor=rescale_factor,
  300. do_normalize=do_normalize,
  301. image_mean=image_mean,
  302. image_std=image_std,
  303. do_center_crop=do_center_crop,
  304. crop_size=crop_size,
  305. do_resize=do_resize,
  306. size=size,
  307. resample=resample,
  308. )
  309. @auto_docstring
  310. def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
  311. """
  312. Preprocess an image or a batch of images.
  313. """
  314. # Perform type validation on received kwargs
  315. validate_typed_dict(self.valid_kwargs, kwargs)
  316. # Set default kwargs from self
  317. for kwarg_name in self._valid_kwargs_names:
  318. kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
  319. # Update kwargs that need further processing before being validated
  320. kwargs = self._standardize_kwargs(**kwargs)
  321. # Validate kwargs
  322. self._validate_preprocess_kwargs(**kwargs)
  323. return self._preprocess_image_like_inputs(images, *args, **kwargs)
  324. def to_dict(self) -> dict[str, Any]:
  325. processor_dict = super().to_dict()
  326. # Filter out None values that are class defaults
  327. filtered_dict = {}
  328. for key, value in processor_dict.items():
  329. if isinstance(value, SizeDict):
  330. value = dict(value)
  331. if value is None:
  332. class_default = getattr(type(self), key, "NOT_FOUND")
  333. # Keep None if user explicitly set it (class default is non-None)
  334. if class_default != "NOT_FOUND" and class_default is not None:
  335. filtered_dict[key] = value
  336. else:
  337. filtered_dict[key] = value
  338. filtered_dict.pop("_valid_processor_keys", None)
  339. filtered_dict.pop("_valid_kwargs_names", None)
  340. return filtered_dict
  341. def rescale(
  342. self,
  343. image: np.ndarray,
  344. scale: float,
  345. data_format: str | ChannelDimension | None = None,
  346. input_data_format: str | ChannelDimension | None = None,
  347. **kwargs,
  348. ) -> np.ndarray:
  349. """
  350. Rescale an image by a scale factor. image = image * scale.
  351. Args:
  352. image (`np.ndarray`):
  353. Image to rescale.
  354. scale (`float`):
  355. The scaling factor to rescale pixel values by.
  356. data_format (`str` or `ChannelDimension`, *optional*):
  357. The channel dimension format for the output image. If unset, the channel dimension format of the input
  358. image is used. Can be one of:
  359. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  360. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  361. input_data_format (`ChannelDimension` or `str`, *optional*):
  362. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  363. from the input image. Can be one of:
  364. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  365. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  366. Returns:
  367. `np.ndarray`: The rescaled image.
  368. """
  369. return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
  370. # The next methods are kept for backwards compatibility with remote code, but are overriden by backends.
  371. def normalize(
  372. self,
  373. image: np.ndarray,
  374. mean: float | Iterable[float],
  375. std: float | Iterable[float],
  376. data_format: str | ChannelDimension | None = None,
  377. input_data_format: str | ChannelDimension | None = None,
  378. **kwargs,
  379. ) -> np.ndarray:
  380. """
  381. Normalize an image. image = (image - image_mean) / image_std.
  382. Args:
  383. image (`np.ndarray`):
  384. Image to normalize.
  385. mean (`float` or `Iterable[float]`):
  386. Image mean to use for normalization.
  387. std (`float` or `Iterable[float]`):
  388. Image standard deviation to use for normalization.
  389. data_format (`str` or `ChannelDimension`, *optional*):
  390. The channel dimension format for the output image. If unset, the channel dimension format of the input
  391. image is used. Can be one of:
  392. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  393. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  394. input_data_format (`ChannelDimension` or `str`, *optional*):
  395. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  396. from the input image. Can be one of:
  397. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  398. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  399. Returns:
  400. `np.ndarray`: The normalized image.
  401. """
  402. return normalize(
  403. image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
  404. )
  405. def center_crop(
  406. self,
  407. image: np.ndarray,
  408. size: dict[str, int],
  409. data_format: str | ChannelDimension | None = None,
  410. input_data_format: str | ChannelDimension | None = None,
  411. **kwargs,
  412. ) -> np.ndarray:
  413. """
  414. Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
  415. any edge, the image is padded with 0's and then center cropped.
  416. Args:
  417. image (`np.ndarray`):
  418. Image to center crop.
  419. size (`dict[str, int]`):
  420. Size of the output image.
  421. data_format (`str` or `ChannelDimension`, *optional*):
  422. The channel dimension format for the output image. If unset, the channel dimension format of the input
  423. image is used. Can be one of:
  424. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  425. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  426. input_data_format (`ChannelDimension` or `str`, *optional*):
  427. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  428. from the input image. Can be one of:
  429. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  430. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  431. """
  432. size = get_size_dict(size)
  433. if "height" not in size or "width" not in size:
  434. raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
  435. return center_crop(
  436. image,
  437. size=(size["height"], size["width"]),
  438. data_format=data_format,
  439. input_data_format=input_data_format,
  440. **kwargs,
  441. )
  442. VALID_SIZE_DICT_KEYS = (
  443. {"height", "width"},
  444. {"shortest_edge"},
  445. {"shortest_edge", "longest_edge"},
  446. {"longest_edge"},
  447. {"max_height", "max_width"},
  448. )
  449. def is_valid_size_dict(size_dict):
  450. if not isinstance(size_dict, dict):
  451. return False
  452. size_dict_keys = set(size_dict.keys())
  453. for allowed_keys in VALID_SIZE_DICT_KEYS:
  454. if size_dict_keys == allowed_keys:
  455. return True
  456. return False
  457. def convert_to_size_dict(
  458. size: int | Iterable[int] | None = None,
  459. max_size: int | None = None,
  460. default_to_square: bool = True,
  461. height_width_order: bool = True,
  462. ) -> dict[str, int]:
  463. # By default, if size is an int we assume it represents a tuple of (size, size).
  464. if isinstance(size, int) and default_to_square:
  465. if max_size is not None:
  466. raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
  467. return {"height": size, "width": size}
  468. # In other configs, if size is an int and default_to_square is False, size represents the length of
  469. # the shortest edge after resizing.
  470. elif isinstance(size, int) and not default_to_square:
  471. size_dict = {"shortest_edge": size}
  472. if max_size is not None:
  473. size_dict["longest_edge"] = max_size
  474. return size_dict
  475. # Otherwise, if size is a tuple it's either (height, width) or (width, height)
  476. elif isinstance(size, (tuple, list)) and height_width_order:
  477. return {"height": size[0], "width": size[1]}
  478. elif isinstance(size, (tuple, list)) and not height_width_order:
  479. return {"height": size[1], "width": size[0]}
  480. elif size is None and max_size is not None:
  481. if default_to_square:
  482. raise ValueError("Cannot specify both default_to_square=True and max_size")
  483. return {"longest_edge": max_size}
  484. raise ValueError(f"Could not convert size input to size dict: {size}")
  485. def get_size_dict(
  486. size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
  487. max_size: int | None = None,
  488. height_width_order: bool = True,
  489. default_to_square: bool = True,
  490. param_name="size",
  491. ) -> dict:
  492. """
  493. Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
  494. compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
  495. width) or (width, height) format.
  496. - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
  497. size[0]}` if `height_width_order` is `False`.
  498. - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
  499. - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
  500. is set, it is added to the dict as `{"longest_edge": max_size}`.
  501. - If `size` is `None` and `default_to_square` is False, the result is `{"longest_edge": max_size}` (requires
  502. `max_size` to be set). Tuple/list/SizeDict/dict `size` values do not use `max_size`.
  503. Args:
  504. size (`int | Iterable[int] | dict[str, int] | SizeDict`, *optional*):
  505. The `size` parameter to be cast into a size dictionary.
  506. max_size (`int | None`, *optional*):
  507. With `default_to_square=False`, sets `longest_edge` when `size` is an int or `None`; unused for dict,
  508. `SizeDict`, or tuple/list `size`. Raises if set with `default_to_square=True` when `size` is an int or `None`.
  509. height_width_order (`bool`, *optional*, defaults to `True`):
  510. If `size` is a tuple, whether it's in (height, width) or (width, height) order.
  511. default_to_square (`bool`, *optional*, defaults to `True`):
  512. If `size` is an int, whether to default to a square image or not.
  513. """
  514. if not isinstance(size, dict | SizeDict):
  515. size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
  516. logger.info(
  517. f"{param_name} should be a dictionary with one of the following sets of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
  518. f" Converted to {size_dict}.",
  519. )
  520. # Some remote code bypasses or overrides `_standardize_kwargs`, so handle `SizeDict` `size` here too.
  521. elif isinstance(size, SizeDict):
  522. size_dict = dict(size)
  523. else:
  524. size_dict = size
  525. if not is_valid_size_dict(size_dict):
  526. raise ValueError(
  527. f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
  528. )
  529. return size_dict
  530. def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
  531. """
  532. Selects the best resolution from a list of possible resolutions based on the original size.
  533. This is done by calculating the effective and wasted resolution for each possible resolution.
  534. The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
  535. Args:
  536. original_size (tuple):
  537. The original size of the image in the format (height, width).
  538. possible_resolutions (list):
  539. A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
  540. Returns:
  541. tuple: The best fit resolution in the format (height, width).
  542. """
  543. original_height, original_width = original_size
  544. best_fit = None
  545. max_effective_resolution = 0
  546. min_wasted_resolution = float("inf")
  547. for height, width in possible_resolutions:
  548. scale = min(width / original_width, height / original_height)
  549. downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
  550. effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
  551. wasted_resolution = (width * height) - effective_resolution
  552. if effective_resolution > max_effective_resolution or (
  553. effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
  554. ):
  555. max_effective_resolution = effective_resolution
  556. min_wasted_resolution = wasted_resolution
  557. best_fit = (height, width)
  558. return best_fit
  559. def get_patch_output_size(image, target_resolution, input_data_format):
  560. """
  561. Given an image and a target resolution, calculate the output size of the image after cropping to the target
  562. """
  563. original_height, original_width = get_image_size(image, channel_dim=input_data_format)
  564. target_height, target_width = target_resolution
  565. scale_w = target_width / original_width
  566. scale_h = target_height / original_height
  567. if scale_w < scale_h:
  568. new_width = target_width
  569. new_height = min(math.ceil(original_height * scale_w), target_height)
  570. else:
  571. new_height = target_height
  572. new_width = min(math.ceil(original_width * scale_h), target_width)
  573. return new_height, new_width