naflex_mixup.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """Variable‑size Mixup / CutMix utilities for NaFlex data loaders.
  2. This module provides:
  3. * `mix_batch_variable_size` – pixel‑level Mixup/CutMix that operates on a
  4. list of images whose spatial sizes differ, mixing only their central overlap
  5. so no resizing is required.
  6. * `pairwise_mixup_target` – builds soft‑label targets that exactly match the
  7. per‑sample pixel provenance produced by the mixer.
  8. * `NaFlexMixup` – a callable functor that wraps the two helpers and stores
  9. all augmentation hyper‑parameters in one place, making it easy to plug into
  10. different dataset wrappers.
  11. Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
  12. """
  13. import math
  14. import random
  15. from typing import Dict, List, Tuple, Union
  16. import torch
  17. def mix_batch_variable_size(
  18. imgs: List[torch.Tensor],
  19. *,
  20. mixup_alpha: float = 0.8,
  21. cutmix_alpha: float = 1.0,
  22. switch_prob: float = 0.5,
  23. local_shuffle: int = 4,
  24. ) -> Tuple[List[torch.Tensor], List[float], Dict[int, int]]:
  25. """Apply Mixup or CutMix on a batch of variable-sized images.
  26. Sorts images by aspect ratio and pairs neighboring samples. Only the mutual
  27. central overlap region of each pair is mixed.
  28. Args:
  29. imgs: List of transformed images shaped (C, H, W).
  30. mixup_alpha: Beta distribution alpha for Mixup. Set to 0 to disable.
  31. cutmix_alpha: Beta distribution alpha for CutMix. Set to 0 to disable.
  32. switch_prob: Probability of using CutMix when both modes are enabled.
  33. local_shuffle: Size of local windows for shuffling after aspect sorting.
  34. Returns:
  35. Tuple of (mixed_imgs, lam_list, pair_to) where:
  36. - mixed_imgs: List of mixed images
  37. - lam_list: Per-sample lambda values representing mixing degree
  38. - pair_to: Mapping i -> j of which sample was mixed with which
  39. """
  40. if len(imgs) < 2:
  41. raise ValueError("Need at least two images to perform Mixup/CutMix.")
  42. # Decide augmentation mode and raw λ
  43. if mixup_alpha > 0.0 and cutmix_alpha > 0.0:
  44. use_cutmix = torch.rand(()).item() < switch_prob
  45. alpha = cutmix_alpha if use_cutmix else mixup_alpha
  46. elif mixup_alpha > 0.0:
  47. use_cutmix = False
  48. alpha = mixup_alpha
  49. elif cutmix_alpha > 0.0:
  50. use_cutmix = True
  51. alpha = cutmix_alpha
  52. else:
  53. raise ValueError("Both mixup_alpha and cutmix_alpha are zero – nothing to do.")
  54. lam_raw = torch.distributions.Beta(alpha, alpha).sample().item()
  55. lam_raw = max(0.0, min(1.0, lam_raw)) # numerical safety
  56. # Pair images by nearest aspect ratio
  57. order = sorted(range(len(imgs)), key=lambda i: imgs[i].shape[2] / imgs[i].shape[1])
  58. if local_shuffle > 1:
  59. for start in range(0, len(order), local_shuffle):
  60. random.shuffle(order[start:start + local_shuffle])
  61. pair_to: Dict[int, int] = {}
  62. for a, b in zip(order[::2], order[1::2]):
  63. pair_to[a] = b
  64. pair_to[b] = a
  65. odd_one = order[-1] if len(imgs) % 2 else None
  66. mixed_imgs: List[torch.Tensor] = [None] * len(imgs)
  67. lam_list: List[float] = [1.0] * len(imgs)
  68. for i in range(len(imgs)):
  69. if i == odd_one:
  70. mixed_imgs[i] = imgs[i]
  71. continue
  72. j = pair_to[i]
  73. xi, xj = imgs[i], imgs[j]
  74. _, hi, wi = xi.shape
  75. _, hj, wj = xj.shape
  76. dest_area = hi * wi
  77. # Central overlap common to both images
  78. oh, ow = min(hi, hj), min(wi, wj)
  79. overlap_area = oh * ow
  80. top_i, left_i = (hi - oh) // 2, (wi - ow) // 2
  81. top_j, left_j = (hj - oh) // 2, (wj - ow) // 2
  82. xi = xi.clone()
  83. if use_cutmix:
  84. # CutMix: random rectangle inside the overlap
  85. cut_ratio = math.sqrt(1.0 - lam_raw)
  86. ch, cw = int(oh * cut_ratio), int(ow * cut_ratio)
  87. cut_area = ch * cw
  88. y_off = random.randint(0, oh - ch)
  89. x_off = random.randint(0, ow - cw)
  90. yl_i, xl_i = top_i + y_off, left_i + x_off
  91. yl_j, xl_j = top_j + y_off, left_j + x_off
  92. xi[:, yl_i: yl_i + ch, xl_i: xl_i + cw] = xj[:, yl_j: yl_j + ch, xl_j: xl_j + cw]
  93. mixed_imgs[i] = xi
  94. corrected_lam = 1.0 - cut_area / float(dest_area)
  95. lam_list[i] = corrected_lam
  96. else:
  97. # Mixup: blend the entire overlap region
  98. patch_i = xi[:, top_i:top_i + oh, left_i:left_i + ow]
  99. patch_j = xj[:, top_j:top_j + oh, left_j:left_j + ow]
  100. blended = patch_i.mul(lam_raw).add_(patch_j, alpha=1.0 - lam_raw)
  101. xi[:, top_i:top_i + oh, left_i:left_i + ow] = blended
  102. mixed_imgs[i] = xi
  103. corrected_lam = (dest_area - overlap_area) / dest_area + lam_raw * overlap_area / dest_area
  104. lam_list[i] = corrected_lam
  105. return mixed_imgs, lam_list, pair_to
  106. def smoothed_sparse_target(
  107. targets: torch.Tensor,
  108. *,
  109. num_classes: int,
  110. smoothing: float = 0.0,
  111. ) -> torch.Tensor:
  112. off_val = smoothing / num_classes
  113. on_val = 1.0 - smoothing + off_val
  114. y_onehot = torch.full(
  115. (targets.size(0), num_classes),
  116. off_val,
  117. dtype=torch.float32,
  118. device=targets.device
  119. )
  120. y_onehot.scatter_(1, targets.unsqueeze(1), on_val)
  121. return y_onehot
  122. def pairwise_mixup_target(
  123. targets: torch.Tensor,
  124. pair_to: Dict[int, int],
  125. lam_list: List[float],
  126. *,
  127. num_classes: int,
  128. smoothing: float = 0.0,
  129. ) -> torch.Tensor:
  130. """Create soft targets that match the pixel‑level mixing performed.
  131. Args:
  132. targets: (B,) tensor of integer class indices.
  133. pair_to: Mapping of sample index to its mixed partner as returned by mix_batch_variable_size().
  134. lam_list: Per‑sample fractions of own pixels, also from the mixer.
  135. num_classes: Total number of classes in the dataset.
  136. smoothing: Label‑smoothing value in the range [0, 1).
  137. Returns:
  138. Tensor of shape (B, num_classes) whose rows sum to 1.
  139. """
  140. y_onehot = smoothed_sparse_target(targets, num_classes=num_classes, smoothing=smoothing)
  141. targets = y_onehot.clone()
  142. for i, j in pair_to.items():
  143. lam = lam_list[i]
  144. targets[i].mul_(lam).add_(y_onehot[j], alpha=1.0 - lam)
  145. return targets
  146. class NaFlexMixup:
  147. """Callable wrapper that combines mixing and target generation."""
  148. def __init__(
  149. self,
  150. *,
  151. num_classes: int,
  152. mixup_alpha: float = 0.8,
  153. cutmix_alpha: float = 1.0,
  154. switch_prob: float = 0.5,
  155. prob: float = 1.0,
  156. local_shuffle: int = 4,
  157. label_smoothing: float = 0.0,
  158. ) -> None:
  159. """Configure the augmentation.
  160. Args:
  161. num_classes: Total number of classes.
  162. mixup_alpha: Beta α for Mixup. 0 disables Mixup.
  163. cutmix_alpha: Beta α for CutMix. 0 disables CutMix.
  164. switch_prob: Probability of selecting CutMix when both modes are enabled.
  165. prob: Probability of applying any mixing per batch.
  166. local_shuffle: Window size used to shuffle images after aspect sorting so pairings vary between epochs.
  167. smoothing: Label‑smoothing value. 0 disables smoothing.
  168. """
  169. self.num_classes = num_classes
  170. self.mixup_alpha = mixup_alpha
  171. self.cutmix_alpha = cutmix_alpha
  172. self.switch_prob = switch_prob
  173. self.prob = prob
  174. self.local_shuffle = local_shuffle
  175. self.smoothing = label_smoothing
  176. def __call__(
  177. self,
  178. imgs: List[torch.Tensor],
  179. targets: torch.Tensor,
  180. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  181. """Apply the augmentation and generate matching targets.
  182. Args:
  183. imgs: List of already transformed images shaped (C, H, W).
  184. targets: Hard labels with shape (B,).
  185. Returns:
  186. mixed_imgs: List of mixed images in the same order and shapes as the input.
  187. targets: Soft‑label tensor shaped (B, num_classes) suitable for cross‑entropy with soft targets.
  188. """
  189. if not isinstance(targets, torch.Tensor):
  190. targets = torch.tensor(targets)
  191. if random.random() > self.prob:
  192. targets = smoothed_sparse_target(targets, num_classes=self.num_classes, smoothing=self.smoothing)
  193. return imgs, targets.unbind(0)
  194. mixed_imgs, lam_list, pair_to = mix_batch_variable_size(
  195. imgs,
  196. mixup_alpha=self.mixup_alpha,
  197. cutmix_alpha=self.cutmix_alpha,
  198. switch_prob=self.switch_prob,
  199. local_shuffle=self.local_shuffle,
  200. )
  201. targets = pairwise_mixup_target(
  202. targets,
  203. pair_to,
  204. lam_list,
  205. num_classes=self.num_classes,
  206. smoothing=self.smoothing,
  207. )
  208. return mixed_imgs, targets.unbind(0)