_commit_scheduler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import atexit
  2. import logging
  3. import os
  4. import time
  5. from concurrent.futures import Future
  6. from dataclasses import dataclass
  7. from io import SEEK_END, SEEK_SET, BytesIO
  8. from pathlib import Path
  9. from threading import Lock, Thread
  10. from typing import Optional
  11. from .hf_api import DEFAULT_IGNORE_PATTERNS, CommitInfo, CommitOperationAdd, HfApi
  12. from .utils import filter_repo_objects
  13. logger = logging.getLogger(__name__)
  14. @dataclass(frozen=True)
  15. class _FileToUpload:
  16. """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
  17. local_path: Path
  18. path_in_repo: str
  19. size_limit: int
  20. last_modified: float
  21. class CommitScheduler:
  22. """
  23. Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
  24. The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
  25. properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
  26. with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
  27. to learn more about how to use it.
  28. Args:
  29. repo_id (`str`):
  30. The id of the repo to commit to.
  31. folder_path (`str` or `Path`):
  32. Path to the local folder to upload regularly.
  33. every (`int` or `float`, *optional*):
  34. The number of minutes between each commit. Defaults to 5 minutes.
  35. path_in_repo (`str`, *optional*):
  36. Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
  37. of the repository.
  38. repo_type (`str`, *optional*):
  39. The type of the repo to commit to. Defaults to `model`.
  40. revision (`str`, *optional*):
  41. The revision of the repo to commit to. Defaults to `main`.
  42. private (`bool`, *optional*):
  43. Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
  44. token (`str`, *optional*):
  45. The token to use to commit to the repo. Defaults to the token saved on the machine.
  46. allow_patterns (`list[str]` or `str`, *optional*):
  47. If provided, only files matching at least one pattern are uploaded.
  48. ignore_patterns (`list[str]` or `str`, *optional*):
  49. If provided, files matching any of the patterns are not uploaded.
  50. squash_history (`bool`, *optional*):
  51. Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
  52. useful to avoid degraded performances on the repo when it grows too large.
  53. hf_api (`HfApi`, *optional*):
  54. The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
  55. Example:
  56. ```py
  57. >>> from pathlib import Path
  58. >>> from huggingface_hub import CommitScheduler
  59. # Scheduler uploads every 10 minutes
  60. >>> csv_path = Path("watched_folder/data.csv")
  61. >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
  62. >>> with csv_path.open("a") as f:
  63. ... f.write("first line")
  64. # Some time later (...)
  65. >>> with csv_path.open("a") as f:
  66. ... f.write("second line")
  67. ```
  68. Example using a context manager:
  69. ```py
  70. >>> from pathlib import Path
  71. >>> from huggingface_hub import CommitScheduler
  72. >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
  73. ... csv_path = Path("watched_folder/data.csv")
  74. ... with csv_path.open("a") as f:
  75. ... f.write("first line")
  76. ... (...)
  77. ... with csv_path.open("a") as f:
  78. ... f.write("second line")
  79. # Scheduler is now stopped and last commit have been triggered
  80. ```
  81. """
  82. def __init__(
  83. self,
  84. *,
  85. repo_id: str,
  86. folder_path: str | Path,
  87. every: int | float = 5,
  88. path_in_repo: str | None = None,
  89. repo_type: str | None = None,
  90. revision: str | None = None,
  91. private: bool | None = None,
  92. token: str | None = None,
  93. allow_patterns: list[str] | str | None = None,
  94. ignore_patterns: list[str] | str | None = None,
  95. squash_history: bool = False,
  96. hf_api: Optional["HfApi"] = None,
  97. ) -> None:
  98. self.api = hf_api or HfApi(token=token)
  99. # Folder
  100. self.folder_path = Path(folder_path).expanduser().resolve()
  101. self.path_in_repo = path_in_repo or ""
  102. self.allow_patterns = allow_patterns
  103. if ignore_patterns is None:
  104. ignore_patterns = []
  105. elif isinstance(ignore_patterns, str):
  106. ignore_patterns = [ignore_patterns]
  107. self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
  108. if self.folder_path.is_file():
  109. raise ValueError(f"'folder_path' must be a directory, not a file: '{self.folder_path}'.")
  110. self.folder_path.mkdir(parents=True, exist_ok=True)
  111. # Repository
  112. repo_url = self.api.create_repo(repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True)
  113. self.repo_id = repo_url.repo_id
  114. self.repo_type = repo_type
  115. self.revision = revision
  116. self.token = token
  117. # Keep track of already uploaded files
  118. self.last_uploaded: dict[Path, float] = {} # key is local path, value is timestamp
  119. # Scheduler
  120. if not every > 0:
  121. raise ValueError(f"'every' must be a positive integer, not '{every}'.")
  122. self.lock = Lock()
  123. self.every = every
  124. self.squash_history = squash_history
  125. logger.info(f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes.")
  126. self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
  127. self._scheduler_thread.start()
  128. atexit.register(self._push_to_hub)
  129. self.__stopped = False
  130. def stop(self) -> None:
  131. """Stop the scheduler.
  132. A stopped scheduler cannot be restarted. Mostly for tests purposes.
  133. """
  134. self.__stopped = True
  135. def __enter__(self) -> "CommitScheduler":
  136. return self
  137. def __exit__(self, exc_type, exc_value, traceback) -> None:
  138. # Upload last changes before exiting
  139. self.trigger().result()
  140. self.stop()
  141. return
  142. def _run_scheduler(self) -> None:
  143. """Dumb thread waiting between each scheduled push to Hub."""
  144. while True:
  145. self.last_future = self.trigger()
  146. time.sleep(self.every * 60)
  147. if self.__stopped:
  148. break
  149. def trigger(self) -> Future:
  150. """Trigger a `push_to_hub` and return a future.
  151. This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
  152. immediately, without waiting for the next scheduled commit.
  153. """
  154. return self.api.run_as_future(self._push_to_hub)
  155. def _push_to_hub(self) -> CommitInfo | None:
  156. if self.__stopped: # If stopped, already scheduled commits are ignored
  157. return None
  158. logger.info("(Background) scheduled commit triggered.")
  159. try:
  160. value = self.push_to_hub()
  161. if self.squash_history:
  162. logger.info("(Background) squashing repo history.")
  163. self.api.super_squash_history(repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision)
  164. return value
  165. except Exception as e:
  166. logger.error(f"Error while pushing to Hub: {e}") # Depending on the setup, error might be silenced
  167. raise
  168. def push_to_hub(self) -> CommitInfo | None:
  169. """
  170. Push folder to the Hub and return the commit info.
  171. > [!WARNING]
  172. > This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
  173. > queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
  174. > issues.
  175. The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
  176. uploads only changed files. If no changes are found, the method returns without committing anything. If you want
  177. to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
  178. for example to compress data together in a single file before committing. For more details and examples, check
  179. out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
  180. """
  181. # Check files to upload (with lock)
  182. with self.lock:
  183. logger.debug("Listing files to upload for scheduled commit.")
  184. # List files from folder (taken from `_prepare_upload_folder_additions`)
  185. relpath_to_abspath = {
  186. path.relative_to(self.folder_path).as_posix(): path
  187. for path in sorted(self.folder_path.glob("**/*")) # sorted to be deterministic
  188. if path.is_file()
  189. }
  190. prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
  191. # Filter with pattern + filter out unchanged files + retrieve current file size
  192. files_to_upload: list[_FileToUpload] = []
  193. for relpath in filter_repo_objects(
  194. relpath_to_abspath.keys(), allow_patterns=self.allow_patterns, ignore_patterns=self.ignore_patterns
  195. ):
  196. local_path = relpath_to_abspath[relpath]
  197. stat = local_path.stat()
  198. if self.last_uploaded.get(local_path) is None or self.last_uploaded[local_path] != stat.st_mtime:
  199. files_to_upload.append(
  200. _FileToUpload(
  201. local_path=local_path,
  202. path_in_repo=prefix + relpath,
  203. size_limit=stat.st_size,
  204. last_modified=stat.st_mtime,
  205. )
  206. )
  207. # Return if nothing to upload
  208. if len(files_to_upload) == 0:
  209. logger.debug("Dropping schedule commit: no changed file to upload.")
  210. return None
  211. # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
  212. logger.debug("Removing unchanged files since previous scheduled commit.")
  213. add_operations = [
  214. CommitOperationAdd(
  215. # Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
  216. path_or_fileobj=PartialFileIO(file_to_upload.local_path, size_limit=file_to_upload.size_limit),
  217. path_in_repo=file_to_upload.path_in_repo,
  218. )
  219. for file_to_upload in files_to_upload
  220. ]
  221. # Upload files (append mode expected - no need for lock)
  222. logger.debug("Uploading files for scheduled commit.")
  223. commit_info = self.api.create_commit(
  224. repo_id=self.repo_id,
  225. repo_type=self.repo_type,
  226. operations=add_operations,
  227. commit_message="Scheduled Commit",
  228. revision=self.revision,
  229. )
  230. # Successful commit: keep track of the latest "last_modified" for each file
  231. for file in files_to_upload:
  232. self.last_uploaded[file.local_path] = file.last_modified
  233. return commit_info
  234. class PartialFileIO(BytesIO):
  235. """A file-like object that reads only the first part of a file.
  236. Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
  237. file is uploaded (i.e. the part that was available when the filesystem was first scanned).
  238. In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
  239. disturbance for the user. The object is passed to `CommitOperationAdd`.
  240. Only supports `read`, `tell` and `seek` methods.
  241. Args:
  242. file_path (`str` or `Path`):
  243. Path to the file to read.
  244. size_limit (`int`):
  245. The maximum number of bytes to read from the file. If the file is larger than this, only the first part
  246. will be read (and uploaded).
  247. """
  248. def __init__(self, file_path: str | Path, size_limit: int) -> None:
  249. self._file_path = Path(file_path)
  250. self._file = self._file_path.open("rb")
  251. self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
  252. def __del__(self) -> None:
  253. self._file.close()
  254. return super().__del__()
  255. def __repr__(self) -> str:
  256. return f"<PartialFileIO file_path={self._file_path} size_limit={self._size_limit}>"
  257. def __len__(self) -> int:
  258. return self._size_limit
  259. def __getattribute__(self, name: str):
  260. if name.startswith("_") or name in ("read", "tell", "seek", "fileno"): # only 4 public methods supported
  261. return super().__getattribute__(name)
  262. raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
  263. def fileno(self):
  264. raise AttributeError("PartialFileIO does not have a fileno.")
  265. def tell(self) -> int:
  266. """Return the current file position."""
  267. return self._file.tell()
  268. def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
  269. """Change the stream position to the given offset.
  270. Behavior is the same as a regular file, except that the position is capped to the size limit.
  271. """
  272. if __whence == SEEK_END:
  273. # SEEK_END => set from the truncated end
  274. __offset = len(self) + __offset
  275. __whence = SEEK_SET
  276. pos = self._file.seek(__offset, __whence)
  277. if pos > self._size_limit:
  278. return self._file.seek(self._size_limit)
  279. return pos
  280. def read(self, __size: int | None = -1) -> bytes:
  281. """Read at most `__size` bytes from the file.
  282. Behavior is the same as a regular file, except that it is capped to the size limit.
  283. """
  284. current = self._file.tell()
  285. if __size is None or __size < 0:
  286. # Read until file limit
  287. truncated_size = self._size_limit - current
  288. else:
  289. # Read until file limit or __size
  290. truncated_size = min(__size, self._size_limit - current)
  291. return self._file.read(truncated_size)