transforms_factory.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. """ Transforms Factory
  2. Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
  3. Hacked together by / Copyright 2019, Ross Wightman
  4. """
  5. import math
  6. from typing import Optional, Tuple, Union
  7. import torch
  8. from torchvision import transforms
  9. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
  10. from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
  11. from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
  12. ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, MaybeToTensor, MaybePILToTensor
  13. from timm.data.naflex_transforms import RandomResizedCropToSequence, ResizeToSequence, Patchify
  14. from timm.data.random_erasing import RandomErasing
  15. def transforms_noaug_train(
  16. img_size: Union[int, Tuple[int, int]] = 224,
  17. interpolation: str = 'bilinear',
  18. mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
  19. std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
  20. use_prefetcher: bool = False,
  21. normalize: bool = True,
  22. ):
  23. """ No-augmentation image transforms for training.
  24. Args:
  25. img_size: Target image size.
  26. interpolation: Image interpolation mode.
  27. mean: Image normalization mean.
  28. std: Image normalization standard deviation.
  29. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
  30. normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
  31. Returns:
  32. """
  33. if interpolation == 'random':
  34. # random interpolation not supported with no-aug
  35. interpolation = 'bilinear'
  36. tfl = [
  37. transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),
  38. transforms.CenterCrop(img_size)
  39. ]
  40. if use_prefetcher:
  41. # prefetcher and collate will handle tensor conversion and norm
  42. tfl += [MaybePILToTensor()]
  43. elif not normalize:
  44. # when normalize disabled, converted to tensor without scaling, keep original dtype
  45. tfl += [MaybePILToTensor()]
  46. else:
  47. tfl += [
  48. MaybeToTensor(),
  49. transforms.Normalize(
  50. mean=torch.tensor(mean),
  51. std=torch.tensor(std)
  52. )
  53. ]
  54. return transforms.Compose(tfl)
  55. def transforms_imagenet_train(
  56. img_size: Union[int, Tuple[int, int]] = 224,
  57. scale: Optional[Tuple[float, float]] = None,
  58. ratio: Optional[Tuple[float, float]] = None,
  59. train_crop_mode: Optional[str] = None,
  60. hflip: float = 0.5,
  61. vflip: float = 0.,
  62. color_jitter: Union[float, Tuple[float, ...]] = 0.4,
  63. color_jitter_prob: Optional[float] = None,
  64. force_color_jitter: bool = False,
  65. grayscale_prob: float = 0.,
  66. gaussian_blur_prob: float = 0.,
  67. auto_augment: Optional[str] = None,
  68. interpolation: str = 'random',
  69. mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
  70. std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
  71. re_prob: float = 0.,
  72. re_mode: str = 'const',
  73. re_count: int = 1,
  74. re_num_splits: int = 0,
  75. use_prefetcher: bool = False,
  76. normalize: bool = True,
  77. separate: bool = False,
  78. naflex: bool = False,
  79. patch_size: Union[int, Tuple[int, int]] = 16,
  80. max_seq_len: int = 576, # 24x24 for 16x16 patch
  81. patchify: bool = False,
  82. ):
  83. """ ImageNet-oriented image transforms for training.
  84. Args:
  85. img_size: Target image size.
  86. train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
  87. scale: Random resize scale range (crop area, < 1.0 => zoom in).
  88. ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
  89. hflip: Horizontal flip probability.
  90. vflip: Vertical flip probability.
  91. color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
  92. Scalar is applied as (scalar,) * 3 (no hue).
  93. color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
  94. force_color_jitter: Force color jitter where it is normally disabled (ie with RandAugment on).
  95. grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
  96. gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
  97. auto_augment: Auto augment configuration string (see auto_augment.py).
  98. interpolation: Image interpolation mode.
  99. mean: Image normalization mean.
  100. std: Image normalization standard deviation.
  101. re_prob: Random erasing probability.
  102. re_mode: Random erasing fill mode.
  103. re_count: Number of random erasing regions.
  104. re_num_splits: Control split of random erasing across batch size.
  105. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
  106. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
  107. separate: Output transforms in 3-stage tuple.
  108. naflex: Enable NaFlex mode, sequence constrained patch output
  109. patch_size: Patch size for NaFlex mode.
  110. max_seq_len: Max sequence length for NaFlex mode.
  111. Returns:
  112. If separate==True, the transforms are returned as a tuple of 3 separate transforms
  113. for use in a mixing dataset that passes
  114. * all data through the first (primary) transform, called the 'clean' data
  115. * a portion of the data through the secondary transform
  116. * normalizes and converts the branches above with the third, final transform
  117. """
  118. train_crop_mode = train_crop_mode or 'rrc'
  119. assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
  120. primary_tfl = []
  121. if naflex:
  122. scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
  123. ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
  124. primary_tfl += [RandomResizedCropToSequence(
  125. patch_size=patch_size,
  126. max_seq_len=max_seq_len,
  127. scale=scale,
  128. ratio=ratio,
  129. interpolation=interpolation
  130. )]
  131. else:
  132. if train_crop_mode in ('rkrc', 'rkrr'):
  133. # FIXME integration of RKR is a WIP
  134. scale = tuple(scale or (0.8, 1.00))
  135. ratio = tuple(ratio or (0.9, 1/.9))
  136. primary_tfl += [
  137. ResizeKeepRatio(
  138. img_size,
  139. interpolation=interpolation,
  140. random_scale_prob=0.5,
  141. random_scale_range=scale,
  142. random_scale_area=True, # scale compatible with RRC
  143. random_aspect_prob=0.5,
  144. random_aspect_range=ratio,
  145. ),
  146. CenterCropOrPad(img_size, padding_mode='reflect')
  147. if train_crop_mode == 'rkrc' else
  148. RandomCropOrPad(img_size, padding_mode='reflect')
  149. ]
  150. else:
  151. scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
  152. ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
  153. primary_tfl += [
  154. RandomResizedCropAndInterpolation(
  155. img_size,
  156. scale=scale,
  157. ratio=ratio,
  158. interpolation=interpolation,
  159. )
  160. ]
  161. if hflip > 0.:
  162. primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
  163. if vflip > 0.:
  164. primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
  165. secondary_tfl = []
  166. disable_color_jitter = False
  167. if auto_augment:
  168. assert isinstance(auto_augment, str)
  169. # color jitter is typically disabled if AA/RA on,
  170. # this allows override without breaking old hparm cfgs
  171. disable_color_jitter = not (force_color_jitter or '3a' in auto_augment)
  172. if isinstance(img_size, (tuple, list)):
  173. img_size_min = min(img_size)
  174. else:
  175. img_size_min = img_size
  176. aa_params = dict(
  177. translate_const=int(img_size_min * 0.45),
  178. img_mean=tuple([min(255, round(255 * x)) for x in mean]),
  179. )
  180. if interpolation and interpolation != 'random':
  181. aa_params['interpolation'] = str_to_pil_interp(interpolation)
  182. if auto_augment.startswith('rand'):
  183. secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
  184. elif auto_augment.startswith('augmix'):
  185. aa_params['translate_pct'] = 0.3
  186. secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
  187. else:
  188. secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
  189. if color_jitter is not None and not disable_color_jitter:
  190. # color jitter is enabled when not using AA or when forced
  191. if isinstance(color_jitter, (list, tuple)):
  192. # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
  193. # or 4 if also augmenting hue
  194. assert len(color_jitter) in (3, 4)
  195. else:
  196. # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
  197. color_jitter = (float(color_jitter),) * 3
  198. if color_jitter_prob is not None:
  199. secondary_tfl += [
  200. transforms.RandomApply([
  201. transforms.ColorJitter(*color_jitter),
  202. ],
  203. p=color_jitter_prob
  204. )
  205. ]
  206. else:
  207. secondary_tfl += [transforms.ColorJitter(*color_jitter)]
  208. if grayscale_prob:
  209. secondary_tfl += [transforms.RandomGrayscale(p=grayscale_prob)]
  210. if gaussian_blur_prob:
  211. secondary_tfl += [
  212. transforms.RandomApply([
  213. transforms.GaussianBlur(kernel_size=23), # hardcoded for now
  214. ],
  215. p=gaussian_blur_prob,
  216. )
  217. ]
  218. final_tfl = []
  219. if use_prefetcher:
  220. # prefetcher and collate will handle tensor conversion and norm
  221. final_tfl += [MaybePILToTensor()]
  222. elif not normalize:
  223. # when normalize disable, converted to tensor without scaling, keeps original dtype
  224. final_tfl += [MaybePILToTensor()]
  225. else:
  226. final_tfl += [
  227. MaybeToTensor(),
  228. transforms.Normalize(
  229. mean=torch.tensor(mean),
  230. std=torch.tensor(std),
  231. ),
  232. ]
  233. if re_prob > 0.:
  234. final_tfl += [
  235. RandomErasing(
  236. re_prob,
  237. mode=re_mode,
  238. max_count=re_count,
  239. num_splits=re_num_splits,
  240. device='cpu',
  241. )
  242. ]
  243. if patchify:
  244. final_tfl += [Patchify(patch_size=patch_size)]
  245. if separate:
  246. return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
  247. else:
  248. return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
  249. def transforms_imagenet_eval(
  250. img_size: Union[int, Tuple[int, int]] = 224,
  251. crop_pct: Optional[float] = None,
  252. crop_mode: Optional[str] = None,
  253. crop_border_pixels: Optional[int] = None,
  254. interpolation: str = 'bilinear',
  255. mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
  256. std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
  257. use_prefetcher: bool = False,
  258. normalize: bool = True,
  259. naflex: bool = False,
  260. patch_size: Union[int, Tuple[int, int]] = 16,
  261. max_seq_len: int = 576, # 24x24 for 16x16 patch
  262. patchify: bool = False,
  263. ):
  264. """ ImageNet-oriented image transform for evaluation and inference.
  265. Args:
  266. img_size: Target image size.
  267. crop_pct: Crop percentage. Defaults to 0.875 when None.
  268. crop_mode: Crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
  269. crop_border_pixels: Trim a border of specified # pixels around edge of original image.
  270. interpolation: Image interpolation mode.
  271. mean: Image normalization mean.
  272. std: Image normalization standard deviation.
  273. use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
  274. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
  275. naflex: Enable NaFlex mode, sequence constrained patch output
  276. patch_size: Patch size for NaFlex mode.
  277. max_seq_len: Max sequence length for NaFlex mode.
  278. patchify: Patchify the output instead of relying on prefetcher
  279. Returns:
  280. Composed transform pipeline
  281. """
  282. crop_pct = crop_pct or DEFAULT_CROP_PCT
  283. if isinstance(img_size, (tuple, list)):
  284. assert len(img_size) == 2
  285. scale_size = tuple([math.floor(x / crop_pct) for x in img_size])
  286. else:
  287. scale_size = math.floor(img_size / crop_pct)
  288. scale_size = (scale_size, scale_size)
  289. tfl = []
  290. if crop_border_pixels:
  291. tfl += [TrimBorder(crop_border_pixels)]
  292. if naflex:
  293. tfl += [ResizeToSequence(
  294. patch_size=patch_size,
  295. max_seq_len=max_seq_len,
  296. interpolation=interpolation,
  297. )]
  298. else:
  299. if crop_mode == 'squash':
  300. # squash mode scales each edge to 1/pct of target, then crops
  301. # aspect ratio is not preserved, no img lost if crop_pct == 1.0
  302. tfl += [
  303. transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
  304. transforms.CenterCrop(img_size),
  305. ]
  306. elif crop_mode == 'border':
  307. # scale the longest edge of image to 1/pct of target edge, add borders to pad, then crop
  308. # no image lost if crop_pct == 1.0
  309. fill = [round(255 * v) for v in mean]
  310. tfl += [
  311. ResizeKeepRatio(scale_size, interpolation=interpolation, longest=1.0),
  312. CenterCropOrPad(img_size, fill=fill),
  313. ]
  314. else:
  315. # default crop model is center
  316. # aspect ratio is preserved, crops center within image, no borders are added, image is lost
  317. if scale_size[0] == scale_size[1]:
  318. # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
  319. tfl += [
  320. transforms.Resize(scale_size[0], interpolation=str_to_interp_mode(interpolation))
  321. ]
  322. else:
  323. # resize the shortest edge to matching target dim for non-square target
  324. tfl += [ResizeKeepRatio(scale_size)]
  325. tfl += [transforms.CenterCrop(img_size)]
  326. if use_prefetcher:
  327. # prefetcher and collate will handle tensor conversion and norm
  328. tfl += [MaybePILToTensor()]
  329. elif not normalize:
  330. # when normalize disabled, converted to tensor without scaling, keeps original dtype
  331. tfl += [MaybePILToTensor()]
  332. else:
  333. tfl += [
  334. MaybeToTensor(),
  335. transforms.Normalize(
  336. mean=torch.tensor(mean),
  337. std=torch.tensor(std),
  338. ),
  339. ]
  340. if patchify:
  341. tfl += [Patchify(patch_size=patch_size)]
  342. return transforms.Compose(tfl)
  343. def create_transform(
  344. input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224,
  345. is_training: bool = False,
  346. no_aug: bool = False,
  347. train_crop_mode: Optional[str] = None,
  348. scale: Optional[Tuple[float, float]] = None,
  349. ratio: Optional[Tuple[float, float]] = None,
  350. hflip: float = 0.5,
  351. vflip: float = 0.,
  352. color_jitter: Union[float, Tuple[float, ...]] = 0.4,
  353. color_jitter_prob: Optional[float] = None,
  354. grayscale_prob: float = 0.,
  355. gaussian_blur_prob: float = 0.,
  356. auto_augment: Optional[str] = None,
  357. interpolation: str = 'bilinear',
  358. mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
  359. std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
  360. re_prob: float = 0.,
  361. re_mode: str = 'const',
  362. re_count: int = 1,
  363. re_num_splits: int = 0,
  364. crop_pct: Optional[float] = None,
  365. crop_mode: Optional[str] = None,
  366. crop_border_pixels: Optional[int] = None,
  367. tf_preprocessing: bool = False,
  368. use_prefetcher: bool = False,
  369. normalize: bool = True,
  370. separate: bool = False,
  371. naflex: bool = False,
  372. patch_size: Union[int, Tuple[int, int]] = 16,
  373. max_seq_len: int = 576, # 24x24 for 16x16 patch
  374. patchify: bool = False
  375. ):
  376. """
  377. Args:
  378. input_size: Target input size (channels, height, width) tuple or size scalar.
  379. is_training: Return training (random) transforms.
  380. no_aug: Disable augmentation for training (useful for debug).
  381. train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
  382. scale: Random resize scale range (crop area, < 1.0 => zoom in).
  383. ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
  384. hflip: Horizontal flip probability.
  385. vflip: Vertical flip probability.
  386. color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
  387. Scalar is applied as (scalar,) * 3 (no hue).
  388. color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug).
  389. grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
  390. gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
  391. auto_augment: Auto augment configuration string (see auto_augment.py).
  392. interpolation: Image interpolation mode.
  393. mean: Image normalization mean.
  394. std: Image normalization standard deviation.
  395. re_prob: Random erasing probability.
  396. re_mode: Random erasing fill mode.
  397. re_count: Number of random erasing regions.
  398. re_num_splits: Control split of random erasing across batch size.
  399. crop_pct: Inference crop percentage (output size / resize size).
  400. crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
  401. crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
  402. tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
  403. use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
  404. normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
  405. separate: Output transforms in 3-stage tuple.
  406. Returns:
  407. Composed transforms or tuple thereof
  408. """
  409. if isinstance(input_size, (tuple, list)):
  410. img_size = input_size[-2:]
  411. else:
  412. img_size = input_size
  413. if tf_preprocessing and use_prefetcher:
  414. assert not separate, "Separate transforms not supported for TF preprocessing"
  415. from timm.data.tf_preprocessing import TfPreprocessTransform
  416. transform = TfPreprocessTransform(
  417. is_training=is_training,
  418. size=img_size,
  419. interpolation=interpolation,
  420. )
  421. else:
  422. if is_training and no_aug:
  423. assert not separate, "Cannot perform split augmentation with no_aug"
  424. transform = transforms_noaug_train(
  425. img_size,
  426. interpolation=interpolation,
  427. mean=mean,
  428. std=std,
  429. use_prefetcher=use_prefetcher,
  430. normalize=normalize,
  431. )
  432. elif is_training:
  433. transform = transforms_imagenet_train(
  434. img_size,
  435. train_crop_mode=train_crop_mode,
  436. scale=scale,
  437. ratio=ratio,
  438. hflip=hflip,
  439. vflip=vflip,
  440. color_jitter=color_jitter,
  441. color_jitter_prob=color_jitter_prob,
  442. grayscale_prob=grayscale_prob,
  443. gaussian_blur_prob=gaussian_blur_prob,
  444. auto_augment=auto_augment,
  445. interpolation=interpolation,
  446. mean=mean,
  447. std=std,
  448. re_prob=re_prob,
  449. re_mode=re_mode,
  450. re_count=re_count,
  451. re_num_splits=re_num_splits,
  452. use_prefetcher=use_prefetcher,
  453. normalize=normalize,
  454. separate=separate,
  455. naflex=naflex,
  456. patch_size=patch_size,
  457. max_seq_len=max_seq_len,
  458. patchify=patchify,
  459. )
  460. else:
  461. assert not separate, "Separate transforms not supported for validation preprocessing"
  462. transform = transforms_imagenet_eval(
  463. img_size,
  464. interpolation=interpolation,
  465. mean=mean,
  466. std=std,
  467. crop_pct=crop_pct,
  468. crop_mode=crop_mode,
  469. crop_border_pixels=crop_border_pixels,
  470. use_prefetcher=use_prefetcher,
  471. normalize=normalize,
  472. naflex=naflex,
  473. patch_size=patch_size,
  474. max_seq_len=max_seq_len,
  475. patchify=patchify,
  476. )
  477. return transform