naflex_random_erasing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """Patch-level random erasing augmentation for NaFlex Vision Transformers.
  2. This module implements random erasing specifically designed for patchified images,
  3. operating at the patch granularity rather than pixel level. It supports two modes:
  4. - 'patch': Randomly erases individual patches (speckle-like noise)
  5. - 'region': Erases contiguous rectangular regions of patches (similar to original RandomErasing)
  6. The implementation is coordinate-aware, respecting valid patch boundaries and supporting
  7. variable patch sizes in NaFlex training.
  8. Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
  9. """
  10. import random
  11. import math
  12. from typing import Optional, Union, Tuple
  13. import torch
  14. class PatchRandomErasing:
  15. """Random erasing for patchified images in NaFlex format.
  16. Supports two modes:
  17. 1. 'patch': Simple mode that erases randomly selected valid patches
  18. 2. 'region': Erases rectangular regions at patch granularity
  19. """
  20. def __init__(
  21. self,
  22. erase_prob: float = 0.5,
  23. patch_drop_prob: float = 0.0,
  24. min_count: int = 1,
  25. max_count: Optional[int] = None,
  26. min_area: float = 0.02,
  27. max_area: float = 1 / 3,
  28. min_aspect: float = 0.3,
  29. max_aspect: Optional[float] = None,
  30. mode: str = 'const',
  31. value: float = 0.,
  32. spatial_mode: str = 'region',
  33. num_splits: int = 0,
  34. device: Union[str, torch.device] = 'cuda',
  35. ) -> None:
  36. """Initialize PatchRandomErasing.
  37. Args:
  38. erase_prob: Probability that the Random Erasing operation will be performed.
  39. patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
  40. min_count: Minimum number of erasing operations.
  41. max_count: Maximum number of erasing operations.
  42. min_area: Minimum percentage of valid patches/area to erase.
  43. max_area: Maximum percentage of valid patches/area to erase.
  44. min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
  45. max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
  46. mode: Patch content mode, one of 'const', 'rand', or 'pixel'.
  47. value: Constant value for 'const' mode.
  48. spatial_mode: Erasing strategy, one of 'patch' or 'region'.
  49. num_splits: Number of splits to apply erasing to (0 for all).
  50. device: Computation device.
  51. """
  52. self.erase_prob = erase_prob
  53. self.patch_drop_prob = patch_drop_prob
  54. self.min_count = min_count
  55. self.max_count = max_count or min_count
  56. self.min_area = min_area
  57. self.max_area = max_area
  58. # Aspect ratio params (for region mode)
  59. max_aspect = max_aspect or 1 / min_aspect
  60. self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
  61. # Number of splits
  62. self.num_splits = num_splits
  63. self.device = device
  64. # Strategy mode
  65. self.spatial_mode = spatial_mode
  66. assert self.spatial_mode in ('patch', 'region')
  67. # Value generation mode flags
  68. self.erase_mode = mode.lower()
  69. assert self.erase_mode in ('rand', 'pixel', 'const')
  70. self.const_value = value
  71. self.unique_noise_per_patch = True
  72. def _get_values(
  73. self,
  74. shape: Union[Tuple[int, ...], torch.Size],
  75. value: Optional[torch.Tensor] = None,
  76. dtype: torch.dtype = torch.float32,
  77. device: Optional[Union[str, torch.device]] = None
  78. ) -> torch.Tensor:
  79. """Generate values for erased patches based on the specified mode.
  80. Args:
  81. shape: Shape of patches to erase.
  82. value: Value to use in const (or rand) mode.
  83. dtype: Data type to use.
  84. device: Device to use.
  85. Returns:
  86. Tensor with values for erasing patches.
  87. """
  88. device = device or self.device
  89. if self.erase_mode == 'pixel':
  90. # only mode with erase shape that includes pixels
  91. return torch.empty(shape, dtype=dtype, device=device).normal_()
  92. else:
  93. shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1])
  94. if self.erase_mode == 'const' or value is not None:
  95. erase_value = value or self.const_value
  96. if isinstance(erase_value, (int, float)):
  97. values = torch.full(shape, erase_value, dtype=dtype, device=device)
  98. else:
  99. erase_value = torch.tensor(erase_value, dtype=dtype, device=device)
  100. values = torch.expand_copy(erase_value, shape)
  101. else:
  102. values = torch.empty(shape, dtype=dtype, device=device).normal_()
  103. return values
  104. def _drop_patches(
  105. self,
  106. patches: torch.Tensor,
  107. patch_coord: torch.Tensor,
  108. patch_valid: torch.Tensor,
  109. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  110. """Patch Dropout.
  111. Fully drops patches from datastream. Only mode that saves compute BUT requires support
  112. for non-contiguous patches and associated patch coordinate and valid handling.
  113. Args:
  114. patches: Tensor of patches.
  115. patch_coord: Tensor of patch coordinates.
  116. patch_valid: Tensor indicating which patches are valid.
  117. Returns:
  118. Tuple of (patches, patch_coord, patch_valid) with some patches dropped.
  119. """
  120. # FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches
  121. if random.random() > self.erase_prob:
  122. return
  123. # Get indices of valid patches
  124. valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
  125. # Skip if no valid patches
  126. if not valid_indices:
  127. return patches, patch_coord, patch_valid
  128. num_valid = len(valid_indices)
  129. if self.patch_drop_prob:
  130. # patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model)
  131. num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob)))
  132. keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep]
  133. # maintain patch order, possibly useful for debug / visualization
  134. keep_indices = keep_indices.sort(dim=-1)[0]
  135. patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:]))
  136. return patches, patch_coord, patch_valid
  137. def _erase_patches(
  138. self,
  139. patches: torch.Tensor,
  140. patch_coord: torch.Tensor,
  141. patch_valid: torch.Tensor,
  142. patch_shape: torch.Size,
  143. dtype: torch.dtype = torch.float32,
  144. ) -> None:
  145. """Apply erasing by selecting individual patches randomly.
  146. The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles'
  147. noise augmentation at patch size.
  148. Args:
  149. patches: Tensor of patches to modify in-place.
  150. patch_coord: Tensor of patch coordinates.
  151. patch_valid: Tensor indicating which patches are valid.
  152. patch_shape: Shape of individual patches.
  153. dtype: Data type for generated values.
  154. """
  155. if random.random() > self.erase_prob:
  156. return
  157. # Get indices of valid patches
  158. valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0]
  159. num_valid = len(valid_indices)
  160. if num_valid == 0:
  161. return
  162. count = random.randint(self.min_count, self.max_count)
  163. # Determine how many valid patches to erase from RE min/max count and area args
  164. max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area)))
  165. min_erase = max(1, int(num_valid * count * self.min_area))
  166. num_erase = random.randint(min_erase, max_erase)
  167. # Randomly select valid patches to erase
  168. erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]]
  169. if self.unique_noise_per_patch and self.erase_mode == 'pixel':
  170. # generate unique noise for the whole selection of patches
  171. fill_shape = (num_erase,) + patch_shape
  172. else:
  173. fill_shape = patch_shape
  174. patches[erase_idx] = self._get_values(fill_shape, dtype=dtype)
  175. def _erase_region(
  176. self,
  177. patches: torch.Tensor,
  178. patch_coord: torch.Tensor,
  179. patch_valid: torch.Tensor,
  180. patch_shape: torch.Size,
  181. dtype: torch.dtype = torch.float32,
  182. ) -> None:
  183. """Apply erasing by selecting rectangular regions of patches randomly.
  184. Closer to the original RandomErasing implementation. Erases
  185. spatially contiguous rectangular regions of patches (aligned with patches).
  186. Args:
  187. patches: Tensor of patches to modify in-place.
  188. patch_coord: Tensor of patch coordinates.
  189. patch_valid: Tensor indicating which patches are valid.
  190. patch_shape: Shape of individual patches.
  191. dtype: Data type for generated values.
  192. """
  193. if random.random() > self.erase_prob:
  194. return
  195. # Determine grid dimensions from coordinates
  196. valid_coord = patch_coord[patch_valid]
  197. if len(valid_coord) == 0:
  198. return # No valid patches
  199. max_y = valid_coord[:, 0].max().item() + 1
  200. max_x = valid_coord[:, 1].max().item() + 1
  201. grid_h, grid_w = max_y, max_x
  202. total_area = grid_h * grid_w
  203. ys, xs = patch_coord[:, 0], patch_coord[:, 1]
  204. count = random.randint(self.min_count, self.max_count)
  205. for _ in range(count):
  206. # Try to select a valid region to erase (multiple attempts)
  207. for attempt in range(10):
  208. # Sample random area and aspect ratio
  209. target_area = random.uniform(self.min_area, self.max_area) * total_area
  210. aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
  211. # Calculate region height and width
  212. h = int(round(math.sqrt(target_area * aspect_ratio)))
  213. w = int(round(math.sqrt(target_area / aspect_ratio)))
  214. if h > grid_h or w > grid_w:
  215. continue # try again
  216. # Calculate region patch bounds
  217. top = random.randint(0, grid_h - h)
  218. left = random.randint(0, grid_w - w)
  219. bottom, right = top + h, left + w
  220. # Region test
  221. region_mask = (
  222. (ys >= top) & (ys < bottom) &
  223. (xs >= left) & (xs < right) &
  224. patch_valid
  225. )
  226. num_selected = int(region_mask.sum().item())
  227. if not num_selected:
  228. continue # no patch actually falls inside – try again
  229. if self.unique_noise_per_patch and self.erase_mode == 'pixel':
  230. # generate unique noise for the whole region
  231. fill_shape = (num_selected,) + patch_shape
  232. else:
  233. fill_shape = patch_shape
  234. patches[region_mask] = self._get_values(fill_shape, dtype=dtype)
  235. # Successfully applied erasing, exit the loop
  236. break
  237. def __call__(
  238. self,
  239. patches: torch.Tensor,
  240. patch_coord: torch.Tensor,
  241. patch_valid: Optional[torch.Tensor] = None,
  242. ) -> torch.Tensor:
  243. """Apply random patch erasing.
  244. Args:
  245. patches: Tensor of shape [B, N, P*P, C] or [B, N, Ph, Pw, C].
  246. patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates.
  247. patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid.
  248. Returns:
  249. Erased patches tensor of same shape as input.
  250. """
  251. if patches.ndim == 4:
  252. batch_size, num_patches, patch_dim, channels = patches.shape
  253. elif patches.ndim == 5:
  254. batch_size, num_patches, patch_h, patch_w, channels = patches.shape
  255. else:
  256. assert False
  257. patch_shape = patches.shape[2:]
  258. # patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
  259. # Create default valid mask if not provided
  260. if patch_valid is None:
  261. patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device)
  262. # Skip the first part of the batch if num_splits is set
  263. batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
  264. # Apply erasing to each batch element
  265. for i in range(batch_start, batch_size):
  266. if self.patch_drop_prob:
  267. assert False, "WIP, not completed"
  268. self._drop_patches(
  269. patches[i],
  270. patch_coord[i],
  271. patch_valid[i],
  272. )
  273. elif self.spatial_mode == 'patch':
  274. # FIXME we could vectorize patch mode across batch, worth the effort?
  275. self._erase_patches(
  276. patches[i],
  277. patch_coord[i],
  278. patch_valid[i],
  279. patch_shape,
  280. patches.dtype
  281. )
  282. elif self.spatial_mode == 'region':
  283. self._erase_region(
  284. patches[i],
  285. patch_coord[i],
  286. patch_valid[i],
  287. patch_shape,
  288. patches.dtype
  289. )
  290. else:
  291. assert False
  292. return patches
  293. def __repr__(self) -> str:
  294. """Return string representation of PatchRandomErasing.
  295. Returns:
  296. String representation of the object.
  297. """
  298. fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}'
  299. fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))'
  300. fs += f', count=({self.min_count}, {self.max_count}))'
  301. return fs