video_utils.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import warnings
  16. from collections.abc import Callable, Iterable, Mapping
  17. from contextlib import redirect_stdout
  18. from dataclasses import dataclass, fields
  19. from io import BytesIO
  20. from typing import NewType, Union
  21. from urllib.parse import urlparse
  22. import httpx
  23. import numpy as np
  24. from .image_transforms import PaddingMode, to_channel_dimension_format
  25. from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
  26. from .utils import (
  27. is_av_available,
  28. is_cv2_available,
  29. is_decord_available,
  30. is_numpy_array,
  31. is_torch_available,
  32. is_torch_tensor,
  33. is_torchcodec_available,
  34. is_torchvision_available,
  35. is_vision_available,
  36. is_yt_dlp_available,
  37. logging,
  38. requires_backends,
  39. )
  40. if is_vision_available():
  41. import PIL.Image
  42. if is_torchvision_available():
  43. from torchvision import io as torchvision_io
  44. if is_torch_available():
  45. import torch
  46. logger = logging.get_logger(__name__)
  47. URL = NewType("URL", str)
  48. Path = NewType("Path", str)
  49. VideoInput = Union[
  50. list["PIL.Image.Image"],
  51. np.ndarray,
  52. "torch.Tensor",
  53. list[np.ndarray],
  54. list["torch.Tensor"],
  55. list[list["PIL.Image.Image"]],
  56. list[list[np.ndarray]],
  57. list[list["torch.Tensor"]],
  58. URL,
  59. list[URL],
  60. list[list[URL]],
  61. Path,
  62. list[Path],
  63. list[list[Path]],
  64. ]
  65. @dataclass
  66. class VideoMetadata(Mapping):
  67. total_num_frames: int
  68. fps: float | None = None
  69. width: int | None = None
  70. height: int | None = None
  71. duration: float | None = None
  72. video_backend: str | None = None
  73. frames_indices: list[int] | None = None
  74. def __iter__(self):
  75. return (f.name for f in fields(self))
  76. def __len__(self):
  77. return len(fields(self))
  78. def __getitem__(self, item):
  79. return getattr(self, item)
  80. def __setitem__(self, key, value):
  81. return setattr(self, key, value)
  82. @property
  83. def timestamps(self) -> list[float]:
  84. "Timestamps of the sampled frames in seconds."
  85. if self.fps is None or self.frames_indices is None:
  86. raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
  87. return [frame_idx / self.fps for frame_idx in self.frames_indices]
  88. @property
  89. def sampled_fps(self) -> float:
  90. "FPS of the sampled video."
  91. if self.frames_indices is None or self.total_num_frames is None or self.fps is None:
  92. return self.fps or 24
  93. return len(self.frames_indices) / self.total_num_frames * self.fps
  94. def update(self, dictionary):
  95. for key, value in dictionary.items():
  96. if hasattr(self, key):
  97. setattr(self, key, value)
  98. VideoMetadataType = VideoMetadata | dict | list[dict | VideoMetadata] | list[list[dict | VideoMetadata]]
  99. def is_valid_video_frame(frame):
  100. return isinstance(frame, PIL.Image.Image) or (
  101. (is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
  102. )
  103. def is_valid_video(video):
  104. if not isinstance(video, (list, tuple)):
  105. return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
  106. return video and all(is_valid_video_frame(frame) for frame in video)
  107. def valid_videos(videos):
  108. # If we have a list of videos, it could be either one video as list of frames or a batch
  109. if isinstance(videos, (list, tuple)):
  110. for video_or_frame in videos:
  111. if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
  112. return False
  113. # If not a list, then we have a single 4D video or 5D batched tensor
  114. elif not is_valid_video(videos) or videos.ndim == 5:
  115. return False
  116. return True
  117. def is_batched_video(videos):
  118. if isinstance(videos, (list, tuple)):
  119. return is_valid_video(videos[0])
  120. elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
  121. return True
  122. return False
  123. def is_scaled_video(video: np.ndarray) -> bool:
  124. """
  125. Checks to see whether the pixel values have already been rescaled to [0, 1].
  126. """
  127. # It's possible the video has pixel values in [0, 255] but is of floating type
  128. return np.min(video) >= 0 and np.max(video) <= 1
  129. def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union[np.ndarray, "torch.Tensor"]]:
  130. """
  131. Given a batch of videos, converts each video to a 4D array. If video is already in array type,
  132. it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
  133. Args:
  134. videos (`VideoInput`):
  135. Video inputs to turn into a list of videos.
  136. """
  137. if not (isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0])):
  138. return videos
  139. video_converted = []
  140. for video in videos:
  141. video = [np.array(frame) for frame in video]
  142. video = np.stack(video)
  143. video_converted.append(video)
  144. return video_converted
  145. def make_batched_videos(videos) -> list[Union[np.ndarray, "torch.Tensor", "URL", "Path"]]:
  146. """
  147. Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
  148. If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
  149. frames are converted to 4D arrays.
  150. We assume that all inputs in the list are in the same format, based on the type of the first element.
  151. Args:
  152. videos (`VideoInput`):
  153. Video inputs to turn into a list of videos.
  154. """
  155. # Early exit for deeply nested list of image frame paths. We shouldn't flatten them
  156. try:
  157. if isinstance(videos[0][0], list) and isinstance(videos[0][0][0], str):
  158. return [image_paths for sublist in videos for image_paths in sublist]
  159. except (IndexError, TypeError):
  160. pass
  161. if is_batched_video(videos):
  162. return convert_pil_frames_to_video(list(videos))
  163. elif isinstance(videos, str) or is_valid_video(videos):
  164. return convert_pil_frames_to_video([videos])
  165. # only one frame passed, thus we unsqueeze time dim
  166. elif is_valid_image(videos):
  167. if isinstance(videos, PIL.Image.Image):
  168. videos = np.array(videos)
  169. return [videos[None, ...]]
  170. elif not isinstance(videos, list):
  171. raise ValueError(
  172. f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
  173. f" type {type(videos)}."
  174. )
  175. # Recursively flatten any nested structure
  176. flat_videos_list = []
  177. for item in videos:
  178. if isinstance(item, str) or is_valid_video(item):
  179. flat_videos_list.append(item)
  180. elif isinstance(item, list) and item:
  181. flat_videos_list.extend(make_batched_videos(item))
  182. flat_videos_list = convert_pil_frames_to_video(flat_videos_list)
  183. return flat_videos_list
  184. def make_batched_metadata(videos: VideoInput, video_metadata: VideoMetadataType) -> list[VideoMetadata]:
  185. if video_metadata is None:
  186. # Create default metadata and fill attributes we can infer from given video
  187. video_metadata = [
  188. {
  189. "total_num_frames": len(video),
  190. "fps": None,
  191. "duration": None,
  192. "frames_indices": list(range(len(video))),
  193. "height": get_video_size(video)[0] if is_valid_video(video) else None,
  194. "width": get_video_size(video)[1] if is_valid_video(video) else None,
  195. }
  196. for video in videos
  197. ]
  198. if isinstance(video_metadata, list):
  199. # Flatten if nested list
  200. if isinstance(video_metadata[0], list):
  201. video_metadata = [
  202. VideoMetadata(**metadata) for metadata_list in video_metadata for metadata in metadata_list
  203. ]
  204. # Simply wrap in VideoMetadata if simple dict
  205. elif isinstance(video_metadata[0], dict):
  206. video_metadata = [VideoMetadata(**metadata) for metadata in video_metadata]
  207. else:
  208. # Create a batched list from single object
  209. video_metadata = [VideoMetadata(**video_metadata)]
  210. return video_metadata
  211. def get_video_size(video: np.ndarray, channel_dim: ChannelDimension | None = None) -> tuple[int, int]:
  212. """
  213. Returns the (height, width) dimensions of the video.
  214. Args:
  215. video (`np.ndarray`):
  216. The video to get the dimensions of.
  217. channel_dim (`ChannelDimension`, *optional*):
  218. Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
  219. Returns:
  220. A tuple of the video's height and width.
  221. """
  222. if channel_dim is None:
  223. channel_dim = infer_channel_dimension_format(video, num_channels=(1, 3, 4))
  224. if channel_dim == ChannelDimension.FIRST:
  225. return video.shape[-2], video.shape[-1]
  226. elif channel_dim == ChannelDimension.LAST:
  227. return video.shape[-3], video.shape[-2]
  228. else:
  229. raise ValueError(f"Unsupported data format: {channel_dim}")
  230. def get_uniform_frame_indices(total_num_frames: int, num_frames: int | None = None):
  231. """
  232. Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
  233. when loading a video.
  234. Args:
  235. total_num_frames (`int`):
  236. Total number of frames that a video has.
  237. num_frames (`int`, *optional*):
  238. Number of frames to sample uniformly. If not specified, all frames are sampled.
  239. Returns:
  240. np.ndarray: np array of frame indices that will be sampled.
  241. """
  242. if num_frames is not None:
  243. indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
  244. else:
  245. indices = np.arange(0, total_num_frames).astype(int)
  246. return indices
  247. def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
  248. """
  249. A default sampling function that replicates the logic used in get_uniform_frame_indices,
  250. while optionally handling `fps` if `num_frames` is not provided.
  251. Args:
  252. metadata (`VideoMetadata`):
  253. `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
  254. num_frames (`int`, *optional*):
  255. Number of frames to sample uniformly.
  256. fps (`int` or `float`, *optional*):
  257. Desired frames per second. Takes priority over num_frames if both are provided.
  258. Returns:
  259. `np.ndarray`: Array of frame indices to sample.
  260. """
  261. total_num_frames = metadata.total_num_frames
  262. video_fps = metadata.fps
  263. # If num_frames is not given but fps is, calculate num_frames from fps
  264. if num_frames is None and fps is not None:
  265. num_frames = int(total_num_frames / video_fps * fps)
  266. if num_frames > total_num_frames:
  267. raise ValueError(
  268. f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
  269. f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
  270. )
  271. if num_frames is not None:
  272. indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
  273. else:
  274. indices = np.arange(0, total_num_frames, dtype=int)
  275. return indices
  276. def read_video_opencv(
  277. video_path: Union["URL", "Path"],
  278. sample_indices_fn: Callable,
  279. **kwargs,
  280. ) -> tuple[np.ndarray, VideoMetadata]:
  281. """
  282. Decode a video using the OpenCV backend.
  283. Args:
  284. video_path (`str`):
  285. Path to the video file.
  286. sample_indices_fn (`Callable`):
  287. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  288. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  289. If not provided, simple uniform sampling with fps is performed.
  290. Example:
  291. def sample_indices_fn(metadata, **kwargs):
  292. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  293. Returns:
  294. tuple[`np.ndarray`, `VideoMetadata`]: A tuple containing:
  295. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  296. - `VideoMetadata` object.
  297. """
  298. # Lazy import cv2
  299. requires_backends(read_video_opencv, ["cv2"])
  300. import cv2
  301. video = cv2.VideoCapture(video_path)
  302. total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
  303. video_fps = video.get(cv2.CAP_PROP_FPS)
  304. duration = total_num_frames / video_fps if video_fps else 0
  305. metadata = VideoMetadata(
  306. total_num_frames=int(total_num_frames),
  307. fps=float(video_fps),
  308. duration=float(duration),
  309. video_backend="opencv",
  310. height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
  311. width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
  312. )
  313. indices = sample_indices_fn(metadata=metadata, **kwargs)
  314. index = 0
  315. frames = []
  316. while video.isOpened():
  317. success, frame = video.read()
  318. if not success:
  319. break
  320. if index in indices:
  321. height, width, channel = frame.shape
  322. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  323. frames.append(frame[0:height, 0:width, 0:channel])
  324. if success:
  325. index += 1
  326. if index >= total_num_frames:
  327. break
  328. video.release()
  329. metadata.frames_indices = indices
  330. return np.stack(frames), metadata
  331. def read_video_decord(
  332. video_path: Union["URL", "Path"],
  333. sample_indices_fn: Callable,
  334. **kwargs,
  335. ):
  336. """
  337. Decode a video using the Decord backend.
  338. Args:
  339. video_path (`str`):
  340. Path to the video file.
  341. sample_indices_fn (`Callable`):
  342. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  343. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  344. If not provided, simple uniform sampling with fps is performed.
  345. Example:
  346. def sample_indices_fn(metadata, **kwargs):
  347. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  348. Returns:
  349. tuple[`np.array`, `VideoMetadata`]: A tuple containing:
  350. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  351. - `VideoMetadata` object.
  352. """
  353. # Lazy import from decord
  354. requires_backends(read_video_decord, ["decord"])
  355. from decord import VideoReader, cpu
  356. vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
  357. video_fps = vr.get_avg_fps()
  358. total_num_frames = len(vr)
  359. duration = total_num_frames / video_fps if video_fps else 0
  360. metadata = VideoMetadata(
  361. total_num_frames=int(total_num_frames),
  362. fps=float(video_fps),
  363. duration=float(duration),
  364. video_backend="decord",
  365. )
  366. indices = sample_indices_fn(metadata=metadata, **kwargs)
  367. video = vr.get_batch(indices).asnumpy()
  368. metadata.update(
  369. {
  370. "frames_indices": indices,
  371. "height": video.shape[1],
  372. "width": video.shape[2],
  373. }
  374. )
  375. return video, metadata
  376. def read_video_pyav(
  377. video_path: Union["URL", "Path"],
  378. sample_indices_fn: Callable,
  379. **kwargs,
  380. ):
  381. """
  382. Decode the video with PyAV decoder.
  383. Args:
  384. video_path (`str`):
  385. Path to the video file.
  386. sample_indices_fn (`Callable`, *optional*):
  387. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  388. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  389. If not provided, simple uniform sampling with fps is performed.
  390. Example:
  391. def sample_indices_fn(metadata, **kwargs):
  392. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  393. Returns:
  394. tuple[`np.array`, `VideoMetadata`]: A tuple containing:
  395. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  396. - `VideoMetadata` object.
  397. """
  398. # Lazy import av
  399. requires_backends(read_video_pyav, ["av"])
  400. import av
  401. container = av.open(video_path)
  402. total_num_frames = container.streams.video[0].frames
  403. video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
  404. duration = total_num_frames / video_fps if video_fps else 0
  405. metadata = VideoMetadata(
  406. total_num_frames=int(total_num_frames),
  407. fps=float(video_fps),
  408. duration=float(duration),
  409. video_backend="pyav",
  410. height=container.streams.video[0].height,
  411. width=container.streams.video[0].width,
  412. )
  413. indices = sample_indices_fn(metadata=metadata, **kwargs)
  414. frames = []
  415. container.seek(0)
  416. end_index = indices[-1]
  417. for i, frame in enumerate(container.decode(video=0)):
  418. if i > end_index:
  419. break
  420. if i >= 0 and i in indices:
  421. frames.append(frame)
  422. video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
  423. metadata.frames_indices = indices
  424. return video, metadata
  425. def read_video_torchvision(
  426. video_path: Union["URL", "Path"],
  427. sample_indices_fn: Callable,
  428. **kwargs,
  429. ):
  430. """
  431. Decode the video with torchvision decoder.
  432. Args:
  433. video_path (`str`):
  434. Path to the video file.
  435. sample_indices_fn (`Callable`, *optional*):
  436. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  437. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  438. If not provided, simple uniform sampling with fps is performed.
  439. Example:
  440. def sample_indices_fn(metadata, **kwargs):
  441. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  442. Returns:
  443. tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
  444. - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
  445. - `VideoMetadata` object.
  446. """
  447. warnings.warn(
  448. "Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
  449. "Please use `torchcodec` instead."
  450. )
  451. video, _, info = torchvision_io.read_video(
  452. video_path,
  453. start_pts=0.0,
  454. end_pts=None,
  455. pts_unit="sec",
  456. output_format="TCHW",
  457. )
  458. video_fps = info["video_fps"]
  459. total_num_frames = video.size(0)
  460. duration = total_num_frames / video_fps if video_fps else 0
  461. metadata = VideoMetadata(
  462. total_num_frames=int(total_num_frames),
  463. fps=float(video_fps),
  464. duration=float(duration),
  465. video_backend="torchvision",
  466. )
  467. indices = sample_indices_fn(metadata=metadata, **kwargs)
  468. video = video[indices].contiguous()
  469. metadata.update(
  470. {
  471. "frames_indices": indices,
  472. "height": video.shape[2],
  473. "width": video.shape[3],
  474. }
  475. )
  476. return video, metadata
  477. def read_video_torchcodec(
  478. video_path: Union["URL", "Path"],
  479. sample_indices_fn: Callable,
  480. **kwargs,
  481. ):
  482. """
  483. Decode the video with torchcodec decoder.
  484. Args:
  485. video_path (`str`):
  486. Path to the video file.
  487. sample_indices_fn (`Callable`):
  488. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  489. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  490. If not provided, simple uniform sampling with fps is performed.
  491. Example:
  492. def sample_indices_fn(metadata, **kwargs):
  493. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  494. Returns:
  495. Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
  496. - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
  497. - `VideoMetadata` object.
  498. """
  499. # Lazy import torchcodec
  500. requires_backends(read_video_torchcodec, ["torchcodec"])
  501. from torchcodec.decoders import VideoDecoder
  502. # VideoDecoder expects a string for device, default to "cpu" if None
  503. decoder = VideoDecoder(
  504. video_path,
  505. # Interestingly `exact` mode takes less than approximate when we load the whole video
  506. seek_mode="exact",
  507. # Allow FFmpeg decide on the number of threads for efficiency
  508. num_ffmpeg_threads=0,
  509. device=kwargs.get("device", "cpu"),
  510. )
  511. total_num_frames = decoder.metadata.num_frames
  512. video_fps = decoder.metadata.average_fps
  513. metadata = VideoMetadata(
  514. total_num_frames=total_num_frames,
  515. fps=video_fps,
  516. duration=decoder.metadata.duration_seconds,
  517. video_backend="torchcodec",
  518. height=decoder.metadata.height,
  519. width=decoder.metadata.width,
  520. )
  521. indices = sample_indices_fn(metadata=metadata, **kwargs)
  522. video = decoder.get_frames_at(indices=indices).data.contiguous()
  523. metadata.frames_indices = indices
  524. return video, metadata
  525. VIDEO_DECODERS = {
  526. "decord": read_video_decord,
  527. "opencv": read_video_opencv,
  528. "pyav": read_video_pyav,
  529. "torchvision": read_video_torchvision,
  530. "torchcodec": read_video_torchcodec,
  531. }
  532. def load_video(
  533. video: VideoInput,
  534. num_frames: int | None = None,
  535. fps: int | float | None = None,
  536. backend: str = "pyav",
  537. sample_indices_fn: Callable | None = None,
  538. **kwargs,
  539. ) -> np.ndarray:
  540. """
  541. Loads `video` to a numpy array.
  542. Args:
  543. video (`VideoInput`):
  544. The video to convert to the numpy array format. Can be a link to video or local path.
  545. num_frames (`int`, *optional*):
  546. Number of frames to sample uniformly. If not passed, the whole video is loaded.
  547. fps (`int` or `float`, *optional*):
  548. Number of frames to sample per second. Should be passed only when `num_frames=None`.
  549. If not specified and `num_frames==None`, all frames are sampled.
  550. backend (`str`, *optional*, defaults to `"pyav"`):
  551. The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
  552. sample_indices_fn (`Callable`, *optional*):
  553. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  554. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  555. If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
  556. The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
  557. indices at which the video should be sampled. For example:
  558. Example:
  559. def sample_indices_fn(metadata, **kwargs):
  560. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  561. Returns:
  562. tuple[`np.ndarray`, Dict]: A tuple containing:
  563. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  564. - Metadata dictionary.
  565. """
  566. # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
  567. if fps is not None and num_frames is not None and sample_indices_fn is None:
  568. raise ValueError(
  569. "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
  570. )
  571. # If user didn't pass a sampling function, create one on the fly with default logic
  572. if sample_indices_fn is None:
  573. def sample_indices_fn_func(metadata, **fn_kwargs):
  574. return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
  575. sample_indices_fn = sample_indices_fn_func
  576. # Early exit if provided an array or `PIL` frames
  577. if not isinstance(video, str):
  578. metadata = [None] * len(video)
  579. return video, metadata
  580. if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
  581. if not is_yt_dlp_available():
  582. raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
  583. # Lazy import from yt_dlp
  584. requires_backends(load_video, ["yt_dlp"])
  585. from yt_dlp import YoutubeDL
  586. buffer = BytesIO()
  587. with redirect_stdout(buffer), YoutubeDL() as f:
  588. f.download([video])
  589. bytes_obj = buffer.getvalue()
  590. file_obj = BytesIO(bytes_obj)
  591. elif video.startswith("http://") or video.startswith("https://"):
  592. file_obj = BytesIO(httpx.get(video, follow_redirects=True).content)
  593. elif os.path.isfile(video):
  594. file_obj = video
  595. else:
  596. raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
  597. # can also load with decord, but not cv2/torchvision
  598. # both will fail in case of url links
  599. video_is_url = video.startswith("http://") or video.startswith("https://")
  600. if video_is_url and backend == "opencv":
  601. raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
  602. if (
  603. (not is_decord_available() and backend == "decord")
  604. or (not is_av_available() and backend == "pyav")
  605. or (not is_cv2_available() and backend == "opencv")
  606. or (not is_torchvision_available() and backend == "torchvision")
  607. or (not is_torchcodec_available() and backend == "torchcodec")
  608. ):
  609. raise ImportError(
  610. f"You chose backend={backend} for loading the video but the required library is not found in your environment "
  611. f"Make sure to install {backend} before loading the video."
  612. )
  613. video_decoder = VIDEO_DECODERS[backend]
  614. video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
  615. return video, metadata
  616. def convert_to_rgb(
  617. video: np.ndarray,
  618. input_data_format: str | ChannelDimension | None = None,
  619. ) -> np.ndarray:
  620. """
  621. Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
  622. Args:
  623. video (`np.ndarray`):
  624. The video to convert.
  625. input_data_format (`ChannelDimension`, *optional*):
  626. The channel dimension format of the input video. If unset, will use the inferred format from the input.
  627. """
  628. if not isinstance(video, np.ndarray):
  629. raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
  630. # np.array usually comes with ChannelDimension.LAST so let's convert it
  631. if input_data_format is None:
  632. input_data_format = infer_channel_dimension_format(video)
  633. video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
  634. # 3 channels for RGB already
  635. if video.shape[-3] == 3:
  636. return video
  637. # Grayscale video so we repeat it 3 times for each channel
  638. if video.shape[-3] == 1:
  639. return video.repeat(3, -3)
  640. if not (video[..., 3, :, :] < 255).any():
  641. return video
  642. # There is a transparency layer, blend it with a white background.
  643. # Calculate the alpha proportion for blending.
  644. alpha = video[..., 3, :, :] / 255.0
  645. video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
  646. return video
  647. def pad(
  648. video: np.ndarray,
  649. padding: int | tuple[int, int] | Iterable[tuple[int, int]],
  650. mode: PaddingMode = PaddingMode.CONSTANT,
  651. constant_values: float | Iterable[float] = 0.0,
  652. data_format: str | ChannelDimension | None = None,
  653. input_data_format: str | ChannelDimension | None = None,
  654. ) -> np.ndarray:
  655. """
  656. Pads the `video` with the specified (height, width) `padding` and `mode`.
  657. Args:
  658. video (`np.ndarray`):
  659. The video to pad.
  660. padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
  661. Padding to apply to the edges of the height, width axes. Can be one of three formats:
  662. - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
  663. - `((before, after),)` yields same before and after pad for height and width.
  664. - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
  665. mode (`PaddingMode`):
  666. The padding mode to use. Can be one of:
  667. - `"constant"`: pads with a constant value.
  668. - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
  669. vector along each axis.
  670. - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
  671. - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
  672. constant_values (`float` or `Iterable[float]`, *optional*):
  673. The value to use for the padding if `mode` is `"constant"`.
  674. data_format (`str` or `ChannelDimension`, *optional*):
  675. The channel dimension format for the output video. Can be one of:
  676. - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
  677. - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
  678. If unset, will use same as the input video.
  679. input_data_format (`str` or `ChannelDimension`, *optional*):
  680. The channel dimension format for the input video. Can be one of:
  681. - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
  682. - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
  683. If unset, will use the inferred format of the input video.
  684. Returns:
  685. `np.ndarray`: The padded video.
  686. """
  687. if input_data_format is None:
  688. input_data_format = infer_channel_dimension_format(video)
  689. def _expand_for_data_format(values):
  690. """
  691. Convert values to be in the format expected by np.pad based on the data format.
  692. """
  693. if isinstance(values, (int, float)):
  694. values = ((values, values), (values, values))
  695. elif isinstance(values, tuple) and len(values) == 1:
  696. values = ((values[0], values[0]), (values[0], values[0]))
  697. elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
  698. values = (values, values)
  699. elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
  700. pass
  701. else:
  702. raise ValueError(f"Unsupported format: {values}")
  703. # add 0 for channel dimension
  704. values = (
  705. ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
  706. )
  707. # Add additional padding if there's a batch dimension
  708. values = (0, *values) if video.ndim == 5 else values
  709. return values
  710. padding_map = {
  711. PaddingMode.CONSTANT: "constant",
  712. PaddingMode.REFLECT: "reflect",
  713. PaddingMode.REPLICATE: "replicate",
  714. PaddingMode.SYMMETRIC: "symmetric",
  715. }
  716. padding = _expand_for_data_format(padding)
  717. pad_kwargs = {}
  718. if mode not in padding_map:
  719. raise ValueError(f"Invalid padding mode: {mode}")
  720. elif mode == PaddingMode.CONSTANT:
  721. pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
  722. video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
  723. video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
  724. return video
  725. def group_videos_by_shape(
  726. videos: list["torch.Tensor"],
  727. ) -> tuple[dict[tuple[int, int], "torch.Tensor"], dict[int, tuple[tuple[int, int], int]]]:
  728. """
  729. Groups videos by shape.
  730. Returns a dictionary with the shape as key and a list of videos with that shape as value,
  731. and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
  732. """
  733. grouped_videos = {}
  734. grouped_videos_index = {}
  735. for i, video in enumerate(videos):
  736. shape = video.shape[-2::]
  737. num_frames = video.shape[-4] # video format BTCHW
  738. shape = (num_frames, *shape)
  739. if shape not in grouped_videos:
  740. grouped_videos[shape] = []
  741. grouped_videos[shape].append(video)
  742. grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
  743. # stack videos with the same size and number of frames
  744. grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
  745. return grouped_videos, grouped_videos_index
  746. def reorder_videos(
  747. processed_videos: dict[tuple[int, int], "torch.Tensor"],
  748. grouped_videos_index: dict[int, tuple[tuple[int, int], int]],
  749. ) -> list["torch.Tensor"]:
  750. """
  751. Reconstructs a list of videos in the original order.
  752. """
  753. return [
  754. processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
  755. for i in range(len(grouped_videos_index))
  756. ]