transforms.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. import math
  2. import numbers
  3. import random
  4. import warnings
  5. from typing import List, Sequence, Tuple, Union
  6. import torch
  7. import torchvision.transforms as transforms
  8. import torchvision.transforms.functional as F
  9. try:
  10. from torchvision.transforms.functional import InterpolationMode
  11. has_interpolation_mode = True
  12. except ImportError:
  13. has_interpolation_mode = False
  14. from PIL import Image
  15. import numpy as np
  16. __all__ = [
  17. "ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
  18. "RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
  19. "RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder", "MaybeToTensor", "MaybePILToTensor"
  20. ]
  21. class ToNumpy:
  22. def __call__(self, pil_img):
  23. np_img = np.array(pil_img, dtype=np.uint8)
  24. if np_img.ndim < 3:
  25. np_img = np.expand_dims(np_img, axis=-1)
  26. np_img = np.rollaxis(np_img, 2) # HWC to CHW
  27. return np_img
  28. class ToTensor:
  29. """ ToTensor with no rescaling of values"""
  30. def __init__(self, dtype=torch.float32):
  31. self.dtype = dtype
  32. def __call__(self, pil_img):
  33. return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
  34. class MaybeToTensor(transforms.ToTensor):
  35. """Convert a PIL Image or ndarray to tensor if it's not already one.
  36. """
  37. def __init__(self) -> None:
  38. super().__init__()
  39. def __call__(self, pic) -> torch.Tensor:
  40. """
  41. Args:
  42. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
  43. Returns:
  44. Tensor: Converted image.
  45. """
  46. if isinstance(pic, torch.Tensor):
  47. return pic
  48. return F.to_tensor(pic)
  49. def __repr__(self) -> str:
  50. return f"{self.__class__.__name__}()"
  51. class MaybePILToTensor:
  52. """Convert a PIL Image to a tensor of the same type - this does not scale values.
  53. """
  54. def __init__(self) -> None:
  55. super().__init__()
  56. def __call__(self, pic):
  57. """
  58. Note: A deep copy of the underlying array is performed.
  59. Args:
  60. pic (PIL Image): Image to be converted to tensor.
  61. Returns:
  62. Tensor: Converted image.
  63. """
  64. if isinstance(pic, torch.Tensor):
  65. return pic
  66. return F.pil_to_tensor(pic)
  67. def __repr__(self) -> str:
  68. return f"{self.__class__.__name__}()"
  69. # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
  70. # favor of the Image.Resampling enum. The top-level resampling attributes will be
  71. # removed in Pillow 10.
  72. if hasattr(Image, "Resampling"):
  73. _pil_interpolation_to_str = {
  74. Image.Resampling.NEAREST: 'nearest',
  75. Image.Resampling.BILINEAR: 'bilinear',
  76. Image.Resampling.BICUBIC: 'bicubic',
  77. Image.Resampling.BOX: 'box',
  78. Image.Resampling.HAMMING: 'hamming',
  79. Image.Resampling.LANCZOS: 'lanczos',
  80. }
  81. else:
  82. _pil_interpolation_to_str = {
  83. Image.NEAREST: 'nearest',
  84. Image.BILINEAR: 'bilinear',
  85. Image.BICUBIC: 'bicubic',
  86. Image.BOX: 'box',
  87. Image.HAMMING: 'hamming',
  88. Image.LANCZOS: 'lanczos',
  89. }
  90. _str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
  91. if has_interpolation_mode:
  92. _torch_interpolation_to_str = {
  93. InterpolationMode.NEAREST: 'nearest',
  94. InterpolationMode.BILINEAR: 'bilinear',
  95. InterpolationMode.BICUBIC: 'bicubic',
  96. InterpolationMode.BOX: 'box',
  97. InterpolationMode.HAMMING: 'hamming',
  98. InterpolationMode.LANCZOS: 'lanczos',
  99. }
  100. _str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
  101. else:
  102. _pil_interpolation_to_torch = {}
  103. _torch_interpolation_to_str = {}
  104. def str_to_pil_interp(mode_str):
  105. return _str_to_pil_interpolation[mode_str]
  106. def str_to_interp_mode(mode_str):
  107. if has_interpolation_mode:
  108. return _str_to_torch_interpolation[mode_str]
  109. else:
  110. return _str_to_pil_interpolation[mode_str]
  111. def interp_mode_to_str(mode):
  112. if has_interpolation_mode:
  113. return _torch_interpolation_to_str[mode]
  114. else:
  115. return _pil_interpolation_to_str[mode]
  116. _RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
  117. def _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."):
  118. if isinstance(size, numbers.Number):
  119. return int(size), int(size)
  120. if isinstance(size, Sequence) and len(size) == 1:
  121. return size[0], size[0]
  122. if len(size) != 2:
  123. raise ValueError(error_msg)
  124. return size
  125. class RandomResizedCropAndInterpolation:
  126. """Crop the given PIL Image to random size and aspect ratio with random interpolation.
  127. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  128. aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
  129. is finally resized to given size.
  130. This is popularly used to train the Inception networks.
  131. Args:
  132. size: expected output size of each edge
  133. scale: range of size of the origin size cropped
  134. ratio: range of aspect ratio of the origin aspect ratio cropped
  135. interpolation: Default: PIL.Image.BILINEAR
  136. """
  137. def __init__(
  138. self,
  139. size,
  140. scale=(0.08, 1.0),
  141. ratio=(3. / 4., 4. / 3.),
  142. interpolation='bilinear',
  143. ):
  144. if isinstance(size, (list, tuple)):
  145. self.size = tuple(size)
  146. else:
  147. self.size = (size, size)
  148. if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
  149. warnings.warn("range should be of kind (min, max)")
  150. if interpolation == 'random':
  151. self.interpolation = _RANDOM_INTERPOLATION
  152. else:
  153. self.interpolation = str_to_interp_mode(interpolation)
  154. self.scale = scale
  155. self.ratio = ratio
  156. @staticmethod
  157. def get_params(img, scale, ratio):
  158. """Get parameters for ``crop`` for a random sized crop.
  159. Args:
  160. img (PIL Image): Image to be cropped.
  161. scale (tuple): range of size of the origin size cropped
  162. ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
  163. Returns:
  164. tuple: params (i, j, h, w) to be passed to ``crop`` for a random
  165. sized crop.
  166. """
  167. img_w, img_h = F.get_image_size(img)
  168. area = img_w * img_h
  169. for attempt in range(10):
  170. target_area = random.uniform(*scale) * area
  171. log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
  172. aspect_ratio = math.exp(random.uniform(*log_ratio))
  173. target_w = int(round(math.sqrt(target_area * aspect_ratio)))
  174. target_h = int(round(math.sqrt(target_area / aspect_ratio)))
  175. if target_w <= img_w and target_h <= img_h:
  176. i = random.randint(0, img_h - target_h)
  177. j = random.randint(0, img_w - target_w)
  178. return i, j, target_h, target_w
  179. # Fallback to central crop
  180. in_ratio = img_w / img_h
  181. if in_ratio < min(ratio):
  182. target_w = img_w
  183. target_h = int(round(target_w / min(ratio)))
  184. elif in_ratio > max(ratio):
  185. target_h = img_h
  186. target_w = int(round(target_h * max(ratio)))
  187. else: # whole image
  188. target_w = img_w
  189. target_h = img_h
  190. i = (img_h - target_h) // 2
  191. j = (img_w - target_w) // 2
  192. return i, j, target_h, target_w
  193. def __call__(self, img):
  194. """
  195. Args:
  196. img (PIL Image): Image to be cropped and resized.
  197. Returns:
  198. PIL Image: Randomly cropped and resized image.
  199. """
  200. i, j, h, w = self.get_params(img, self.scale, self.ratio)
  201. if isinstance(self.interpolation, (tuple, list)):
  202. interpolation = random.choice(self.interpolation)
  203. else:
  204. interpolation = self.interpolation
  205. return F.resized_crop(img, i, j, h, w, self.size, interpolation)
  206. def __repr__(self):
  207. if isinstance(self.interpolation, (tuple, list)):
  208. interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
  209. else:
  210. interpolate_str = interp_mode_to_str(self.interpolation)
  211. format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
  212. format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
  213. format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
  214. format_string += ', interpolation={0})'.format(interpolate_str)
  215. return format_string
  216. def center_crop_or_pad(
  217. img: torch.Tensor,
  218. output_size: Union[int, List[int]],
  219. fill: Union[int, Tuple[int, int, int]] = 0,
  220. padding_mode: str = 'constant',
  221. ) -> torch.Tensor:
  222. """Center crops and/or pads the given image.
  223. If the image is torch Tensor, it is expected
  224. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  225. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
  226. Args:
  227. img (PIL Image or Tensor): Image to be cropped.
  228. output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
  229. it is used for both directions.
  230. fill (int, Tuple[int]): Padding color
  231. Returns:
  232. PIL Image or Tensor: Cropped image.
  233. """
  234. output_size = _setup_size(output_size)
  235. crop_height, crop_width = output_size
  236. _, image_height, image_width = F.get_dimensions(img)
  237. if crop_width > image_width or crop_height > image_height:
  238. padding_ltrb = [
  239. (crop_width - image_width) // 2 if crop_width > image_width else 0,
  240. (crop_height - image_height) // 2 if crop_height > image_height else 0,
  241. (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
  242. (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
  243. ]
  244. img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
  245. _, image_height, image_width = F.get_dimensions(img)
  246. if crop_width == image_width and crop_height == image_height:
  247. return img
  248. crop_top = int(round((image_height - crop_height) / 2.0))
  249. crop_left = int(round((image_width - crop_width) / 2.0))
  250. return F.crop(img, crop_top, crop_left, crop_height, crop_width)
  251. class CenterCropOrPad(torch.nn.Module):
  252. """Crops the given image at the center.
  253. If the image is torch Tensor, it is expected
  254. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  255. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
  256. Args:
  257. size (sequence or int): Desired output size of the crop. If size is an
  258. int instead of sequence like (h, w), a square crop (size, size) is
  259. made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
  260. """
  261. def __init__(
  262. self,
  263. size: Union[int, List[int]],
  264. fill: Union[int, Tuple[int, int, int]] = 0,
  265. padding_mode: str = 'constant',
  266. ):
  267. super().__init__()
  268. self.size = _setup_size(size)
  269. self.fill = fill
  270. self.padding_mode = padding_mode
  271. def forward(self, img):
  272. """
  273. Args:
  274. img (PIL Image or Tensor): Image to be cropped.
  275. Returns:
  276. PIL Image or Tensor: Cropped image.
  277. """
  278. return center_crop_or_pad(img, self.size, fill=self.fill, padding_mode=self.padding_mode)
  279. def __repr__(self) -> str:
  280. return f"{self.__class__.__name__}(size={self.size})"
  281. def crop_or_pad(
  282. img: torch.Tensor,
  283. top: int,
  284. left: int,
  285. height: int,
  286. width: int,
  287. fill: Union[int, Tuple[int, int, int]] = 0,
  288. padding_mode: str = 'constant',
  289. ) -> torch.Tensor:
  290. """ Crops and/or pads image to meet target size, with control over fill and padding_mode.
  291. """
  292. _, image_height, image_width = F.get_dimensions(img)
  293. right = left + width
  294. bottom = top + height
  295. if left < 0 or top < 0 or right > image_width or bottom > image_height:
  296. padding_ltrb = [
  297. max(-left + min(0, right), 0),
  298. max(-top + min(0, bottom), 0),
  299. max(right - max(image_width, left), 0),
  300. max(bottom - max(image_height, top), 0),
  301. ]
  302. img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
  303. top = max(top, 0)
  304. left = max(left, 0)
  305. return F.crop(img, top, left, height, width)
  306. class RandomCropOrPad(torch.nn.Module):
  307. """ Crop and/or pad image with random placement within the crop or pad margin.
  308. """
  309. def __init__(
  310. self,
  311. size: Union[int, List[int]],
  312. fill: Union[int, Tuple[int, int, int]] = 0,
  313. padding_mode: str = 'constant',
  314. ):
  315. super().__init__()
  316. self.size = _setup_size(size)
  317. self.fill = fill
  318. self.padding_mode = padding_mode
  319. @staticmethod
  320. def get_params(img, size):
  321. _, image_height, image_width = F.get_dimensions(img)
  322. delta_height = image_height - size[0]
  323. delta_width = image_width - size[1]
  324. top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
  325. left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
  326. return top, left
  327. def forward(self, img):
  328. """
  329. Args:
  330. img (PIL Image or Tensor): Image to be cropped.
  331. Returns:
  332. PIL Image or Tensor: Cropped image.
  333. """
  334. top, left = self.get_params(img, self.size)
  335. return crop_or_pad(
  336. img,
  337. top=top,
  338. left=left,
  339. height=self.size[0],
  340. width=self.size[1],
  341. fill=self.fill,
  342. padding_mode=self.padding_mode,
  343. )
  344. def __repr__(self) -> str:
  345. return f"{self.__class__.__name__}(size={self.size})"
  346. class RandomPad:
  347. def __init__(self, input_size, fill=0):
  348. self.input_size = input_size
  349. self.fill = fill
  350. @staticmethod
  351. def get_params(img, input_size):
  352. width, height = F.get_image_size(img)
  353. delta_width = max(input_size[1] - width, 0)
  354. delta_height = max(input_size[0] - height, 0)
  355. pad_left = random.randint(0, delta_width)
  356. pad_top = random.randint(0, delta_height)
  357. pad_right = delta_width - pad_left
  358. pad_bottom = delta_height - pad_top
  359. return pad_left, pad_top, pad_right, pad_bottom
  360. def __call__(self, img):
  361. padding = self.get_params(img, self.input_size)
  362. img = F.pad(img, padding, self.fill)
  363. return img
  364. class ResizeKeepRatio:
  365. """ Resize and Keep Aspect Ratio
  366. """
  367. def __init__(
  368. self,
  369. size,
  370. longest=0.,
  371. interpolation='bilinear',
  372. random_scale_prob=0.,
  373. random_scale_range=(0.85, 1.05),
  374. random_scale_area=False,
  375. random_aspect_prob=0.,
  376. random_aspect_range=(0.9, 1.11),
  377. ):
  378. """
  379. Args:
  380. size:
  381. longest:
  382. interpolation:
  383. random_scale_prob:
  384. random_scale_range:
  385. random_scale_area:
  386. random_aspect_prob:
  387. random_aspect_range:
  388. """
  389. if isinstance(size, (list, tuple)):
  390. self.size = tuple(size)
  391. else:
  392. self.size = (size, size)
  393. if interpolation == 'random':
  394. self.interpolation = _RANDOM_INTERPOLATION
  395. else:
  396. self.interpolation = str_to_interp_mode(interpolation)
  397. self.longest = float(longest)
  398. self.random_scale_prob = random_scale_prob
  399. self.random_scale_range = random_scale_range
  400. self.random_scale_area = random_scale_area
  401. self.random_aspect_prob = random_aspect_prob
  402. self.random_aspect_range = random_aspect_range
  403. @staticmethod
  404. def get_params(
  405. img,
  406. target_size,
  407. longest,
  408. random_scale_prob=0.,
  409. random_scale_range=(1.0, 1.33),
  410. random_scale_area=False,
  411. random_aspect_prob=0.,
  412. random_aspect_range=(0.9, 1.11)
  413. ):
  414. """Get parameters
  415. """
  416. img_h, img_w = img_size = F.get_dimensions(img)[1:]
  417. target_h, target_w = target_size
  418. ratio_h = img_h / target_h
  419. ratio_w = img_w / target_w
  420. ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
  421. if random_scale_prob > 0 and random.random() < random_scale_prob:
  422. ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
  423. if random_scale_area:
  424. # make ratio factor equivalent to RRC area crop where < 1.0 = area zoom,
  425. # otherwise like affine scale where < 1.0 = linear zoom out
  426. ratio_factor = 1. / math.sqrt(ratio_factor)
  427. ratio_factor = (ratio_factor, ratio_factor)
  428. else:
  429. ratio_factor = (1., 1.)
  430. if random_aspect_prob > 0 and random.random() < random_aspect_prob:
  431. log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
  432. aspect_factor = math.exp(random.uniform(*log_aspect))
  433. aspect_factor = math.sqrt(aspect_factor)
  434. # currently applying random aspect adjustment equally to both dims,
  435. # could change to keep output sizes above their target where possible
  436. ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
  437. size = [round(x * f / ratio) for x, f in zip(img_size, ratio_factor)]
  438. return size
  439. def __call__(self, img):
  440. """
  441. Args:
  442. img (PIL Image): Image to be cropped and resized.
  443. Returns:
  444. PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
  445. """
  446. size = self.get_params(
  447. img, self.size, self.longest,
  448. self.random_scale_prob, self.random_scale_range, self.random_scale_area,
  449. self.random_aspect_prob, self.random_aspect_range
  450. )
  451. if isinstance(self.interpolation, (tuple, list)):
  452. interpolation = random.choice(self.interpolation)
  453. else:
  454. interpolation = self.interpolation
  455. img = F.resize(img, size, interpolation)
  456. return img
  457. def __repr__(self):
  458. if isinstance(self.interpolation, (tuple, list)):
  459. interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
  460. else:
  461. interpolate_str = interp_mode_to_str(self.interpolation)
  462. format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
  463. format_string += f', interpolation={interpolate_str}'
  464. format_string += f', longest={self.longest:.3f}'
  465. format_string += f', random_scale_prob={self.random_scale_prob:.3f}'
  466. format_string += f', random_scale_range=(' \
  467. f'{self.random_scale_range[0]:.3f}, {self.random_scale_range[1]:.3f})'
  468. format_string += f', random_aspect_prob={self.random_aspect_prob:.3f}'
  469. format_string += f', random_aspect_range=(' \
  470. f'{self.random_aspect_range[0]:.3f}, {self.random_aspect_range[1]:.3f}))'
  471. return format_string
  472. class TrimBorder(torch.nn.Module):
  473. def __init__(
  474. self,
  475. border_size: int,
  476. ):
  477. super().__init__()
  478. self.border_size = border_size
  479. def forward(self, img):
  480. w, h = F.get_image_size(img)
  481. top = left = self.border_size
  482. top = min(top, h)
  483. left = min(left, h)
  484. height = max(0, h - 2 * self.border_size)
  485. width = max(0, w - 2 * self.border_size)
  486. return F.crop(img, top, left, height, width)