clip_sampler.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import math
  2. from collections.abc import Iterator, Sized
  3. from typing import cast, Optional, Union
  4. import torch
  5. import torch.distributed as dist
  6. from torch.utils.data import Sampler
  7. from torchvision.datasets.video_utils import VideoClips
  8. class DistributedSampler(Sampler):
  9. """
  10. Extension of DistributedSampler, as discussed in
  11. https://github.com/pytorch/pytorch/issues/23430
  12. Example:
  13. dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
  14. num_replicas: 4
  15. shuffle: False
  16. when group_size = 1
  17. RANK | shard_dataset
  18. =========================
  19. rank_0 | [0, 4, 8, 12]
  20. rank_1 | [1, 5, 9, 13]
  21. rank_2 | [2, 6, 10, 0]
  22. rank_3 | [3, 7, 11, 1]
  23. when group_size = 2
  24. RANK | shard_dataset
  25. =========================
  26. rank_0 | [0, 1, 8, 9]
  27. rank_1 | [2, 3, 10, 11]
  28. rank_2 | [4, 5, 12, 13]
  29. rank_3 | [6, 7, 0, 1]
  30. """
  31. def __init__(
  32. self,
  33. dataset: Sized,
  34. num_replicas: Optional[int] = None,
  35. rank: Optional[int] = None,
  36. shuffle: bool = False,
  37. group_size: int = 1,
  38. ) -> None:
  39. if num_replicas is None:
  40. if not dist.is_available():
  41. raise RuntimeError("Requires distributed package to be available")
  42. num_replicas = dist.get_world_size()
  43. if rank is None:
  44. if not dist.is_available():
  45. raise RuntimeError("Requires distributed package to be available")
  46. rank = dist.get_rank()
  47. if len(dataset) % group_size != 0:
  48. raise ValueError(
  49. f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
  50. )
  51. self.dataset = dataset
  52. self.group_size = group_size
  53. self.num_replicas = num_replicas
  54. self.rank = rank
  55. self.epoch = 0
  56. dataset_group_length = len(dataset) // group_size
  57. self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
  58. self.num_samples = self.num_group_samples * group_size
  59. self.total_size = self.num_samples * self.num_replicas
  60. self.shuffle = shuffle
  61. def __iter__(self) -> Iterator[int]:
  62. # deterministically shuffle based on epoch
  63. g = torch.Generator()
  64. g.manual_seed(self.epoch)
  65. indices: Union[torch.Tensor, list[int]]
  66. if self.shuffle:
  67. indices = torch.randperm(len(self.dataset), generator=g).tolist()
  68. else:
  69. indices = list(range(len(self.dataset)))
  70. # add extra samples to make it evenly divisible
  71. indices += indices[: (self.total_size - len(indices))]
  72. assert len(indices) == self.total_size
  73. total_group_size = self.total_size // self.group_size
  74. indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
  75. # subsample
  76. indices = indices[self.rank : total_group_size : self.num_replicas, :]
  77. indices = torch.reshape(indices, (-1,)).tolist()
  78. assert len(indices) == self.num_samples
  79. if isinstance(self.dataset, Sampler):
  80. orig_indices = list(iter(self.dataset))
  81. indices = [orig_indices[i] for i in indices]
  82. return iter(indices)
  83. def __len__(self) -> int:
  84. return self.num_samples
  85. def set_epoch(self, epoch: int) -> None:
  86. self.epoch = epoch
  87. class UniformClipSampler(Sampler):
  88. """
  89. Sample `num_video_clips_per_video` clips for each video, equally spaced.
  90. When number of unique clips in the video is fewer than num_video_clips_per_video,
  91. repeat the clips until `num_video_clips_per_video` clips are collected
  92. Args:
  93. video_clips (VideoClips): video clips to sample from
  94. num_clips_per_video (int): number of clips to be sampled per video
  95. """
  96. def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
  97. if not isinstance(video_clips, VideoClips):
  98. raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
  99. self.video_clips = video_clips
  100. self.num_clips_per_video = num_clips_per_video
  101. def __iter__(self) -> Iterator[int]:
  102. idxs = []
  103. s = 0
  104. # select num_clips_per_video for each video, uniformly spaced
  105. for c in self.video_clips.clips:
  106. length = len(c)
  107. if length == 0:
  108. # corner case where video decoding fails
  109. continue
  110. sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
  111. s += length
  112. idxs.append(sampled)
  113. return iter(cast(list[int], torch.cat(idxs).tolist()))
  114. def __len__(self) -> int:
  115. return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
  116. class RandomClipSampler(Sampler):
  117. """
  118. Samples at most `max_video_clips_per_video` clips for each video randomly
  119. Args:
  120. video_clips (VideoClips): video clips to sample from
  121. max_clips_per_video (int): maximum number of clips to be sampled per video
  122. """
  123. def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
  124. if not isinstance(video_clips, VideoClips):
  125. raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
  126. self.video_clips = video_clips
  127. self.max_clips_per_video = max_clips_per_video
  128. def __iter__(self) -> Iterator[int]:
  129. idxs = []
  130. s = 0
  131. # select at most max_clips_per_video for each video, randomly
  132. for c in self.video_clips.clips:
  133. length = len(c)
  134. size = min(length, self.max_clips_per_video)
  135. sampled = torch.randperm(length)[:size] + s
  136. s += length
  137. idxs.append(sampled)
  138. idxs_ = torch.cat(idxs)
  139. # shuffle all clips randomly
  140. perm = torch.randperm(len(idxs_))
  141. return iter(idxs_[perm].tolist())
  142. def __len__(self) -> int:
  143. return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)