naflex_dataset.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. """ Dynamic Sequence Length Datasets for Variable Resolution Image Processing
  2. Implements two dataset wrappers:
  3. 1. NaFlexMapDatasetWrapper - Map-style dataset that returns batches with variable sequence lengths
  4. TODO: 2. NaFlexIterableDatasetWrapper - Iterable dataset that yields batches with variable sequence lengths
  5. Both support:
  6. - Pre-initialized transforms for efficiency
  7. - Distributed training
  8. - Multiple workers
  9. - Variable batch sizes based on sequence length
  10. Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
  11. """
  12. import math
  13. import random
  14. import warnings
  15. from functools import partial
  16. from typing import Any, Iterator, List, Tuple, Dict, Optional, Union, Callable
  17. import torch
  18. from torch.utils.data import Dataset, IterableDataset, DataLoader
  19. from PIL import Image
  20. from .naflex_transforms import Patchify
  21. from timm.layers import to_2tuple
  22. def calculate_naflex_batch_size(
  23. tokens_per_batch: int,
  24. seq_len: int,
  25. max_size: Optional[int] = None,
  26. divisor: int = 1,
  27. rounding: str = 'floor',
  28. ) -> int:
  29. """Calculate batch size based on sequence length with divisibility constraints.
  30. Args:
  31. tokens_per_batch: Target number of tokens per batch.
  32. seq_len: Sequence length for this batch.
  33. max_size: Optional maximum batch size.
  34. divisor: Ensure batch size is divisible by this value.
  35. rounding: Rounding method ('floor', 'ceil', 'round').
  36. Returns:
  37. Calculated batch size.
  38. """
  39. # Calculate raw batch size based on sequence length
  40. raw_batch_size = tokens_per_batch / seq_len
  41. # Apply divisibility with specified rounding method
  42. if divisor > 1:
  43. if rounding == 'floor':
  44. batch_size = math.floor(raw_batch_size / divisor) * divisor
  45. elif rounding == 'ceil':
  46. batch_size = math.ceil(raw_batch_size / divisor) * divisor
  47. else: # 'round' is the default
  48. batch_size = round(raw_batch_size / divisor) * divisor
  49. else:
  50. # If no divisor specified, just use integer division
  51. batch_size = int(raw_batch_size)
  52. # Ensure batch size is valid
  53. batch_size = max(1, batch_size) # At least 1
  54. if max_size is not None:
  55. batch_size = min(batch_size, max_size)
  56. return batch_size
  57. class NaFlexCollator:
  58. """Custom collator for batching NaFlex-style variable-resolution images."""
  59. def __init__(
  60. self,
  61. max_seq_len: Optional[int] = None,
  62. ) -> None:
  63. """Initialize NaFlexCollator.
  64. Args:
  65. max_seq_len: Maximum sequence length for batching.
  66. """
  67. self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
  68. def __call__(self, batch: List[Tuple[Dict[str, torch.Tensor], Union[int, torch.Tensor]]]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
  69. """Collate batch of NaFlex samples.
  70. Args:
  71. batch: List of tuples (patch_dict, target).
  72. Returns:
  73. A tuple of (input_dict, targets) where input_dict contains:
  74. - patches: Padded tensor of patches
  75. - patch_coord: Coordinates for each patch (y, x)
  76. - patch_valid: Valid indicators
  77. """
  78. assert isinstance(batch[0], tuple)
  79. batch_size = len(batch)
  80. # Extract targets
  81. targets = [item[1] for item in batch]
  82. if isinstance(targets[0], torch.Tensor):
  83. targets = torch.stack(targets)
  84. else:
  85. targets = torch.tensor(targets, dtype=torch.int64)
  86. # Get patch dictionaries
  87. patch_dicts = [item[0] for item in batch]
  88. # If we have a maximum sequence length constraint, ensure we don't exceed it
  89. if self.max_seq_len is not None:
  90. max_patches = self.max_seq_len
  91. else:
  92. # Find the maximum number of patches in this batch
  93. max_patches = max(item['patches'].shape[0] for item in patch_dicts)
  94. # Check if patches are flattened or unflattened
  95. patches_tensor = patch_dicts[0]['patches']
  96. is_unflattened = patches_tensor.ndim == 4 # [N, Ph, Pw, C]
  97. if is_unflattened:
  98. # Patches are [N, Ph, Pw, C] - variable patch size mode
  99. _, ph, pw, c = patches_tensor.shape
  100. patches = torch.zeros((batch_size, max_patches, ph, pw, c), dtype=torch.float32)
  101. else:
  102. # Patches are [N, P*P*C] - normal mode
  103. patch_dim = patches_tensor.shape[1]
  104. patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
  105. # Prepare other tensors
  106. patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
  107. patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
  108. # Fill in the tensors
  109. for i, patch_dict in enumerate(patch_dicts):
  110. num_patches = min(patch_dict['patches'].shape[0], max_patches)
  111. patches[i, :num_patches] = patch_dict['patches'][:num_patches]
  112. patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
  113. patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
  114. result = {
  115. 'patches': patches,
  116. 'patch_coord': patch_coord,
  117. 'patch_valid': patch_valid,
  118. 'seq_len': max_patches,
  119. }
  120. return result, targets
  121. def _resolve_patch_cfg(
  122. patch_size: Optional[Union[int, Tuple[int, int]]],
  123. patch_size_choices: Optional[List[int]],
  124. patch_size_choice_probs: Optional[List[float]],
  125. ) -> Tuple[List[Tuple[int, int]], List[float], bool]:
  126. """Resolve patch size configuration.
  127. Args:
  128. patch_size: Single patch size to use.
  129. patch_size_choices: List of patch sizes to choose from.
  130. patch_size_choice_probs: Probabilities for each patch size choice.
  131. Returns:
  132. Tuple of (sizes, probs, variable) where sizes is list of patch size tuples,
  133. probs is list of probabilities, and variable indicates if patch size varies.
  134. """
  135. # If both are None, default to patch_size=16
  136. if patch_size is None and patch_size_choices is None:
  137. patch_size = 16
  138. if (patch_size is None) == (patch_size_choices is None):
  139. raise ValueError(
  140. "Specify exactly one of `patch_size` or `patch_size_choices`."
  141. )
  142. if patch_size is not None:
  143. sizes = [to_2tuple(patch_size)]
  144. probs = [1.0]
  145. variable = False
  146. else:
  147. sizes = [to_2tuple(p) for p in patch_size_choices]
  148. if patch_size_choice_probs is None:
  149. probs = [1.0 / len(sizes)] * len(sizes)
  150. else:
  151. if len(patch_size_choice_probs) != len(sizes):
  152. raise ValueError("`patch_size_choice_probs` length mismatch.")
  153. s = float(sum(patch_size_choice_probs))
  154. if s <= 0:
  155. raise ValueError("`patch_size_choice_probs` sum to zero.")
  156. probs = [p / s for p in patch_size_choice_probs]
  157. variable = True
  158. return sizes, probs, variable
  159. class NaFlexMapDatasetWrapper(IterableDataset):
  160. """
  161. IterableDataset wrapper for a map-style base dataset.
  162. Yields batches with variable sequence lengths. It calculates a canonical
  163. batch schedule (sequence length, batch size pairs) once based on the
  164. total dataset size (padded for distribution). Each epoch, it shuffles
  165. the order of this canonical schedule and the dataset indices.
  166. This ensures a consistent number of batches and samples per epoch
  167. across all ranks. Handles distributed training and multiple workers.
  168. """
  169. def __init__(
  170. self,
  171. base_dataset: Dataset,
  172. patch_size: Optional[Union[int, Tuple[int, int]]] = None,
  173. patch_size_choices: Optional[List[int]] = None,
  174. patch_size_choice_probs: Optional[List[float]] = None,
  175. seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
  176. max_tokens_per_batch: int = 4096 * 4,
  177. transform_factory: Optional[Callable] = None,
  178. mixup_fn: Optional[Callable] = None,
  179. seed: int = 42,
  180. shuffle: bool = True,
  181. distributed: bool = False,
  182. rank: int = 0,
  183. world_size: int = 1,
  184. epoch: int = 0,
  185. batch_divisor: int = 8,
  186. ) -> None:
  187. """Initialize NaFlexMapDatasetWrapper.
  188. Args:
  189. base_dataset: Map-style dataset to wrap.
  190. patch_size: Single patch size to use.
  191. patch_size_choices: List of patch sizes to randomly select from.
  192. patch_size_choice_probs: Probabilities for each patch size.
  193. seq_lens: Sequence lengths to use for batching.
  194. max_tokens_per_batch: Target tokens per batch.
  195. transform_factory: Factory function for creating transforms.
  196. mixup_fn: Optional mixup function.
  197. seed: Random seed.
  198. shuffle: Whether to shuffle data.
  199. distributed: Whether using distributed training.
  200. rank: Process rank for distributed training.
  201. world_size: Total number of processes.
  202. epoch: Starting epoch.
  203. batch_divisor: Ensure batch size is divisible by this.
  204. """
  205. super().__init__()
  206. if not hasattr(base_dataset, '__len__') or not hasattr(base_dataset, '__getitem__'):
  207. raise TypeError("base_dataset must be a map-style dataset (implement __len__ and __getitem__)")
  208. self.base_dataset = base_dataset
  209. self.seq_lens = sorted(list(set(seq_lens))) # Ensure unique and sorted
  210. self.max_tokens_per_batch = max_tokens_per_batch
  211. self.seed = seed
  212. self.shuffle = shuffle
  213. self.distributed = distributed
  214. self.rank = rank if distributed else 0
  215. self.world_size = world_size if distributed else 1
  216. self.epoch = epoch
  217. self.batch_divisor = batch_divisor
  218. # Resolve patch size configuration
  219. self.patch_sizes, self.patch_size_probs, self.variable_patch_size = _resolve_patch_cfg(
  220. patch_size,
  221. patch_size_choices,
  222. patch_size_choice_probs
  223. )
  224. # Pre-initialize transforms and collate fns for each (seq_len, patch_idx) combination
  225. self.transforms: Dict[Tuple[int, int], Optional[Callable]] = {}
  226. self.collate_fns: Dict[int, Callable] = {}
  227. self.patchifiers: List[Callable] = []
  228. for seq_len in self.seq_lens:
  229. self.collate_fns[seq_len] = NaFlexCollator(seq_len)
  230. for patch_idx, patch_size_tuple in enumerate(self.patch_sizes):
  231. # Pre-initialize patchifiers for each patch size (indexed by patch_idx)
  232. self.patchifiers.append(Patchify(
  233. patch_size=patch_size_tuple,
  234. flatten_patches=not self.variable_patch_size
  235. ))
  236. # Create transforms for each (seq_len, patch_idx) combination
  237. for seq_len in self.seq_lens:
  238. key = (seq_len, patch_idx)
  239. if transform_factory:
  240. self.transforms[key] = transform_factory(max_seq_len=seq_len, patch_size=patch_size_tuple)
  241. else:
  242. self.transforms[key] = None # No transform
  243. self.mixup_fn = mixup_fn
  244. # Canonical Schedule Calculation (Done Once)
  245. self._canonical_batch_schedule: List[Tuple[int, int]] = []
  246. self._num_batches_per_rank: int = 0
  247. self._padded_samples_per_rank: int = 0
  248. self._create_canonical_schedule() # Calculate schedule based on padded size
  249. # Per-Epoch State
  250. # Stores (seq_len, list_of_indices) for the current epoch, specific to this rank
  251. self._epoch_batches: List[Tuple[int, List[int]]] = []
  252. self._prepare_epoch_batches(self.epoch) # setup for initial epoch
  253. def _create_canonical_schedule(self):
  254. """
  255. Calculates the canonical batch schedule (seq_len, batch_size pairs)
  256. based on the dataset size, padded for distributed training.
  257. This schedule is the *same* for all ranks and ensures consistent
  258. epoch length. It is calculated once during initialization.
  259. """
  260. total_len = len(self.base_dataset)
  261. padded_total_len = total_len
  262. num_samples_per_rank = total_len
  263. if self.distributed and self.world_size > 1:
  264. # Calculate padding needed for even distribution
  265. if total_len % self.world_size != 0:
  266. pad_size = self.world_size - (total_len % self.world_size)
  267. padded_total_len += pad_size
  268. print(f"Rank {self.rank}: Padding dataset with {pad_size} samples for distributed training (total size {padded_total_len}).")
  269. else:
  270. pad_size = 0
  271. if padded_total_len % self.world_size != 0:
  272. # This should not happen with the padding logic, but safeguard
  273. raise RuntimeError(f"Internal Error: Padded total length {padded_total_len} not divisible by world size {self.world_size}")
  274. num_samples_per_rank = padded_total_len // self.world_size
  275. elif self.distributed and self.world_size <= 1:
  276. # Distributed flag set but world_size is 1, treat as non-distributed
  277. pass # num_samples_per_rank remains total_len
  278. self._padded_samples_per_rank = num_samples_per_rank
  279. if num_samples_per_rank == 0:
  280. self._canonical_batch_schedule = []
  281. self._num_batches_per_rank = 0
  282. return
  283. # Use a fixed seed for generating the canonical schedule structure
  284. g = torch.Generator()
  285. g.manual_seed(self.seed) # Use base seed, NOT epoch seed
  286. current_schedule: List[Tuple[int, int]] = []
  287. remaining_samples = num_samples_per_rank
  288. total_scheduled_samples = 0
  289. while remaining_samples > 0:
  290. # Sample sequence length deterministically based on base seed
  291. seq_idx = torch.randint(0, len(self.seq_lens), (1,), generator=g).item()
  292. seq_len = self.seq_lens[seq_idx]
  293. # Calculate batch size
  294. batch_size = calculate_naflex_batch_size(
  295. tokens_per_batch=self.max_tokens_per_batch,
  296. seq_len=seq_len,
  297. # max_size should be remaining_samples to avoid overshooting
  298. max_size=remaining_samples,
  299. divisor=self.batch_divisor,
  300. rounding='floor',
  301. )
  302. # Ensure batch size is positive and doesn't exceed remaining samples
  303. batch_size = max(1, batch_size)
  304. batch_size = min(batch_size, remaining_samples)
  305. if batch_size <= 0:
  306. warnings.warn(f"Calculated batch size <= 0 (seq_len={seq_len}, remaining={remaining_samples}). Stopping schedule generation early.")
  307. break # Avoid infinite loop if something goes wrong
  308. current_schedule.append((seq_len, batch_size))
  309. remaining_samples -= batch_size
  310. total_scheduled_samples += batch_size
  311. # Sanity check: Ensure the schedule covers all samples for the rank
  312. if total_scheduled_samples != num_samples_per_rank:
  313. warnings.warn(
  314. f"Rank {self.rank}: Canonical schedule accounts for {total_scheduled_samples} samples, "
  315. f"but expected {num_samples_per_rank} samples per rank. "
  316. f"This might happen if min_batch_size or batch_divisor constraints prevent utilizing all samples. "
  317. f"Check parameters. Remaining samples: {remaining_samples}"
  318. )
  319. # Adjust if needed? Could add a final small batch, but might violate constraints.
  320. # Current behavior: some samples might be dropped if schedule logic fails.
  321. self._canonical_batch_schedule = current_schedule
  322. self._num_batches_per_rank = len(current_schedule)
  323. print(f"Rank {self.rank}: Created canonical schedule with {self._num_batches_per_rank} batches for {self._padded_samples_per_rank} samples/rank.")
  324. def _prepare_epoch_batches(self, epoch: int):
  325. """
  326. Prepares the batches for the current epoch by:
  327. 1. Shuffling the full dataset indices (using epoch seed).
  328. 2. Applying padding if in distributed mode.
  329. 3. Selecting indices for the current rank.
  330. 4. Shuffling the *order* of the canonical batch schedule (using epoch seed).
  331. 5. Assigning the rank's indices to the shuffled batches.
  332. """
  333. g = torch.Generator()
  334. g.manual_seed(self.seed + epoch) # Epoch-specific seed for shuffling
  335. # 1. Get shuffled global indices
  336. total_len = len(self.base_dataset)
  337. if self.shuffle:
  338. all_indices_shuffled = torch.randperm(total_len, generator=g).tolist()
  339. else:
  340. all_indices_shuffled = list(range(total_len))
  341. # 2. Apply padding for distributed mode
  342. indices_for_ranks = all_indices_shuffled
  343. if self.distributed and self.world_size > 1:
  344. padded_total_len = self._padded_samples_per_rank * self.world_size
  345. if padded_total_len > total_len:
  346. pad_size = padded_total_len - total_len
  347. # Repeat initial elements from the *shuffled* list for padding
  348. indices_for_ranks = all_indices_shuffled + all_indices_shuffled[:pad_size]
  349. # Ensure length matches expectation
  350. if len(indices_for_ranks) != padded_total_len:
  351. raise RuntimeError(f"Internal Error: Padded index list length {len(indices_for_ranks)} does not match expected {padded_total_len}")
  352. # 3. Select indices for the current rank
  353. if self.distributed and self.world_size > 1:
  354. indices_this_rank = indices_for_ranks[self.rank::self.world_size]
  355. else: # Non-distributed or world_size=1
  356. indices_this_rank = indices_for_ranks
  357. # Sanity check length
  358. if len(indices_this_rank) != self._padded_samples_per_rank:
  359. # This might happen if canonical schedule generation had warnings/issues
  360. warnings.warn(
  361. f"Rank {self.rank}: Number of indices for this rank ({len(indices_this_rank)}) "
  362. f"does not match expected padded samples per rank ({self._padded_samples_per_rank}). "
  363. f"Epoch generation might be inconsistent."
  364. )
  365. # Adjust expected samples? Or truncate/pad indices? Let's proceed but warn.
  366. # Using min() prevents IndexError later if indices are fewer than expected.
  367. effective_samples_this_rank = min(len(indices_this_rank), self._padded_samples_per_rank)
  368. indices_this_rank = indices_this_rank[:effective_samples_this_rank]
  369. else:
  370. effective_samples_this_rank = self._padded_samples_per_rank
  371. # 4. Shuffle the order of the canonical batch schedule for this epoch
  372. if self.shuffle:
  373. schedule_perm = torch.randperm(self._num_batches_per_rank, generator=g).tolist()
  374. shuffled_schedule = [self._canonical_batch_schedule[i] for i in schedule_perm]
  375. else:
  376. shuffled_schedule = list(self._canonical_batch_schedule) # Keep original order
  377. # 5. Assign indices to the shuffled batches
  378. self._epoch_batches = []
  379. idx_pos = 0
  380. scheduled_samples_count = 0
  381. for seq_len, bs in shuffled_schedule:
  382. # Ensure we don't try to grab more indices than available for the rank
  383. actual_bs = min(bs, effective_samples_this_rank - idx_pos)
  384. if actual_bs <= 0:
  385. if scheduled_samples_count < effective_samples_this_rank:
  386. # This indicates mismatch between schedule total and actual samples
  387. warnings.warn(f"Rank {self.rank}: Ran out of samples ({idx_pos}/{effective_samples_this_rank}) before processing entire schedule. Check schedule generation.")
  388. break # Stop if no more indices or batch size is zero
  389. batch_indices = indices_this_rank[idx_pos : idx_pos + actual_bs]
  390. self._epoch_batches.append((seq_len, batch_indices))
  391. idx_pos += actual_bs
  392. scheduled_samples_count += actual_bs
  393. # Final check
  394. if scheduled_samples_count != effective_samples_this_rank:
  395. warnings.warn(
  396. f"Rank {self.rank}: Assigned {scheduled_samples_count} samples to batches, "
  397. f"but expected {effective_samples_this_rank} effective samples this epoch. "
  398. f"Indices remaining: {effective_samples_this_rank - scheduled_samples_count}."
  399. )
  400. def set_epoch(self, epoch: int) -> None:
  401. """Updates the epoch, regenerating the epoch-specific batches.
  402. Args:
  403. epoch: New epoch number.
  404. """
  405. # Only regenerate if the epoch actually changes
  406. if epoch != self.epoch:
  407. self.epoch = epoch
  408. self._prepare_epoch_batches(epoch)
  409. def __len__(self) -> int:
  410. """Returns the number of batches per worker for the current epoch.
  411. Returns:
  412. Number of batches this worker will process.
  413. """
  414. return self._num_batches_per_rank
  415. def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
  416. """Iterates through pre-calculated batches for the current epoch.
  417. Yields:
  418. Tuple of (input_dict, targets) for each batch.
  419. """
  420. worker_info = torch.utils.data.get_worker_info()
  421. num_workers = worker_info.num_workers if worker_info else 1
  422. worker_id = worker_info.id if worker_info else 0
  423. # Distribute pre-calculated batches among workers for this rank
  424. # Each worker processes a slice of the batches prepared in _prepare_epoch_batches
  425. batches_for_worker = self._epoch_batches[worker_id::num_workers]
  426. for seq_len, indices in batches_for_worker:
  427. if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
  428. continue
  429. # Select patch size for this batch
  430. patch_idx = 0
  431. if self.variable_patch_size:
  432. # Use torch multinomial for weighted random choice
  433. patch_idx = torch.multinomial(torch.tensor(self.patch_size_probs), 1).item()
  434. # Get the pre-initialized transform and patchifier using patch_idx
  435. transform_key = (seq_len, patch_idx)
  436. transform = self.transforms.get(transform_key)
  437. batch_patchifier = self.patchifiers[patch_idx]
  438. batch_imgs = []
  439. batch_targets = []
  440. for idx in indices:
  441. try:
  442. # Get original image and label from map-style dataset
  443. img, label = self.base_dataset[idx]
  444. # Apply transform if available
  445. # Handle cases where transform might return None or fail
  446. processed_img = transform(img) if transform else img
  447. if processed_img is None:
  448. warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
  449. continue
  450. batch_imgs.append(processed_img)
  451. batch_targets.append(label)
  452. except IndexError:
  453. warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
  454. continue
  455. except Exception as e:
  456. # Log other potential errors during data loading/processing
  457. warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
  458. continue # Skip problematic sample
  459. if self.mixup_fn is not None:
  460. batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)
  461. batch_imgs = [batch_patchifier(img) for img in batch_imgs]
  462. batch_samples = list(zip(batch_imgs, batch_targets))
  463. if batch_samples: # Only yield if we successfully processed samples
  464. # Collate the processed samples into a batch
  465. yield self.collate_fns[seq_len](batch_samples)
  466. # If batch_samples is empty after processing 'indices', an empty batch is skipped.