mixup.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. """ Mixup and Cutmix
  2. Papers:
  3. mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
  4. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
  5. Code Reference:
  6. CutMix: https://github.com/clovaai/CutMix-PyTorch
  7. Hacked together by / Copyright 2019, Ross Wightman
  8. """
  9. import numpy as np
  10. import torch
  11. def one_hot(x, num_classes, on_value=1., off_value=0.):
  12. x = x.long().view(-1, 1)
  13. return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)
  14. def mixup_target(target, num_classes, lam=1., smoothing=0.0):
  15. off_value = smoothing / num_classes
  16. on_value = 1. - smoothing + off_value
  17. y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
  18. y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
  19. return y1 * lam + y2 * (1. - lam)
  20. def rand_bbox(img_shape, lam, margin=0., count=None):
  21. """ Standard CutMix bounding-box
  22. Generates a random square bbox based on lambda value. This impl includes
  23. support for enforcing a border margin as percent of bbox dimensions.
  24. Args:
  25. img_shape (tuple): Image shape as tuple
  26. lam (float): Cutmix lambda value
  27. margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
  28. count (int): Number of bbox to generate
  29. """
  30. ratio = np.sqrt(1 - lam)
  31. img_h, img_w = img_shape[-2:]
  32. cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
  33. margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
  34. cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
  35. cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
  36. yl = np.clip(cy - cut_h // 2, 0, img_h)
  37. yh = np.clip(cy + cut_h // 2, 0, img_h)
  38. xl = np.clip(cx - cut_w // 2, 0, img_w)
  39. xh = np.clip(cx + cut_w // 2, 0, img_w)
  40. return yl, yh, xl, xh
  41. def rand_bbox_minmax(img_shape, minmax, count=None):
  42. """ Min-Max CutMix bounding-box
  43. Inspired by Darknet cutmix impl, generates a random rectangular bbox
  44. based on min/max percent values applied to each dimension of the input image.
  45. Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
  46. Args:
  47. img_shape (tuple): Image shape as tuple
  48. minmax (tuple or list): Min and max bbox ratios (as percent of image size)
  49. count (int): Number of bbox to generate
  50. """
  51. assert len(minmax) == 2
  52. img_h, img_w = img_shape[-2:]
  53. cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
  54. cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
  55. yl = np.random.randint(0, img_h - cut_h, size=count)
  56. xl = np.random.randint(0, img_w - cut_w, size=count)
  57. yu = yl + cut_h
  58. xu = xl + cut_w
  59. return yl, yu, xl, xu
  60. def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
  61. """ Generate bbox and apply lambda correction.
  62. """
  63. if ratio_minmax is not None:
  64. yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
  65. else:
  66. yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
  67. if correct_lam or ratio_minmax is not None:
  68. bbox_area = (yu - yl) * (xu - xl)
  69. lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
  70. return (yl, yu, xl, xu), lam
  71. class Mixup:
  72. """ Mixup/Cutmix that applies different params to each element or whole batch
  73. Args:
  74. mixup_alpha (float): mixup alpha value, mixup is active if > 0.
  75. cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
  76. cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
  77. prob (float): probability of applying mixup or cutmix per batch or element
  78. switch_prob (float): probability of switching to cutmix instead of mixup when both are active
  79. mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
  80. correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
  81. label_smoothing (float): apply label smoothing to the mixed target tensor
  82. num_classes (int): number of classes for target
  83. """
  84. def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
  85. mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
  86. self.mixup_alpha = mixup_alpha
  87. self.cutmix_alpha = cutmix_alpha
  88. self.cutmix_minmax = cutmix_minmax
  89. if self.cutmix_minmax is not None:
  90. assert len(self.cutmix_minmax) == 2
  91. # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
  92. self.cutmix_alpha = 1.0
  93. self.mix_prob = prob
  94. self.switch_prob = switch_prob
  95. self.label_smoothing = label_smoothing
  96. self.num_classes = num_classes
  97. self.mode = mode
  98. self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
  99. self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
  100. def _params_per_elem(self, batch_size):
  101. lam = np.ones(batch_size, dtype=np.float32)
  102. use_cutmix = np.zeros(batch_size, dtype=bool)
  103. if self.mixup_enabled:
  104. if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
  105. use_cutmix = np.random.rand(batch_size) < self.switch_prob
  106. lam_mix = np.where(
  107. use_cutmix,
  108. np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
  109. np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
  110. elif self.mixup_alpha > 0.:
  111. lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
  112. elif self.cutmix_alpha > 0.:
  113. use_cutmix = np.ones(batch_size, dtype=bool)
  114. lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
  115. else:
  116. assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
  117. lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
  118. return lam, use_cutmix
  119. def _params_per_batch(self):
  120. lam = 1.
  121. use_cutmix = False
  122. if self.mixup_enabled and np.random.rand() < self.mix_prob:
  123. if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
  124. use_cutmix = np.random.rand() < self.switch_prob
  125. lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
  126. np.random.beta(self.mixup_alpha, self.mixup_alpha)
  127. elif self.mixup_alpha > 0.:
  128. lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
  129. elif self.cutmix_alpha > 0.:
  130. use_cutmix = True
  131. lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
  132. else:
  133. assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
  134. lam = float(lam_mix)
  135. return lam, use_cutmix
  136. def _mix_elem(self, x):
  137. batch_size = len(x)
  138. lam_batch, use_cutmix = self._params_per_elem(batch_size)
  139. x_orig = x.clone() # need to keep an unmodified original for mixing source
  140. for i in range(batch_size):
  141. j = batch_size - i - 1
  142. lam = lam_batch[i]
  143. if lam != 1.:
  144. if use_cutmix[i]:
  145. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  146. x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
  147. x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
  148. lam_batch[i] = lam
  149. else:
  150. x[i] = x[i] * lam + x_orig[j] * (1 - lam)
  151. return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
  152. def _mix_pair(self, x):
  153. batch_size = len(x)
  154. lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
  155. x_orig = x.clone() # need to keep an unmodified original for mixing source
  156. for i in range(batch_size // 2):
  157. j = batch_size - i - 1
  158. lam = lam_batch[i]
  159. if lam != 1.:
  160. if use_cutmix[i]:
  161. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  162. x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
  163. x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
  164. x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
  165. lam_batch[i] = lam
  166. else:
  167. x[i] = x[i] * lam + x_orig[j] * (1 - lam)
  168. x[j] = x[j] * lam + x_orig[i] * (1 - lam)
  169. lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
  170. return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
  171. def _mix_batch(self, x):
  172. lam, use_cutmix = self._params_per_batch()
  173. if lam == 1.:
  174. return 1.
  175. if use_cutmix:
  176. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  177. x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
  178. x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
  179. else:
  180. x_flipped = x.flip(0).mul_(1. - lam)
  181. x.mul_(lam).add_(x_flipped)
  182. return lam
  183. def __call__(self, x, target):
  184. assert len(x) % 2 == 0, 'Batch size should be even when using this'
  185. if self.mode == 'elem':
  186. lam = self._mix_elem(x)
  187. elif self.mode == 'pair':
  188. lam = self._mix_pair(x)
  189. else:
  190. lam = self._mix_batch(x)
  191. target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
  192. return x, target
  193. class FastCollateMixup(Mixup):
  194. """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
  195. A Mixup impl that's performed while collating the batches.
  196. """
  197. def _mix_elem_collate(self, output, batch, half=False):
  198. batch_size = len(batch)
  199. num_elem = batch_size // 2 if half else batch_size
  200. assert len(output) == num_elem
  201. lam_batch, use_cutmix = self._params_per_elem(num_elem)
  202. is_np = isinstance(batch[0][0], np.ndarray)
  203. for i in range(num_elem):
  204. j = batch_size - i - 1
  205. lam = lam_batch[i]
  206. mixed = batch[i][0]
  207. if lam != 1.:
  208. if use_cutmix[i]:
  209. if not half:
  210. mixed = mixed.copy() if is_np else mixed.clone()
  211. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  212. output.shape,
  213. lam,
  214. ratio_minmax=self.cutmix_minmax,
  215. correct_lam=self.correct_lam,
  216. )
  217. mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
  218. lam_batch[i] = lam
  219. else:
  220. if is_np:
  221. mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
  222. np.rint(mixed, out=mixed)
  223. else:
  224. mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
  225. torch.round(mixed, out=mixed)
  226. output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
  227. if half:
  228. lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
  229. return torch.tensor(lam_batch).unsqueeze(1)
  230. def _mix_pair_collate(self, output, batch):
  231. batch_size = len(batch)
  232. lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
  233. is_np = isinstance(batch[0][0], np.ndarray)
  234. for i in range(batch_size // 2):
  235. j = batch_size - i - 1
  236. lam = lam_batch[i]
  237. mixed_i = batch[i][0]
  238. mixed_j = batch[j][0]
  239. assert 0 <= lam <= 1.0
  240. if lam < 1.:
  241. if use_cutmix[i]:
  242. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  243. output.shape,
  244. lam,
  245. ratio_minmax=self.cutmix_minmax,
  246. correct_lam=self.correct_lam,
  247. )
  248. patch_i = mixed_i[:, yl:yh, xl:xh].copy() if is_np else mixed_i[:, yl:yh, xl:xh].clone()
  249. mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
  250. mixed_j[:, yl:yh, xl:xh] = patch_i
  251. lam_batch[i] = lam
  252. else:
  253. if is_np:
  254. mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
  255. mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
  256. mixed_i = mixed_temp
  257. np.rint(mixed_j, out=mixed_j)
  258. np.rint(mixed_i, out=mixed_i)
  259. else:
  260. mixed_temp = mixed_i.float() * lam + mixed_j.float() * (1 - lam)
  261. mixed_j = mixed_j.float() * lam + mixed_i.float() * (1 - lam)
  262. mixed_i = mixed_temp
  263. torch.round(mixed_j, out=mixed_j)
  264. torch.round(mixed_i, out=mixed_i)
  265. output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) if is_np else mixed_i.byte()
  266. output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) if is_np else mixed_j.byte()
  267. lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
  268. return torch.tensor(lam_batch).unsqueeze(1)
  269. def _mix_batch_collate(self, output, batch):
  270. batch_size = len(batch)
  271. lam, use_cutmix = self._params_per_batch()
  272. is_np = isinstance(batch[0][0], np.ndarray)
  273. if use_cutmix:
  274. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  275. output.shape,
  276. lam,
  277. ratio_minmax=self.cutmix_minmax,
  278. correct_lam=self.correct_lam,
  279. )
  280. for i in range(batch_size):
  281. j = batch_size - i - 1
  282. mixed = batch[i][0]
  283. if lam != 1.:
  284. if use_cutmix:
  285. mixed = mixed.copy() if is_np else mixed.clone() # don't want to modify the original while iterating
  286. mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
  287. else:
  288. if is_np:
  289. mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
  290. np.rint(mixed, out=mixed)
  291. else:
  292. mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
  293. torch.round(mixed, out=mixed)
  294. output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
  295. return lam
  296. def __call__(self, batch, _=None):
  297. batch_size = len(batch)
  298. assert batch_size % 2 == 0, 'Batch size should be even when using this'
  299. half = 'half' in self.mode
  300. if half:
  301. batch_size //= 2
  302. output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
  303. if self.mode == 'elem' or self.mode == 'half':
  304. lam = self._mix_elem_collate(output, batch, half=half)
  305. elif self.mode == 'pair':
  306. lam = self._mix_pair_collate(output, batch)
  307. else:
  308. lam = self._mix_batch_collate(output, batch)
  309. target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
  310. target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
  311. target = target[:batch_size]
  312. return output, target