| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414 |
- """NaFlex data loader for dynamic sequence length training.
- This module provides a specialized data loader for Vision Transformer models that supports:
- - Dynamic sequence length sampling during training for improved efficiency
- - Variable patch size training with probabilistic selection
- - Patch-level random erasing augmentation
- - Efficient GPU prefetching with normalization
- Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
- """
- import math
- from contextlib import suppress
- from functools import partial
- from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
- import torch
- from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from .loader import _worker_init, adapt_to_chs
- from .naflex_dataset import NaFlexMapDatasetWrapper, NaFlexCollator
- from .naflex_random_erasing import PatchRandomErasing
- from .transforms_factory import create_transform
- class NaFlexPrefetchLoader:
- """Data prefetcher for NaFlex format which normalizes patches."""
- def __init__(
- self,
- loader: torch.utils.data.DataLoader,
- mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
- std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
- channels: int = 3,
- device: torch.device = torch.device('cuda'),
- img_dtype: Optional[torch.dtype] = None,
- re_prob: float = 0.,
- re_mode: str = 'const',
- re_count: int = 1,
- re_num_splits: int = 0,
- ) -> None:
- """Initialize NaFlexPrefetchLoader.
- Args:
- loader: DataLoader to prefetch from.
- mean: Mean values for normalization.
- std: Standard deviation values for normalization.
- channels: Number of image channels.
- device: Device to move tensors to.
- img_dtype: Data type for image tensors.
- re_prob: Random erasing probability.
- re_mode: Random erasing mode.
- re_count: Maximum number of erasing rectangles.
- re_num_splits: Number of augmentation splits.
- """
- self.loader = loader
- self.device = device
- self.img_dtype = img_dtype or torch.float32
- # Create mean/std tensors for normalization (will be applied to patches)
- mean = adapt_to_chs(mean, channels)
- std = adapt_to_chs(std, channels)
- normalization_shape = (1, 1, channels)
- self.channels = channels
- self.mean = torch.tensor(
- [x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape)
- self.std = torch.tensor(
- [x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape)
- if re_prob > 0.:
- self.random_erasing = PatchRandomErasing(
- erase_prob=re_prob,
- mode=re_mode,
- max_count=re_count,
- num_splits=re_num_splits,
- device=device,
- )
- else:
- self.random_erasing = None
- # Check for CUDA/NPU availability
- self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
- self.is_npu = device.type == 'npu' and torch.npu.is_available()
- def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
- """Iterate through the loader with prefetching and normalization.
- Yields:
- Tuple of (input_dict, targets) with normalized patches.
- """
- first = True
- if self.is_cuda:
- stream = torch.cuda.Stream(device=self.device)
- stream_context = partial(torch.cuda.stream, stream=stream)
- elif self.is_npu:
- stream = torch.npu.Stream(device=self.device)
- stream_context = partial(torch.npu.stream, stream=stream)
- else:
- stream = None
- stream_context = suppress
- for next_input_dict, next_target in self.loader:
- with stream_context():
- # Move all tensors in input_dict to device
- for k, v in next_input_dict.items():
- if isinstance(v, torch.Tensor):
- dtype = self.img_dtype if k == 'patches' else None
- next_input_dict[k] = next_input_dict[k].to(
- device=self.device,
- non_blocking=True,
- dtype=dtype,
- )
- next_target = next_target.to(device=self.device, non_blocking=True)
- # Normalize patch values - handle both [B, N, P*P*C] and [B, N, Ph, Pw, C] formats
- patches_tensor = next_input_dict['patches']
- original_shape = patches_tensor.shape
- if patches_tensor.ndim == 3:
- # Format: [B, N, P*P*C] - flattened patches
- batch_size, num_patches, patch_pixels = original_shape
- # To [B*N, P*P, C] for normalization and erasing
- patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
- elif patches_tensor.ndim == 5:
- # Format: [B, N, Ph, Pw, C] - unflattened patches (variable patch size mode)
- batch_size, num_patches, patch_h, patch_w, channels = original_shape
- assert channels == self.channels, f"Expected {self.channels} channels, got {channels}"
- # To [B*N, Ph*Pw, C] for normalization and erasing
- patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
- else:
- raise ValueError(f"Unexpected patches tensor dimensions: {patches_tensor.ndim}. Expected 3 or 5.")
- # Apply normalization
- patches = patches.sub(self.mean).div(self.std)
- if self.random_erasing is not None:
- patches = self.random_erasing(
- patches,
- patch_coord=next_input_dict['patch_coord'],
- patch_valid=next_input_dict.get('patch_valid', None),
- )
- # Reshape back to original format
- next_input_dict['patches'] = patches.view(original_shape)
- if not first:
- yield input_dict, target
- else:
- first = False
- if stream is not None:
- if self.is_cuda:
- torch.cuda.current_stream(device=self.device).wait_stream(stream)
- elif self.is_npu:
- torch.npu.current_stream(device=self.device).wait_stream(stream)
- input_dict = next_input_dict
- target = next_target
- yield input_dict, target
- def __len__(self) -> int:
- """Get length of underlying loader.
- Returns:
- Number of batches in the loader.
- """
- return len(self.loader)
- @property
- def sampler(self):
- """Get sampler from underlying loader.
- Returns:
- Sampler from the underlying DataLoader.
- """
- return self.loader.sampler
- @property
- def dataset(self):
- """Get dataset from underlying loader.
- Returns:
- Dataset from the underlying DataLoader.
- """
- return self.loader.dataset
- def create_naflex_loader(
- dataset,
- patch_size: Optional[Union[Tuple[int, int], int]] = None,
- patch_size_choices: Optional[List[int]] = None,
- patch_size_choice_probs: Optional[List[float]] = None,
- train_seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
- max_seq_len: int = 576,
- batch_size: int = 32,
- is_training: bool = False,
- mixup_fn: Optional[Callable] = None,
- no_aug: bool = False,
- re_prob: float = 0.,
- re_mode: str = 'const',
- re_count: int = 1,
- re_split: bool = False,
- train_crop_mode: Optional[str] = None,
- scale: Optional[Tuple[float, float]] = None,
- ratio: Optional[Tuple[float, float]] = None,
- hflip: float = 0.5,
- vflip: float = 0.,
- color_jitter: float = 0.4,
- color_jitter_prob: Optional[float] = None,
- grayscale_prob: float = 0.,
- gaussian_blur_prob: float = 0.,
- auto_augment: Optional[str] = None,
- num_aug_repeats: int = 0,
- num_aug_splits: int = 0,
- interpolation: str = 'bilinear',
- mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
- std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
- crop_pct: Optional[float] = None,
- crop_mode: Optional[str] = None,
- crop_border_pixels: Optional[int] = None,
- num_workers: int = 4,
- distributed: bool = False,
- rank: int = 0,
- world_size: int = 1,
- seed: int = 42,
- epoch: int = 0,
- use_prefetcher: bool = True,
- pin_memory: bool = True,
- img_dtype: torch.dtype = torch.float32,
- device: Union[str, torch.device] = torch.device('cuda'),
- persistent_workers: bool = True,
- worker_seeding: str = 'all',
- ) -> Union[torch.utils.data.DataLoader, NaFlexPrefetchLoader]:
- """Create a data loader with dynamic sequence length sampling for training.
- Args:
- dataset: Dataset to load from.
- patch_size: Single patch size to use.
- patch_size_choices: List of patch sizes for variable patch size training.
- patch_size_choice_probs: Probabilities for each patch size choice.
- train_seq_lens: Training sequence lengths for dynamic batching.
- max_seq_len: Fixed sequence length for validation.
- batch_size: Batch size for validation and max training sequence length.
- is_training: Whether this is for training (enables dynamic batching).
- mixup_fn: Optional mixup function.
- no_aug: Disable augmentation.
- re_prob: Random erasing probability.
- re_mode: Random erasing mode.
- re_count: Maximum number of erasing rectangles.
- re_split: Random erasing split flag.
- train_crop_mode: Training crop mode.
- scale: Scale range for random resize crop.
- ratio: Aspect ratio range for random resize crop.
- hflip: Horizontal flip probability.
- vflip: Vertical flip probability.
- color_jitter: Color jitter factor.
- color_jitter_prob: Color jitter probability.
- grayscale_prob: Grayscale conversion probability.
- gaussian_blur_prob: Gaussian blur probability.
- auto_augment: AutoAugment policy.
- num_aug_repeats: Number of augmentation repeats.
- num_aug_splits: Number of augmentation splits.
- interpolation: Interpolation method.
- mean: Normalization mean values.
- std: Normalization standard deviation values.
- crop_pct: Crop percentage for validation.
- crop_mode: Crop mode.
- crop_border_pixels: Crop border pixels.
- num_workers: Number of data loading workers.
- distributed: Whether using distributed training.
- rank: Process rank for distributed training.
- world_size: Total number of processes.
- seed: Random seed.
- epoch: Starting epoch.
- use_prefetcher: Whether to use prefetching.
- pin_memory: Whether to pin memory.
- img_dtype: Image data type.
- device: Device to move tensors to.
- persistent_workers: Whether to use persistent workers.
- worker_seeding: Worker seeding mode.
- Returns:
- DataLoader or NaFlexPrefetchLoader instance.
- """
- if is_training:
- # For training, use the dynamic sequence length mechanism
- assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
- transform_factory = partial(
- create_transform,
- is_training=True,
- no_aug=no_aug,
- train_crop_mode=train_crop_mode,
- scale=scale,
- ratio=ratio,
- hflip=hflip,
- vflip=vflip,
- color_jitter=color_jitter,
- color_jitter_prob=color_jitter_prob,
- grayscale_prob=grayscale_prob,
- gaussian_blur_prob=gaussian_blur_prob,
- auto_augment=auto_augment,
- interpolation=interpolation,
- mean=mean,
- std=std,
- crop_pct=crop_pct,
- crop_mode=crop_mode,
- crop_border_pixels=crop_border_pixels,
- re_prob=re_prob,
- re_mode=re_mode,
- re_count=re_count,
- use_prefetcher=use_prefetcher,
- naflex=True,
- )
- max_train_seq_len = max(train_seq_lens)
- max_tokens_per_batch = batch_size * max_train_seq_len
- if isinstance(dataset, torch.utils.data.IterableDataset):
- assert False, "IterableDataset Wrapper is a WIP"
- naflex_dataset = NaFlexMapDatasetWrapper(
- dataset,
- transform_factory=transform_factory,
- patch_size=patch_size,
- patch_size_choices=patch_size_choices,
- patch_size_choice_probs=patch_size_choice_probs,
- seq_lens=train_seq_lens,
- max_tokens_per_batch=max_tokens_per_batch,
- mixup_fn=mixup_fn,
- seed=seed,
- distributed=distributed,
- rank=rank,
- world_size=world_size,
- shuffle=True,
- epoch=epoch,
- )
- # NOTE: Collation is handled by the dataset wrapper for training
- loader = torch.utils.data.DataLoader(
- naflex_dataset,
- batch_size=None,
- shuffle=False,
- num_workers=num_workers,
- sampler=None,
- pin_memory=pin_memory,
- worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
- persistent_workers=persistent_workers
- )
- if use_prefetcher:
- loader = NaFlexPrefetchLoader(
- loader,
- mean=mean,
- std=std,
- img_dtype=img_dtype,
- device=device,
- re_prob=re_prob,
- re_mode=re_mode,
- re_count=re_count,
- )
- else:
- # For validation, use fixed sequence length (unchanged)
- dataset.transform = create_transform(
- is_training=False,
- interpolation=interpolation,
- mean=mean,
- std=std,
- # FIXME add crop args when sequence transforms support crop modes
- use_prefetcher=use_prefetcher,
- naflex=True,
- patch_size=patch_size,
- max_seq_len=max_seq_len,
- patchify=True,
- )
- # Create the collator
- collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
- # Handle distributed training
- sampler = None
- if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
- # For validation, use OrderedDistributedSampler
- from timm.data.distributed_sampler import OrderedDistributedSampler
- sampler = OrderedDistributedSampler(dataset)
- loader = torch.utils.data.DataLoader(
- dataset,
- batch_size=batch_size,
- shuffle=False,
- num_workers=num_workers,
- sampler=sampler,
- collate_fn=collate_fn,
- pin_memory=pin_memory,
- drop_last=False,
- )
- if use_prefetcher:
- loader = NaFlexPrefetchLoader(
- loader,
- mean=mean,
- std=std,
- img_dtype=img_dtype,
- device=device,
- )
- return loader
|