| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558 |
- """ Dynamic Sequence Length Datasets for Variable Resolution Image Processing
- Implements two dataset wrappers:
- 1. NaFlexMapDatasetWrapper - Map-style dataset that returns batches with variable sequence lengths
- TODO: 2. NaFlexIterableDatasetWrapper - Iterable dataset that yields batches with variable sequence lengths
- Both support:
- - Pre-initialized transforms for efficiency
- - Distributed training
- - Multiple workers
- - Variable batch sizes based on sequence length
- Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
- """
- import math
- import random
- import warnings
- from functools import partial
- from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable
- import torch
- from torch.utils.data import Dataset, IterableDataset, DataLoader
- from PIL import Image
- from .naflex_transforms import Patchify
- from timm.layers import to_2tuple
- def calculate_naflex_batch_size(
- tokens_per_batch: int,
- seq_len: int,
- max_size: Optional[int] = None,
- divisor: int = 1,
- rounding: str = 'floor',
- ) -> int:
- """Calculate batch size based on sequence length with divisibility constraints.
- Args:
- tokens_per_batch: Target number of tokens per batch.
- seq_len: Sequence length for this batch.
- max_size: Optional maximum batch size.
- divisor: Ensure batch size is divisible by this value.
- rounding: Rounding method ('floor', 'ceil', 'round').
- Returns:
- Calculated batch size.
- """
- # Calculate raw batch size based on sequence length
- raw_batch_size = tokens_per_batch / seq_len
- # Apply divisibility with specified rounding method
- if divisor > 1:
- if rounding == 'floor':
- batch_size = math.floor(raw_batch_size / divisor) * divisor
- elif rounding == 'ceil':
- batch_size = math.ceil(raw_batch_size / divisor) * divisor
- else: # 'round' is the default
- batch_size = round(raw_batch_size / divisor) * divisor
- else:
- # If no divisor specified, just use integer division
- batch_size = int(raw_batch_size)
- # Ensure batch size is valid
- batch_size = max(1, batch_size) # At least 1
- if max_size is not None:
- batch_size = min(batch_size, max_size)
- return batch_size
- class NaFlexCollator:
- """Custom collator for batching NaFlex-style variable-resolution images."""
- def __init__(
- self,
- max_seq_len: Optional[int] = None,
- ) -> None:
- """Initialize NaFlexCollator.
- Args:
- max_seq_len: Maximum sequence length for batching.
- """
- self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
- def __call__(self, batch: List[Tuple[Dict[str, torch.Tensor], Union[int, torch.Tensor]]]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
- """Collate batch of NaFlex samples.
- Args:
- batch: List of tuples (patch_dict, target).
- Returns:
- A tuple of (input_dict, targets) where input_dict contains:
- - patches: Padded tensor of patches
- - patch_coord: Coordinates for each patch (y, x)
- - patch_valid: Valid indicators
- """
- assert isinstance(batch[0], tuple)
- batch_size = len(batch)
- # Extract targets
- targets = [item[1] for item in batch]
- if isinstance(targets[0], torch.Tensor):
- targets = torch.stack(targets)
- else:
- targets = torch.tensor(targets, dtype=torch.int64)
- # Get patch dictionaries
- patch_dicts = [item[0] for item in batch]
- # If we have a maximum sequence length constraint, ensure we don't exceed it
- if self.max_seq_len is not None:
- max_patches = self.max_seq_len
- else:
- # Find the maximum number of patches in this batch
- max_patches = max(item['patches'].shape[0] for item in patch_dicts)
- # Check if patches are flattened or unflattened
- patches_tensor = patch_dicts[0]['patches']
- is_unflattened = patches_tensor.ndim == 4 # [N, Ph, Pw, C]
- if is_unflattened:
- # Patches are [N, Ph, Pw, C] - variable patch size mode
- _, ph, pw, c = patches_tensor.shape
- patches = torch.zeros((batch_size, max_patches, ph, pw, c), dtype=torch.float32)
- else:
- # Patches are [N, P*P*C] - normal mode
- patch_dim = patches_tensor.shape[1]
- patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
- # Prepare other tensors
- patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
- patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
- # Fill in the tensors
- for i, patch_dict in enumerate(patch_dicts):
- num_patches = min(patch_dict['patches'].shape[0], max_patches)
- patches[i, :num_patches] = patch_dict['patches'][:num_patches]
- patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
- patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
- result = {
- 'patches': patches,
- 'patch_coord': patch_coord,
- 'patch_valid': patch_valid,
- 'seq_len': max_patches,
- }
- return result, targets
- def _resolve_patch_cfg(
- patch_size: Optional[Union[int, Tuple[int, int]]],
- patch_size_choices: Optional[List[int]],
- patch_size_choice_probs: Optional[List[float]],
- ) -> Tuple[List[Tuple[int, int]], List[float], bool]:
- """Resolve patch size configuration.
- Args:
- patch_size: Single patch size to use.
- patch_size_choices: List of patch sizes to choose from.
- patch_size_choice_probs: Probabilities for each patch size choice.
- Returns:
- Tuple of (sizes, probs, variable) where sizes is list of patch size tuples,
- probs is list of probabilities, and variable indicates if patch size varies.
- """
- # If both are None, default to patch_size=16
- if patch_size is None and patch_size_choices is None:
- patch_size = 16
- if (patch_size is None) == (patch_size_choices is None):
- raise ValueError(
- "Specify exactly one of `patch_size` or `patch_size_choices`."
- )
- if patch_size is not None:
- sizes = [to_2tuple(patch_size)]
- probs = [1.0]
- variable = False
- else:
- sizes = [to_2tuple(p) for p in patch_size_choices]
- if patch_size_choice_probs is None:
- probs = [1.0 / len(sizes)] * len(sizes)
- else:
- if len(patch_size_choice_probs) != len(sizes):
- raise ValueError("`patch_size_choice_probs` length mismatch.")
- s = float(sum(patch_size_choice_probs))
- if s <= 0:
- raise ValueError("`patch_size_choice_probs` sum to zero.")
- probs = [p / s for p in patch_size_choice_probs]
- variable = True
- return sizes, probs, variable
- class NaFlexMapDatasetWrapper(IterableDataset):
- """
- IterableDataset wrapper for a map-style base dataset.
- Yields batches with variable sequence lengths. It calculates a canonical
- batch schedule (sequence length, batch size pairs) once based on the
- total dataset size (padded for distribution). Each epoch, it shuffles
- the order of this canonical schedule and the dataset indices.
- This ensures a consistent number of batches and samples per epoch
- across all ranks. Handles distributed training and multiple workers.
- """
- def __init__(
- self,
- base_dataset: Dataset,
- patch_size: Optional[Union[int, Tuple[int, int]]] = None,
- patch_size_choices: Optional[List[int]] = None,
- patch_size_choice_probs: Optional[List[float]] = None,
- seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
- max_tokens_per_batch: int = 4096 * 4,
- transform_factory: Optional[Callable] = None,
- mixup_fn: Optional[Callable] = None,
- seed: int = 42,
- shuffle: bool = True,
- distributed: bool = False,
- rank: int = 0,
- world_size: int = 1,
- epoch: int = 0,
- batch_divisor: int = 8,
- ) -> None:
- """Initialize NaFlexMapDatasetWrapper.
- Args:
- base_dataset: Map-style dataset to wrap.
- patch_size: Single patch size to use.
- patch_size_choices: List of patch sizes to randomly select from.
- patch_size_choice_probs: Probabilities for each patch size.
- seq_lens: Sequence lengths to use for batching.
- max_tokens_per_batch: Target tokens per batch.
- transform_factory: Factory function for creating transforms.
- mixup_fn: Optional mixup function.
- seed: Random seed.
- shuffle: Whether to shuffle data.
- distributed: Whether using distributed training.
- rank: Process rank for distributed training.
- world_size: Total number of processes.
- epoch: Starting epoch.
- batch_divisor: Ensure batch size is divisible by this.
- """
- super().__init__()
- if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'):
- raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)")
- self.base_dataset = base_dataset
- self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted
- self.max_tokens_per_batch = max_tokens_per_batch
- self.seed = seed
- self.shuffle = shuffle
- self.distributed = distributed
- self.rank = rank if distributed else 0
- self.world_size = world_size if distributed else 1
- self.epoch = epoch
- self.batch_divisor = batch_divisor
- # Resolve patch size configuration
- self.patch_sizes, self.patch_size_probs, self.variable_patch_size = _resolve_patch_cfg(
- patch_size,
- patch_size_choices,
- patch_size_choice_probs
- )
- # Pre-initialize transforms and collate fns for each (seq_len, patch_idx) combination
- self.transforms: Dict[Tuple[int, int], Optional[Callable]] = {}
- self.collate_fns: Dict[int, Callable] = {}
- self.patchifiers: List[Callable] = []
- for seq_len in self.seq_lens:
- self.collate_fns[seq_len] = NaFlexCollator(seq_len)
- for patch_idx, patch_size_tuple in enumerate(self.patch_sizes):
- # Pre-initialize patchifiers for each patch size (indexed by patch_idx)
- self.patchifiers.append(Patchify(
- patch_size=patch_size_tuple,
- flatten_patches=not self.variable_patch_size
- ))
- # Create transforms for each (seq_len, patch_idx) combination
- for seq_len in self.seq_lens:
- key = (seq_len, patch_idx)
- if transform_factory:
- self.transforms[key] = transform_factory(max_seq_len=seq_len, patch_size=patch_size_tuple)
- else:
- self.transforms[key] = None # No transform
- self.mixup_fn = mixup_fn
- # Canonical Schedule Calculation (Done Once)
- self._canonical_batch_schedule: List[Tuple[int, int]] = []
- self._num_batches_per_rank: int = 0
- self._padded_samples_per_rank: int = 0
- self._create_canonical_schedule() # Calculate schedule based on padded size
- # Per-Epoch State
- # Stores (seq_len, list_of_indices) for the current epoch, specific to this rank
- self._epoch_batches: List[Tuple[int, List[int]]] = []
- self._prepare_epoch_batches(self.epoch) # setup for initial epoch
- def _create_canonical_schedule(self):
- """
- Calculates the canonical batch schedule (seq_len, batch_size pairs)
- based on the dataset size, padded for distributed training.
- This schedule is the *same* for all ranks and ensures consistent
- epoch length. It is calculated once during initialization.
- """
- total_len = len(self.base_dataset)
- padded_total_len = total_len
- num_samples_per_rank = total_len
- if self.distributed and self.world_size > 1:
- # Calculate padding needed for even distribution
- if total_len % self.world_size != 0:
- pad_size = self.world_size - (total_len % self.world_size)
- padded_total_len += pad_size
- print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).")
- else:
- pad_size = 0
- if padded_total_len % self.world_size != 0:
- # This should not happen with the padding logic, but safeguard
- raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}")
- num_samples_per_rank = padded_total_len // self.world_size
- elif self.distributed and self.world_size <= 1:
- # Distributed flag set but world_size is 1, treat as non-distributed
- pass # num_samples_per_rank remains total_len
- self._padded_samples_per_rank = num_samples_per_rank
- if num_samples_per_rank == 0:
- self._canonical_batch_schedule = []
- self._num_batches_per_rank = 0
- return
- # Use a fixed seed for generating the canonical schedule structure
- g = torch.Generator()
- g.manual_seed(self.seed) # Use base seed, NOT epoch seed
- current_schedule: List[Tuple[int, int]] = []
- remaining_samples = num_samples_per_rank
- total_scheduled_samples = 0
- while remaining_samples > 0:
- # Sample sequence length deterministically based on base seed
- seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item()
- seq_len = self.seq_lens[seq_idx]
- # Calculate batch size
- batch_size = calculate_naflex_batch_size(
- tokens_per_batch=self.max_tokens_per_batch,
- seq_len=seq_len,
- # max_size should be remaining_samples to avoid overshooting
- max_size=remaining_samples,
- divisor=self.batch_divisor,
- rounding='floor',
- )
- # Ensure batch size is positive and doesn't exceed remaining samples
- batch_size = max(1, batch_size)
- batch_size = min(batch_size, remaining_samples)
- if batch_size <= 0:
- warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.")
- break # Avoid infinite loop if something goes wrong
- current_schedule.append((seq_len, batch_size))
- remaining_samples -= batch_size
- total_scheduled_samples += batch_size
- # Sanity check: Ensure the schedule covers all samples for the rank
- if total_scheduled_samples != num_samples_per_rank:
- warnings.warn(
- f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, "
- f"but expected {num_samples_per_rank} samples per rank. "
- f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
- f"Check parameters. Remaining samples: {remaining_samples}"
- )
- # Adjust if needed? Could add a final small batch, but might violate constraints.
- # Current behavior: some samples might be dropped if schedule logic fails.
- self._canonical_batch_schedule = current_schedule
- self._num_batches_per_rank = len(current_schedule)
- print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.")
- def _prepare_epoch_batches(self, epoch: int):
- """
- Prepares the batches for the current epoch by:
- 1. Shuffling the full dataset indices (using epoch seed).
- 2. Applying padding if in distributed mode.
- 3. Selecting indices for the current rank.
- 4. Shuffling the *order* of the canonical batch schedule (using epoch seed).
- 5. Assigning the rank's indices to the shuffled batches.
- """
- g = torch.Generator()
- g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling
- # 1. Get shuffled global indices
- total_len = len(self.base_dataset)
- if self.shuffle:
- all_indices_shuffled = torch.randperm(total_len, generator=g).tolist()
- else:
- all_indices_shuffled = list(range(total_len))
- # 2. Apply padding for distributed mode
- indices_for_ranks = all_indices_shuffled
- if self.distributed and self.world_size > 1:
- padded_total_len = self._padded_samples_per_rank * self.world_size
- if padded_total_len > total_len:
- pad_size = padded_total_len - total_len
- # Repeat initial elements from the *shuffled* list for padding
- indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size]
- # Ensure length matches expectation
- if len(indices_for_ranks) != padded_total_len:
- raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")
- # 3. Select indices for the current rank
- if self.distributed and self.world_size > 1:
- indices_this_rank = indices_for_ranks[self.rank::self.world_size]
- else: # Non-distributed or world_size=1
- indices_this_rank = indices_for_ranks
- # Sanity check length
- if len(indices_this_rank) != self._padded_samples_per_rank:
- # This might happen if canonical schedule generation had warnings/issues
- warnings.warn(
- f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) "
- f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). "
- f"Epoch generation might be inconsistent."
- )
- # Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
- # Using min() prevents IndexError later if indices are fewer than expected.
- effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank)
- indices_this_rank = indices_this_rank[:effective_samples_this_rank]
- else:
- effective_samples_this_rank = self._padded_samples_per_rank
- # 4. Shuffle the order of the canonical batch schedule for this epoch
- if self.shuffle:
- schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist()
- shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm]
- else:
- shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order
- # 5. Assign indices to the shuffled batches
- self._epoch_batches = []
- idx_pos = 0
- scheduled_samples_count = 0
- for seq_len, bs in shuffled_schedule:
- # Ensure we don't try to grab more indices than available for the rank
- actual_bs = min(bs, effective_samples_this_rank - idx_pos)
- if actual_bs <= 0:
- if scheduled_samples_count < effective_samples_this_rank:
- # This indicates mismatch between schedule total and actual samples
- warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.")
- break # Stop if no more indices or batch size is zero
- batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs]
- self._epoch_batches.append((seq_len, batch_indices))
- idx_pos += actual_bs
- scheduled_samples_count += actual_bs
- # Final check
- if scheduled_samples_count != effective_samples_this_rank:
- warnings.warn(
- f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, "
- f"but expected {effective_samples_this_rank} effective samples this epoch. "
- f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
- )
- def set_epoch(self, epoch: int) -> None:
- """Updates the epoch, regenerating the epoch-specific batches.
- Args:
- epoch: New epoch number.
- """
- # Only regenerate if the epoch actually changes
- if epoch != self.epoch:
- self.epoch = epoch
- self._prepare_epoch_batches(epoch)
- def __len__(self) -> int:
- """Returns the number of batches per worker for the current epoch.
- Returns:
- Number of batches this worker will process.
- """
- return self._num_batches_per_rank
- def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
- """Iterates through pre-calculated batches for the current epoch.
- Yields:
- Tuple of (input_dict, targets) for each batch.
- """
- worker_info = torch.utils.data.get_worker_info()
- num_workers = worker_info.num_workers if worker_info else 1
- worker_id = worker_info.id if worker_info else 0
- # Distribute pre-calculated batches among workers for this rank
- # Each worker processes a slice of the batches prepared in _prepare_epoch_batches
- batches_for_worker = self._epoch_batches[worker_id::num_workers]
- for seq_len, indices in batches_for_worker:
- if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
- continue
- # Select patch size for this batch
- patch_idx = 0
- if self.variable_patch_size:
- # Use torch multinomial for weighted random choice
- patch_idx = torch.multinomial(torch.tensor(self.patch_size_probs), 1).item()
- # Get the pre-initialized transform and patchifier using patch_idx
- transform_key = (seq_len, patch_idx)
- transform = self.transforms.get(transform_key)
- batch_patchifier = self.patchifiers[patch_idx]
- batch_imgs = []
- batch_targets = []
- for idx in indices:
- try:
- # Get original image and label from map-style dataset
- img, label = self.base_dataset[idx]
- # Apply transform if available
- # Handle cases where transform might return None or fail
- processed_img = transform(img) if transform else img
- if processed_img is None:
- warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
- continue
- batch_imgs.append(processed_img)
- batch_targets.append(label)
- except IndexError:
- warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
- continue
- except Exception as e:
- # Log other potential errors during data loading/processing
- warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
- continue # Skip problematic sample
- if self.mixup_fn is not None:
- batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)
- batch_imgs = [batch_patchifier(img) for img in batch_imgs]
- batch_samples = list(zip(batch_imgs, batch_targets))
- if batch_samples: # Only yield if we successfully processed samples
- # Collate the processed samples into a batch
- yield self.collate_fns[seq_len](batch_samples)
- # If batch_samples is empty after processing 'indices', an empty batch is skipped.
|