| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- """Patch-level random erasing augmentation for NaFlex Vision Transformers.
- This module implements random erasing specifically designed for patchified images,
- operating at the patch granularity rather than pixel level. It supports two modes:
- - 'patch': Randomly erases individual patches (speckle-like noise)
- - 'region': Erases contiguous rectangular regions of patches (similar to original RandomErasing)
- The implementation is coordinate-aware, respecting valid patch boundaries and supporting
- variable patch sizes in NaFlex training.
- Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
- """
- import random
- import math
- from typing import Optional, Union, Tuple
- import torch
- class PatchRandomErasing:
- """Random erasing for patchified images in NaFlex format.
- Supports two modes:
- 1. 'patch': Simple mode that erases randomly selected valid patches
- 2. 'region': Erases rectangular regions at patch granularity
- """
- def __init__(
- self,
- erase_prob: float = 0.5,
- patch_drop_prob: float = 0.0,
- min_count: int = 1,
- max_count: Optional[int] = None,
- min_area: float = 0.02,
- max_area: float = 1 / 3,
- min_aspect: float = 0.3,
- max_aspect: Optional[float] = None,
- mode: str = 'const',
- value: float = 0.,
- spatial_mode: str = 'region',
- num_splits: int = 0,
- device: Union[str, torch.device] = 'cuda',
- ) -> None:
- """Initialize PatchRandomErasing.
- Args:
- erase_prob: Probability that the Random Erasing operation will be performed.
- patch_drop_prob: Patch dropout probability. Remove random patches instead of erasing.
- min_count: Minimum number of erasing operations.
- max_count: Maximum number of erasing operations.
- min_area: Minimum percentage of valid patches/area to erase.
- max_area: Maximum percentage of valid patches/area to erase.
- min_aspect: Minimum aspect ratio of erased area (only used in 'region' mode).
- max_aspect: Maximum aspect ratio of erased area (only used in 'region' mode).
- mode: Patch content mode, one of 'const', 'rand', or 'pixel'.
- value: Constant value for 'const' mode.
- spatial_mode: Erasing strategy, one of 'patch' or 'region'.
- num_splits: Number of splits to apply erasing to (0 for all).
- device: Computation device.
- """
- self.erase_prob = erase_prob
- self.patch_drop_prob = patch_drop_prob
- self.min_count = min_count
- self.max_count = max_count or min_count
- self.min_area = min_area
- self.max_area = max_area
- # Aspect ratio params (for region mode)
- max_aspect = max_aspect or 1 / min_aspect
- self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
- # Number of splits
- self.num_splits = num_splits
- self.device = device
- # Strategy mode
- self.spatial_mode = spatial_mode
- assert self.spatial_mode in ('patch', 'region')
- # Value generation mode flags
- self.erase_mode = mode.lower()
- assert self.erase_mode in ('rand', 'pixel', 'const')
- self.const_value = value
- self.unique_noise_per_patch = True
- def _get_values(
- self,
- shape: Union[Tuple[int, ...], torch.Size],
- value: Optional[torch.Tensor] = None,
- dtype: torch.dtype = torch.float32,
- device: Optional[Union[str, torch.device]] = None
- ) -> torch.Tensor:
- """Generate values for erased patches based on the specified mode.
- Args:
- shape: Shape of patches to erase.
- value: Value to use in const (or rand) mode.
- dtype: Data type to use.
- device: Device to use.
- Returns:
- Tensor with values for erasing patches.
- """
- device = device or self.device
- if self.erase_mode == 'pixel':
- # only mode with erase shape that includes pixels
- return torch.empty(shape, dtype=dtype, device=device).normal_()
- else:
- shape = (1, 1, shape[-1]) if len(shape) == 3 else (1, shape[-1])
- if self.erase_mode == 'const' or value is not None:
- erase_value = value or self.const_value
- if isinstance(erase_value, (int, float)):
- values = torch.full(shape, erase_value, dtype=dtype, device=device)
- else:
- erase_value = torch.tensor(erase_value, dtype=dtype, device=device)
- values = torch.expand_copy(erase_value, shape)
- else:
- values = torch.empty(shape, dtype=dtype, device=device).normal_()
- return values
- def _drop_patches(
- self,
- patches: torch.Tensor,
- patch_coord: torch.Tensor,
- patch_valid: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Patch Dropout.
- Fully drops patches from datastream. Only mode that saves compute BUT requires support
- for non-contiguous patches and associated patch coordinate and valid handling.
- Args:
- patches: Tensor of patches.
- patch_coord: Tensor of patch coordinates.
- patch_valid: Tensor indicating which patches are valid.
- Returns:
- Tuple of (patches, patch_coord, patch_valid) with some patches dropped.
- """
- # FIXME WIP, not completed. Downstream support in model needed for non-contiguous valid patches
- if random.random() > self.erase_prob:
- return
- # Get indices of valid patches
- valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0].tolist()
- # Skip if no valid patches
- if not valid_indices:
- return patches, patch_coord, patch_valid
- num_valid = len(valid_indices)
- if self.patch_drop_prob:
- # patch dropout mode, completely remove dropped patches (FIXME needs downstream support in model)
- num_keep = max(1, int(num_valid * (1. - self.patch_drop_prob)))
- keep_indices = torch.argsort(torch.randn(1, num_valid, device=self.device), dim=-1)[:, :num_keep]
- # maintain patch order, possibly useful for debug / visualization
- keep_indices = keep_indices.sort(dim=-1)[0]
- patches = patches.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + patches.shape[2:]))
- return patches, patch_coord, patch_valid
- def _erase_patches(
- self,
- patches: torch.Tensor,
- patch_coord: torch.Tensor,
- patch_valid: torch.Tensor,
- patch_shape: torch.Size,
- dtype: torch.dtype = torch.float32,
- ) -> None:
- """Apply erasing by selecting individual patches randomly.
- The simplest mode, aligned on patch boundaries. Behaves similarly to speckle or 'sprinkles'
- noise augmentation at patch size.
- Args:
- patches: Tensor of patches to modify in-place.
- patch_coord: Tensor of patch coordinates.
- patch_valid: Tensor indicating which patches are valid.
- patch_shape: Shape of individual patches.
- dtype: Data type for generated values.
- """
- if random.random() > self.erase_prob:
- return
- # Get indices of valid patches
- valid_indices = torch.nonzero(patch_valid, as_tuple=True)[0]
- num_valid = len(valid_indices)
- if num_valid == 0:
- return
- count = random.randint(self.min_count, self.max_count)
- # Determine how many valid patches to erase from RE min/max count and area args
- max_erase = min(num_valid, max(1, int(num_valid * count * self.max_area)))
- min_erase = max(1, int(num_valid * count * self.min_area))
- num_erase = random.randint(min_erase, max_erase)
- # Randomly select valid patches to erase
- erase_idx = valid_indices[torch.randperm(num_valid, device=patches.device)[:num_erase]]
- if self.unique_noise_per_patch and self.erase_mode == 'pixel':
- # generate unique noise for the whole selection of patches
- fill_shape = (num_erase,) + patch_shape
- else:
- fill_shape = patch_shape
- patches[erase_idx] = self._get_values(fill_shape, dtype=dtype)
- def _erase_region(
- self,
- patches: torch.Tensor,
- patch_coord: torch.Tensor,
- patch_valid: torch.Tensor,
- patch_shape: torch.Size,
- dtype: torch.dtype = torch.float32,
- ) -> None:
- """Apply erasing by selecting rectangular regions of patches randomly.
- Closer to the original RandomErasing implementation. Erases
- spatially contiguous rectangular regions of patches (aligned with patches).
- Args:
- patches: Tensor of patches to modify in-place.
- patch_coord: Tensor of patch coordinates.
- patch_valid: Tensor indicating which patches are valid.
- patch_shape: Shape of individual patches.
- dtype: Data type for generated values.
- """
- if random.random() > self.erase_prob:
- return
- # Determine grid dimensions from coordinates
- valid_coord = patch_coord[patch_valid]
- if len(valid_coord) == 0:
- return # No valid patches
- max_y = valid_coord[:, 0].max().item() + 1
- max_x = valid_coord[:, 1].max().item() + 1
- grid_h, grid_w = max_y, max_x
- total_area = grid_h * grid_w
- ys, xs = patch_coord[:, 0], patch_coord[:, 1]
- count = random.randint(self.min_count, self.max_count)
- for _ in range(count):
- # Try to select a valid region to erase (multiple attempts)
- for attempt in range(10):
- # Sample random area and aspect ratio
- target_area = random.uniform(self.min_area, self.max_area) * total_area
- aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
- # Calculate region height and width
- h = int(round(math.sqrt(target_area * aspect_ratio)))
- w = int(round(math.sqrt(target_area / aspect_ratio)))
- if h > grid_h or w > grid_w:
- continue # try again
- # Calculate region patch bounds
- top = random.randint(0, grid_h - h)
- left = random.randint(0, grid_w - w)
- bottom, right = top + h, left + w
- # Region test
- region_mask = (
- (ys >= top) & (ys < bottom) &
- (xs >= left) & (xs < right) &
- patch_valid
- )
- num_selected = int(region_mask.sum().item())
- if not num_selected:
- continue # no patch actually falls inside – try again
- if self.unique_noise_per_patch and self.erase_mode == 'pixel':
- # generate unique noise for the whole region
- fill_shape = (num_selected,) + patch_shape
- else:
- fill_shape = patch_shape
- patches[region_mask] = self._get_values(fill_shape, dtype=dtype)
- # Successfully applied erasing, exit the loop
- break
- def __call__(
- self,
- patches: torch.Tensor,
- patch_coord: torch.Tensor,
- patch_valid: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- """Apply random patch erasing.
- Args:
- patches: Tensor of shape [B, N, P*P, C] or [B, N, Ph, Pw, C].
- patch_coord: Tensor of shape [B, N, 2] with (y, x) coordinates.
- patch_valid: Boolean tensor of shape [B, N] indicating which patches are valid.
- Returns:
- Erased patches tensor of same shape as input.
- """
- if patches.ndim == 4:
- batch_size, num_patches, patch_dim, channels = patches.shape
- elif patches.ndim == 5:
- batch_size, num_patches, patch_h, patch_w, channels = patches.shape
- else:
- assert False
- patch_shape = patches.shape[2:]
- # patch_shape ==> shape of patches to fill (h, w, c) or (h * w, c)
- # Create default valid mask if not provided
- if patch_valid is None:
- patch_valid = torch.ones((batch_size, num_patches), dtype=torch.bool, device=patches.device)
- # Skip the first part of the batch if num_splits is set
- batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
- # Apply erasing to each batch element
- for i in range(batch_start, batch_size):
- if self.patch_drop_prob:
- assert False, "WIP, not completed"
- self._drop_patches(
- patches[i],
- patch_coord[i],
- patch_valid[i],
- )
- elif self.spatial_mode == 'patch':
- # FIXME we could vectorize patch mode across batch, worth the effort?
- self._erase_patches(
- patches[i],
- patch_coord[i],
- patch_valid[i],
- patch_shape,
- patches.dtype
- )
- elif self.spatial_mode == 'region':
- self._erase_region(
- patches[i],
- patch_coord[i],
- patch_valid[i],
- patch_shape,
- patches.dtype
- )
- else:
- assert False
- return patches
- def __repr__(self) -> str:
- """Return string representation of PatchRandomErasing.
- Returns:
- String representation of the object.
- """
- fs = self.__class__.__name__ + f'(p={self.erase_prob}, mode={self.erase_mode}'
- fs += f', spatial={self.spatial_mode}, area=({self.min_area}, {self.max_area}))'
- fs += f', count=({self.min_count}, {self.max_count}))'
- return fs
|