naflex_transforms.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. """ NaFlex (NaViT + FlexiViT) Transforms and Collation
  2. Implements PyTorch versions of the transforms described in the NaViT and FlexiViT papers:
  3. - NaViT: https://arxiv.org/abs/2307.14995
  4. - FlexiViT: https://arxiv.org/abs/2212.08013
  5. Enables variable resolution/aspect ratio image handling with efficient patching.
  6. Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
  7. """
  8. import math
  9. import random
  10. import warnings
  11. from typing import Dict, List, Optional, Sequence, Tuple, Union
  12. import torch
  13. from PIL import Image
  14. from torchvision import transforms
  15. from torchvision.transforms import functional as F
  16. from torchvision.transforms.functional import InterpolationMode
  17. from .transforms import str_to_interp_mode, crop_or_pad, center_crop_or_pad
  18. def get_image_size_for_seq(
  19. image_hw: Tuple[int, int],
  20. patch_size: Union[int, Tuple[int, int]] = 16,
  21. max_seq_len: int = 1024,
  22. divisible_by_patch: bool = True,
  23. max_ratio: Optional[float] = None,
  24. eps: float = 1e-5,
  25. ) -> Tuple[float, Tuple[int, int]]:
  26. """Determine scaling ratio and image size for sequence length constraint.
  27. Calculates the scaling ratio needed so that when image_hw is scaled,
  28. the total number of resulting patches does not exceed max_seq_len.
  29. Args:
  30. image_hw: Original image dimensions (height, width).
  31. patch_size: Patch dimensions. If int, patches are square.
  32. max_seq_len: Maximum allowed sequence length.
  33. divisible_by_patch: Whether resulting dimensions must be divisible by patch_size.
  34. max_ratio: Optional cap on scaling ratio to prevent excessive upsampling.
  35. eps: Convergence threshold for binary search.
  36. Returns:
  37. Tuple of (ratio, target_hw) where ratio is the scaling factor and
  38. target_hw is the resulting (height, width) after scaling.
  39. """
  40. # Handle patch size input, extract patch_h, patch_w
  41. if isinstance(patch_size, int):
  42. patch_h, patch_w = patch_size, patch_size
  43. else:
  44. # Assume it's a tuple/list: (patch_h, patch_w)
  45. if len(patch_size) != 2:
  46. raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
  47. patch_h, patch_w = patch_size
  48. # Safety checks
  49. if patch_h <= 0 or patch_w <= 0:
  50. raise ValueError("patch_size dimensions must be positive.")
  51. def prepare_target_hw(ratio):
  52. """Scale image_hw by ratio and optionally round dimensions to multiples of patch_h, patch_w."""
  53. scaled_h = image_hw[0] * ratio
  54. scaled_w = image_hw[1] * ratio
  55. # If we need the result to be divisible by patch_size
  56. if divisible_by_patch:
  57. scaled_h = patch_h * math.ceil(scaled_h / patch_h)
  58. scaled_w = patch_w * math.ceil(scaled_w / patch_w)
  59. # Ensure at least one patch in each dimension
  60. scaled_h = int(max(scaled_h, patch_h))
  61. scaled_w = int(max(scaled_w, patch_w))
  62. return scaled_h, scaled_w
  63. def is_feasible(ratio):
  64. """Check if scaling by 'ratio' keeps patch count within max_seq_len."""
  65. t_h, t_w = prepare_target_hw(ratio)
  66. # Each dimension is already a multiple of patch_h, patch_w if divisible_by_patch=True.
  67. # Use integer division to count patches.
  68. num_patches_h = t_h // patch_h
  69. num_patches_w = t_w // patch_w
  70. seq_len = num_patches_h * num_patches_w
  71. return seq_len <= max_seq_len
  72. # Binary search boundaries
  73. lb = eps / 10.0
  74. rb = 100.0
  75. # Standard binary search loop
  76. while (rb - lb) >= eps:
  77. mid = (lb + rb) / 2.0
  78. if is_feasible(mid):
  79. lb = mid
  80. else:
  81. rb = mid
  82. # The final ratio from the binary search
  83. ratio = lb
  84. # If max_ratio is provided, clamp it to prevent upsampling beyond that threshold
  85. if max_ratio is not None:
  86. ratio = min(ratio, max_ratio)
  87. # Final checks
  88. if ratio <= eps:
  89. raise ValueError("Binary search failed - image might be too large?")
  90. if ratio >= 100.0:
  91. raise ValueError("Binary search failed - image might be too small?")
  92. # Prepare the final target dimensions with the possibly clamped ratio
  93. target_hw = prepare_target_hw(ratio)
  94. return ratio, target_hw
  95. _RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
  96. class ResizeToSequence(torch.nn.Module):
  97. """Resize image to fit within a maximum sequence length constraint when patchified.
  98. This maintains aspect ratio while ensuring the resulting image, when divided into patches,
  99. will not exceed the specified maximum sequence length.
  100. """
  101. def __init__(
  102. self,
  103. patch_size: int,
  104. max_seq_len: int = 1024,
  105. divisible_by_patch: bool = True,
  106. max_ratio: Optional[float] = None,
  107. interpolation: Union[str, InterpolationMode, Tuple[InterpolationMode, ...]] = 'bicubic',
  108. ) -> None:
  109. """Initialize ResizeToSequence transform.
  110. Args:
  111. patch_size: Size of patches.
  112. max_seq_len: Maximum sequence length constraint.
  113. divisible_by_patch: Whether dimensions must be divisible by patch_size.
  114. max_ratio: Optional cap on scaling ratio.
  115. interpolation: Interpolation method or methods.
  116. """
  117. super().__init__()
  118. self.patch_size = patch_size
  119. self.max_seq_len = max_seq_len
  120. self.divisible_by_patch = divisible_by_patch
  121. self.max_ratio = max_ratio
  122. if isinstance(interpolation, str):
  123. if interpolation == 'random':
  124. self.interpolation = _RANDOM_INTERPOLATION
  125. else:
  126. self.interpolation = str_to_interp_mode(interpolation)
  127. else:
  128. self.interpolation = interpolation
  129. def forward(self, img: torch.Tensor) -> torch.Tensor:
  130. """Resize image to maintain aspect ratio and fit sequence constraint.
  131. Args:
  132. img: Input image tensor.
  133. Returns:
  134. Resized image tensor.
  135. """
  136. _, h, w = transforms.functional.get_dimensions(img)
  137. _, target_hw = get_image_size_for_seq(
  138. (h, w),
  139. self.patch_size,
  140. self.max_seq_len,
  141. divisible_by_patch=self.divisible_by_patch,
  142. max_ratio=self.max_ratio,
  143. )
  144. if isinstance(self.interpolation, (tuple, list)):
  145. interpolation = random.choice(self.interpolation)
  146. else:
  147. interpolation = self.interpolation
  148. resized_img = transforms.functional.resize(img, target_hw, interpolation=interpolation, antialias=True)
  149. return resized_img
  150. class ResizeKeepRatioToSequence(torch.nn.Module):
  151. """
  152. Resize and Keep Aspect Ratio, adapted to fit sequence length constraints.
  153. """
  154. def __init__(
  155. self,
  156. patch_size=16,
  157. max_sequence_len=1024,
  158. divisible_by_patch=True,
  159. longest=0.,
  160. interpolation='bilinear',
  161. random_scale_prob=0.,
  162. random_scale_range=(0.85, 1.05),
  163. random_scale_area=False,
  164. random_aspect_prob=0.,
  165. random_aspect_range=(0.9, 1.11),
  166. max_ratio=None,
  167. ):
  168. """
  169. Args:
  170. patch_size: Size of patches (int or tuple of (patch_h, patch_w))
  171. max_sequence_len: Maximum allowed sequence length for the resulting image
  172. divisible_by_patch: If True, ensure dimensions are divisible by patch_size
  173. longest: Float between 0-1 where 0=shortest side, 1=longest side determines scale
  174. interpolation: Interpolation method for resizing
  175. random_scale_prob: Probability of applying random scaling
  176. random_scale_range: Range for random scaling factor (min, max)
  177. random_scale_area: If True, scale factors affect area (√ factor)
  178. random_aspect_prob: Probability of applying random aspect ratio jittering
  179. random_aspect_range: Range for random aspect ratio (min, max)
  180. max_ratio: Maximum allowed scaling ratio
  181. """
  182. super().__init__()
  183. self.patch_size = patch_size
  184. self.max_sequence_len = max_sequence_len
  185. self.divisible_by_patch = divisible_by_patch
  186. self.longest = float(longest)
  187. if interpolation == 'random':
  188. self.interpolation = _RANDOM_INTERPOLATION
  189. else:
  190. self.interpolation = str_to_interp_mode(interpolation)
  191. self.random_scale_prob = random_scale_prob
  192. self.random_scale_range = random_scale_range
  193. self.random_scale_area = random_scale_area
  194. self.random_aspect_prob = random_aspect_prob
  195. self.random_aspect_range = random_aspect_range
  196. self.max_ratio = max_ratio
  197. @staticmethod
  198. def get_params(
  199. img,
  200. patch_size,
  201. max_sequence_len,
  202. divisible_by_patch,
  203. longest,
  204. random_scale_prob=0.,
  205. random_scale_range=(1.0, 1.33),
  206. random_scale_area=False,
  207. random_aspect_prob=0.,
  208. random_aspect_range=(0.9, 1.11),
  209. max_ratio=None,
  210. ):
  211. """Get parameters for resizing."""
  212. # Get image dimensions
  213. img_h, img_w = F.get_dimensions(img)[1:]
  214. # Step 1: Get the maximum allowed dimensions from sequence length constraint
  215. _, target_hw = get_image_size_for_seq(
  216. (img_h, img_w),
  217. patch_size,
  218. max_sequence_len,
  219. divisible_by_patch,
  220. max_ratio,
  221. )
  222. target_h, target_w = target_hw
  223. # Calculate ratio based on sequence constraint
  224. ratio_h = target_h / img_h
  225. ratio_w = target_w / img_w
  226. # Apply longest blending
  227. ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
  228. # Apply random scaling
  229. if random_scale_prob > 0 and random.random() < random_scale_prob:
  230. ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
  231. if random_scale_area:
  232. # Make ratio factor equivalent to area change
  233. ratio_factor = 1. / math.sqrt(ratio_factor)
  234. ratio_factor = (ratio_factor, ratio_factor)
  235. else:
  236. ratio_factor = (1., 1.)
  237. # Apply random aspect
  238. if random_aspect_prob > 0 and random.random() < random_aspect_prob:
  239. log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
  240. aspect_factor = math.exp(random.uniform(*log_aspect))
  241. aspect_factor = math.sqrt(aspect_factor)
  242. # Apply aspect ratio jittering
  243. ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
  244. # Calculate final dimensions
  245. size = [round(dim * ratio * f) for dim, f in zip((img_h, img_w), ratio_factor)]
  246. # Ensure dimensions satisfy sequence constraint and are divisible by patch size
  247. if isinstance(patch_size, int):
  248. ph, pw = patch_size, patch_size
  249. else:
  250. ph, pw = patch_size
  251. # Ensure dimensions are at least one patch
  252. size[0] = max(size[0], ph)
  253. size[1] = max(size[1], pw)
  254. # Make divisible by patch size if needed
  255. if divisible_by_patch:
  256. size[0] = ph * math.ceil(size[0] / ph)
  257. size[1] = pw * math.ceil(size[1] / pw)
  258. # Verify we haven't exceeded sequence length
  259. num_patches_h = size[0] // ph
  260. num_patches_w = size[1] // pw
  261. seq_len = num_patches_h * num_patches_w
  262. if seq_len > max_sequence_len:
  263. # Scale back down to fit sequence constraint
  264. scale_back = math.sqrt(max_sequence_len / seq_len)
  265. size[0] = int(size[0] * scale_back)
  266. size[1] = int(size[1] * scale_back)
  267. # Ensure divisible by patch size after scaling back
  268. if divisible_by_patch:
  269. size[0] = ph * math.ceil(size[0] / ph)
  270. size[1] = pw * math.ceil(size[1] / pw)
  271. return size
  272. def forward(self, img):
  273. """
  274. Resize the image with aspect ratio preservation and sequence length constraints.
  275. """
  276. size = self.get_params(
  277. img,
  278. self.patch_size,
  279. self.max_sequence_len,
  280. self.divisible_by_patch,
  281. self.longest,
  282. self.random_scale_prob,
  283. self.random_scale_range,
  284. self.random_scale_area,
  285. self.random_aspect_prob,
  286. self.random_aspect_range,
  287. self.max_ratio,
  288. )
  289. if isinstance(self.interpolation, (tuple, list)):
  290. interpolation = random.choice(self.interpolation)
  291. else:
  292. interpolation = self.interpolation
  293. return F.resize(img, size, interpolation)
  294. def __repr__(self):
  295. interpolate_str = "random" if isinstance(self.interpolation, (tuple, list)) else str(self.interpolation)
  296. return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
  297. f"max_sequence_len={self.max_sequence_len}, "
  298. f"longest={self.longest:.3f}, "
  299. f"random_scale_prob={self.random_scale_prob:.3f}, "
  300. f"random_aspect_prob={self.random_aspect_prob:.3f})")
  301. class CenterCropToSequence(torch.nn.Module):
  302. """Center crop the image such that the resulting patch sequence length meets constraints."""
  303. def __init__(
  304. self,
  305. patch_size: int,
  306. max_seq_len: int,
  307. divisible_by_patch: bool = True,
  308. fill: Union[int, Tuple[int, int, int]] = 0,
  309. padding_mode: str = 'constant'
  310. ):
  311. super().__init__()
  312. self.patch_size = patch_size
  313. self.max_seq_len = max_seq_len
  314. self.divisible_by_patch = divisible_by_patch
  315. self.fill = fill
  316. self.padding_mode = padding_mode
  317. def forward(self, img):
  318. """Center crop the image to maintain aspect ratio and fit sequence constraint."""
  319. _, h, w = transforms.functional.get_dimensions(img)
  320. _, target_hw = get_image_size_for_seq(
  321. (h, w),
  322. self.patch_size,
  323. self.max_seq_len,
  324. self.divisible_by_patch
  325. )
  326. # Use center crop
  327. return center_crop_or_pad(img, target_hw, fill=self.fill, padding_mode=self.padding_mode)
  328. class RandomCropToSequence(torch.nn.Module):
  329. """Randomly crop and/or pad the image to fit sequence length constraints.
  330. This maintains aspect ratio while ensuring the resulting image, when divided into patches,
  331. will not exceed the specified maximum sequence length. Similar to CentralCropToSequence
  332. but with randomized positioning.
  333. """
  334. def __init__(
  335. self,
  336. patch_size: int,
  337. max_sequence_len: int,
  338. divisible_by_patch: bool = True,
  339. fill: Union[int, Tuple[int, int, int]] = 0,
  340. padding_mode: str = 'constant'
  341. ):
  342. """
  343. Args:
  344. patch_size: Size of patches (int or tuple of (patch_h, patch_w))
  345. max_sequence_len: Maximum allowed sequence length for the resulting image
  346. divisible_by_patch: If True, resulting image dimensions will be multiples of patch_size
  347. fill: Fill value for padding
  348. padding_mode: Padding mode ('constant', 'edge', 'reflect', 'symmetric')
  349. """
  350. super().__init__()
  351. self.patch_size = patch_size
  352. self.max_sequence_len = max_sequence_len
  353. self.divisible_by_patch = divisible_by_patch
  354. self.fill = fill
  355. self.padding_mode = padding_mode
  356. @staticmethod
  357. def get_params(img, target_size):
  358. """Get random position for crop/pad."""
  359. _, image_height, image_width = transforms.functional.get_dimensions(img)
  360. delta_height = image_height - target_size[0]
  361. delta_width = image_width - target_size[1]
  362. # Handle both positive (crop) and negative (pad) deltas
  363. if delta_height == 0:
  364. top = 0
  365. else:
  366. top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
  367. if delta_width == 0:
  368. left = 0
  369. else:
  370. left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
  371. return top, left
  372. def forward(self, img):
  373. """Randomly crop or pad the image to maintain aspect ratio and fit sequence constraint."""
  374. # Get current dimensions
  375. _, img_h, img_w = transforms.functional.get_dimensions(img)
  376. # Calculate target dimensions that satisfy sequence length
  377. # We use max_ratio=1.0 to prevent upscaling - we only want to crop or maintain current size
  378. _, target_hw = get_image_size_for_seq(
  379. (img_h, img_w),
  380. self.patch_size,
  381. self.max_sequence_len,
  382. self.divisible_by_patch,
  383. max_ratio=1.0 # Prevent upscaling
  384. )
  385. # Get random position for crop/pad
  386. top, left = self.get_params(img, target_hw)
  387. # Apply crop or pad
  388. return crop_or_pad(
  389. img,
  390. top=top,
  391. left=left,
  392. height=target_hw[0],
  393. width=target_hw[1],
  394. fill=self.fill,
  395. padding_mode=self.padding_mode,
  396. )
  397. def __repr__(self) -> str:
  398. return (f"{self.__class__.__name__}(patch_size={self.patch_size}, "
  399. f"max_sequence_len={self.max_sequence_len}, "
  400. f"divisible_by_patch={self.divisible_by_patch})")
  401. def _validate_range(value, name, length=2):
  402. # Validate type and length
  403. if not isinstance(value, Sequence) or len(value) != length:
  404. raise ValueError(f"{name} should be a sequence of length {length}.")
  405. # Validate order
  406. if value[0] > value[1]:
  407. warnings.warn(f"{name.capitalize()} range reversed. Swapping.")
  408. return value[1], value[0]
  409. return value
  410. class RandomResizedCropToSequence(torch.nn.Module):
  411. """
  412. Randomly crop the input image to a subregion with varying area and aspect ratio
  413. (relative to the original), then resize that crop to a target size. The target size
  414. is determined such that patchifying the resized image (with `patch_size`)
  415. does not exceed `max_seq_len` patches, while maintaining the aspect ratio of the crop.
  416. This combines aspects of torchvision's RandomResizedCrop with sequence length constraints.
  417. Args:
  418. patch_size (int or tuple[int, int]):
  419. Patch dimensions (patch_h, patch_w) for sequence length calculation.
  420. max_seq_len (int):
  421. Maximum number of patches allowed in the final image.
  422. scale (tuple[float, float]):
  423. Range (min, max) of area fraction of the original image to crop.
  424. ratio (tuple[float, float]):
  425. Range (min, max) of aspect ratio *multipliers* for the crop, relative
  426. to the original image's aspect ratio. E.g., (0.75, 1.333) means the
  427. crop's aspect ratio will be sampled between 0.75*orig_ar and 1.333*orig_ar.
  428. Uses log-uniform sampling.
  429. interpolation (str or InterpolationMode):
  430. Interpolation mode for resizing. Can be 'bilinear', 'bicubic', 'nearest',
  431. or 'random' (chooses between bilinear and bicubic).
  432. Defaults to 'bicubic'.
  433. divisible_by_patch (bool):
  434. If True, the final image height and width will be multiples of the
  435. respective patch dimensions. Defaults to True.
  436. max_ratio (float, optional):
  437. An optional upper limit on the scaling ratio applied during resizing.
  438. Prevents excessive upsampling of the initial crop. `max_ratio=1.0`
  439. prevents any upsampling beyond the cropped size. Defaults to None (no limit).
  440. final_scale_range (tuple[float, float], optional):
  441. If provided, applies an *additional* random scaling factor to the
  442. final target size. The factor is sampled uniformly from this range,
  443. and multiplied by the size determined by `get_image_size_for_seq`.
  444. E.g., (0.8, 1.0) means the final size will be between 80% and 100%
  445. of the maximum feasible size. Defaults to None (use maximum feasible size).
  446. attempts (int):
  447. Number of attempts to sample a valid crop geometry before falling back
  448. to a center crop strategy. Defaults to 10.
  449. """
  450. def __init__(
  451. self,
  452. patch_size: Union[int, Tuple[int, int]] = 16,
  453. max_seq_len: int = 1024,
  454. scale: Tuple[float, float] = (0.08, 1.0),
  455. ratio: Tuple[float, float] = (.8, 1.25),
  456. interpolation: Union[str, InterpolationMode] = 'bicubic',
  457. divisible_by_patch: bool = True,
  458. max_ratio: Optional[float] = None,
  459. final_scale_range: Optional[Tuple[float, float]] = None,
  460. attempts: int = 10,
  461. ):
  462. super().__init__()
  463. if isinstance(patch_size, int):
  464. self.patch_h, self.patch_w = patch_size, patch_size
  465. else:
  466. # Assume it's a tuple/list: (patch_h, patch_w)
  467. if len(patch_size) != 2:
  468. raise ValueError("patch_size tuple must have exactly two elements (patch_h, patch_w).")
  469. self.patch_h, self.patch_w = patch_size
  470. self.max_seq_len = max_seq_len
  471. self.scale = scale
  472. self.ratio = ratio
  473. self.divisible_by_patch = divisible_by_patch
  474. self.max_ratio = max_ratio
  475. self.final_scale_range = final_scale_range
  476. self.attempts = attempts
  477. if isinstance(interpolation, str):
  478. if interpolation == 'random':
  479. self.interpolation = _RANDOM_INTERPOLATION
  480. else:
  481. self.interpolation = str_to_interp_mode(interpolation)
  482. else:
  483. self.interpolation = interpolation
  484. # Validate scale and ratio
  485. self.scale = _validate_range(self.scale, "scale")
  486. self.ratio = _validate_range(self.ratio, "ratio")
  487. # Validate final_scale_range if provided
  488. if self.final_scale_range is not None:
  489. self.final_scale_range = _validate_range(self.final_scale_range, "final_scale_range")
  490. # Additional validation for final_scale_range values
  491. if not (0.0 <= self.final_scale_range[0] <= self.final_scale_range[1] <= 1.0):
  492. warnings.warn("final_scale_range values should ideally be between 0.0 and 1.0.")
  493. @staticmethod
  494. def get_params(
  495. img: torch.Tensor,
  496. scale: Tuple[float, float],
  497. ratio: Tuple[float, float],
  498. crop_attempts: int = 10,
  499. patch_h: int = 16,
  500. patch_w: int = 16,
  501. max_seq_len: int = 1024,
  502. divisible_by_patch: bool = True,
  503. max_ratio: Optional[float] = None,
  504. final_scale_range: Optional[Tuple[float, float]] = None,
  505. interpolation: Union[List[InterpolationMode], InterpolationMode] = _RANDOM_INTERPOLATION,
  506. ) -> Tuple[Tuple[int, int, int, int], Tuple[int, int], InterpolationMode]:
  507. """ Get parameters for a random sized crop relative to image aspect ratio.
  508. """
  509. _, height, width = F.get_dimensions(img)
  510. if height <= 0 or width <= 0:
  511. raise ValueError(f"Input image must have positive dimensions, got H={height}, W={width}")
  512. area = height * width
  513. orig_aspect = width / height
  514. log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
  515. for _ in range(crop_attempts):
  516. target_area = area * random.uniform(scale[0], scale[1])
  517. aspect_ratio_factor = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
  518. aspect_ratio = orig_aspect * aspect_ratio_factor
  519. # Calculate target dimensions for the crop
  520. # target_area = crop_w * crop_h, aspect_ratio = crop_w / crop_h
  521. # => crop_h = sqrt(target_area / aspect_ratio)
  522. # => crop_w = sqrt(target_area * aspect_ratio)
  523. crop_h = int(round(math.sqrt(target_area / aspect_ratio)))
  524. crop_w = int(round(math.sqrt(target_area * aspect_ratio)))
  525. if 0 < crop_w <= width and 0 < crop_h <= height:
  526. top = random.randint(0, height - crop_h)
  527. left = random.randint(0, width - crop_w)
  528. break
  529. else:
  530. # Fallback strategy, use center crop trying to respect ratio range
  531. min_aspect_ratio = orig_aspect * ratio[0]
  532. max_aspect_ratio = orig_aspect * ratio[1]
  533. if orig_aspect < min_aspect_ratio:
  534. # Original is narrower than target min, clamp width
  535. crop_w = width
  536. crop_h = min(int(round(crop_w / min_aspect_ratio)), height)
  537. elif orig_aspect > max_aspect_ratio:
  538. # Original is wider than target max, clamp height
  539. crop_h = height
  540. crop_w = min(int(round(crop_h * max_aspect_ratio)), width)
  541. else:
  542. # Aspect ratio is within range, take the largest possible crop (full image)
  543. crop_w = width
  544. crop_h = height
  545. # Ensure valid dimensions after fallback calculation
  546. crop_h = max(1, crop_h)
  547. crop_w = max(1, crop_w)
  548. top = (height - crop_h) // 2
  549. left = (width - crop_w) // 2
  550. # Determine max feasible size for scaling of the *cropped* region
  551. feasible_ratio, feasible_size = get_image_size_for_seq(
  552. (crop_h, crop_w),
  553. patch_size=(patch_h, patch_w), # Pass as tuple
  554. max_seq_len=max_seq_len,
  555. divisible_by_patch=divisible_by_patch,
  556. max_ratio=max_ratio,
  557. )
  558. # Optionally apply final scale randomization
  559. final_size = feasible_size
  560. if final_scale_range is not None:
  561. min_sc, max_sc = final_scale_range
  562. scale_factor = random.uniform(min_sc, max_sc)
  563. scale_factor = min(max(scale_factor, 0.0), 1.0) # Clamp factor just in case
  564. # Calculate raw scaled size
  565. # Note: feasible_ratio already accounts for max_ratio clamp if any
  566. raw_h = crop_h * feasible_ratio * scale_factor
  567. raw_w = crop_w * feasible_ratio * scale_factor
  568. # Re-apply divisibility constraint if needed
  569. if divisible_by_patch:
  570. # Use ceil to avoid going under minimum patch size
  571. target_h = patch_h * math.ceil(raw_h / patch_h)
  572. target_w = patch_w * math.ceil(raw_w / patch_w)
  573. else:
  574. target_h = int(round(raw_h))
  575. target_w = int(round(raw_w))
  576. # Ensure final size is at least one patch dimension
  577. target_h = max(target_h, patch_h)
  578. target_w = max(target_w, patch_w)
  579. final_size = (target_h, target_w)
  580. # Final check: Ensure this randomized size still fits max_seq_len
  581. # (It should, as we scaled down, but rounding might theoretically push it over)
  582. num_patches_h = final_size[0] // patch_h
  583. num_patches_w = final_size[1] // patch_w
  584. if (num_patches_h * num_patches_w) > max_seq_len:
  585. # If it exceeds, revert to the original feasible_size (safest)
  586. final_size = feasible_size
  587. warnings.warn(f"Final scale randomization ({scale_factor:.2f}) resulted in size {final_size} exceeding max_seq_len={max_seq_len} after rounding. Reverting to feasible size {feasible_size}.")
  588. # Select interpolation mode
  589. if isinstance(interpolation, (tuple, list)):
  590. interpolation = random.choice(interpolation)
  591. else:
  592. interpolation = interpolation
  593. return (top, left, crop_h, crop_w), final_size, interpolation
  594. def forward(self, img: torch.Tensor) -> torch.Tensor:
  595. # Sample crop, resize, and interpolation parameters
  596. crop_params, final_size, interpolation = self.get_params(
  597. img,
  598. scale=self.scale,
  599. ratio=self.ratio,
  600. crop_attempts=self.attempts,
  601. patch_h=self.patch_h,
  602. patch_w=self.patch_w,
  603. divisible_by_patch=self.divisible_by_patch,
  604. max_seq_len=self.max_seq_len,
  605. final_scale_range=self.final_scale_range,
  606. interpolation=self.interpolation,
  607. )
  608. top, left, crop_h, crop_w = crop_params
  609. output = F.resized_crop(
  610. img,
  611. top=top,
  612. left=left,
  613. height=crop_h,
  614. width=crop_w,
  615. size=final_size,
  616. interpolation=interpolation,
  617. antialias=True,
  618. )
  619. return output
  620. def __repr__(self) -> str:
  621. if isinstance(self.interpolation, (tuple, list)):
  622. interpolate_str = ', '.join(str(m).split('.')[-1] for m in self.interpolation)
  623. else:
  624. interpolate_str = str(self.interpolation)
  625. format_string = self.__class__.__name__ + '('
  626. format_string += f"patch_size=({self.patch_h}, {self.patch_w})"
  627. format_string += f", max_seq_len={self.max_seq_len}"
  628. format_string += f", scale={self.scale}"
  629. format_string += f", ratio={self.ratio}"
  630. format_string += f", interpolation=[{interpolate_str}]"
  631. format_string += f", divisible_by_patch={self.divisible_by_patch}"
  632. format_string += f", max_ratio={self.max_ratio}"
  633. format_string += f", final_scale_range={self.final_scale_range}"
  634. format_string += f", attempts={self.attempts}"
  635. format_string += ')'
  636. return format_string
  637. def patchify_image(
  638. img: torch.Tensor,
  639. patch_size: Tuple[int, int],
  640. pad: bool = True,
  641. include_info: bool = True,
  642. flatten_patches: bool = True,
  643. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  644. c, h, w = img.shape
  645. ph, pw = patch_size
  646. # Ensure the image is divisible by patch size
  647. if pad and (h % ph != 0 or w % pw != 0):
  648. pad_h = (ph - h % ph) % ph # amount to add on bottom
  649. pad_w = (pw - w % pw) % pw # amount to add on right
  650. img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
  651. c, h, w = img.shape
  652. # Calculate number of patches in each dimension
  653. nh, nw = h // ph, w // pw
  654. # Reshape image to patches
  655. patches = img.view(c, nh, ph, nw, pw).permute(1, 3, 2, 4, 0)
  656. # [nh, nw, ph, pw, c] -> [nh * nw, ph * pw * c] or [nh * nw, ph, pw, c]
  657. patches = patches.reshape(-1, ph * pw * c) if flatten_patches else patches.reshape(-1, ph, pw, c)
  658. if include_info:
  659. # Create coordinate indices
  660. y_idx, x_idx = torch.meshgrid(torch.arange(nh), torch.arange(nw), indexing='ij')
  661. # Stack into a single coords tensor [N, 2] with (y, x) order
  662. coord = torch.stack([y_idx.reshape(-1), x_idx.reshape(-1)], dim=1)
  663. # Create type indicators (all 1s for regular patches)
  664. valid = torch.ones(nh * nw, dtype=torch.bool)
  665. return patches, coord, valid
  666. return patches
  667. class Patchify(torch.nn.Module):
  668. """Transform an image into patches with corresponding coordinates and type indicators."""
  669. def __init__(
  670. self,
  671. patch_size: Union[int, Tuple[int, int]],
  672. flatten_patches: bool = True
  673. ):
  674. super().__init__()
  675. self.patch_size = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
  676. self.flatten_patches = flatten_patches
  677. def forward(self, img):
  678. """
  679. Args:
  680. img: A PIL Image or tensor of shape [C, H, W]
  681. Returns:
  682. A dictionary containing:
  683. - patches: Tensor of shape [N, P*P*C] if flatten_patches=True,
  684. or [N, Ph, Pw, C] if flatten_patches=False
  685. - patch_coord: Tensor of shape [N, 2] with (y, x) coordinates
  686. - patch_valid: Valid indicator (all 1s for non-padding patches)
  687. """
  688. if isinstance(img, Image.Image):
  689. # Convert PIL Image to tensor [C, H, W]
  690. img = transforms.functional.to_tensor(img)
  691. patches, coord, valid = patchify_image(img, self.patch_size, flatten_patches=self.flatten_patches)
  692. return {
  693. 'patches': patches,
  694. 'patch_coord': coord,
  695. 'patch_valid': valid,
  696. }