_tensorboard_logger.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains a logger to push training logs to the Hub, using Tensorboard."""
  15. from pathlib import Path
  16. from ._commit_scheduler import CommitScheduler
  17. from .errors import EntryNotFoundError
  18. from .repocard import ModelCard
  19. from .utils import experimental
  20. # Depending on user's setup, SummaryWriter can come either from 'tensorboardX'
  21. # or from 'torch.utils.tensorboard'. Both are compatible so let's try to load
  22. # from either of them.
  23. try:
  24. from tensorboardX import SummaryWriter as _RuntimeSummaryWriter
  25. is_summary_writer_available = True
  26. except ImportError:
  27. try:
  28. from torch.utils.tensorboard import SummaryWriter as _RuntimeSummaryWriter
  29. is_summary_writer_available = True
  30. except ImportError:
  31. # Dummy class to avoid failing at import. Will raise on instance creation.
  32. class _DummySummaryWriter:
  33. pass
  34. _RuntimeSummaryWriter = _DummySummaryWriter # type: ignore[assignment]
  35. is_summary_writer_available = False
  36. class HFSummaryWriter(_RuntimeSummaryWriter):
  37. """
  38. Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub.
  39. Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate
  40. thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection
  41. issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
  42. minutes (default to every 5 minutes).
  43. > [!WARNING]
  44. > `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
  45. Args:
  46. repo_id (`str`):
  47. The id of the repo to which the logs will be pushed.
  48. logdir (`str`, *optional*):
  49. The directory where the logs will be written. If not specified, a local directory will be created by the
  50. underlying `SummaryWriter` object.
  51. commit_every (`int` or `float`, *optional*):
  52. The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes.
  53. squash_history (`bool`, *optional*):
  54. Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
  55. useful to avoid degraded performances on the repo when it grows too large.
  56. repo_type (`str`, *optional*):
  57. The type of the repo to which the logs will be pushed. Defaults to "model".
  58. repo_revision (`str`, *optional*):
  59. The revision of the repo to which the logs will be pushed. Defaults to "main".
  60. repo_private (`bool`, *optional*):
  61. Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
  62. path_in_repo (`str`, *optional*):
  63. The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/".
  64. repo_allow_patterns (`list[str]` or `str`, *optional*):
  65. A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the
  66. [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
  67. repo_ignore_patterns (`list[str]` or `str`, *optional*):
  68. A list of patterns to exclude in the upload. Check out the
  69. [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
  70. token (`str`, *optional*):
  71. Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
  72. details
  73. kwargs:
  74. Additional keyword arguments passed to `SummaryWriter`.
  75. Examples:
  76. ```diff
  77. # Taken from https://pytorch.org/docs/stable/tensorboard.html
  78. - from torch.utils.tensorboard import SummaryWriter
  79. + from huggingface_hub import HFSummaryWriter
  80. import numpy as np
  81. - writer = SummaryWriter()
  82. + writer = HFSummaryWriter(repo_id="username/my-trained-model")
  83. for n_iter in range(100):
  84. writer.add_scalar('Loss/train', np.random.random(), n_iter)
  85. writer.add_scalar('Loss/test', np.random.random(), n_iter)
  86. writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
  87. writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
  88. ```
  89. ```py
  90. >>> from huggingface_hub import HFSummaryWriter
  91. # Logs are automatically pushed every 15 minutes (5 by default) + when exiting the context manager
  92. >>> with HFSummaryWriter(repo_id="test_hf_logger", commit_every=15) as logger:
  93. ... logger.add_scalar("a", 1)
  94. ... logger.add_scalar("b", 2)
  95. ```
  96. """
  97. @experimental
  98. def __new__(cls, *args, **kwargs) -> "HFSummaryWriter":
  99. if not is_summary_writer_available:
  100. raise ImportError(
  101. "You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade"
  102. " tensorboardX` first."
  103. )
  104. return super().__new__(cls)
  105. def __init__(
  106. self,
  107. repo_id: str,
  108. *,
  109. logdir: str | None = None,
  110. commit_every: int | float = 5,
  111. squash_history: bool = False,
  112. repo_type: str | None = None,
  113. repo_revision: str | None = None,
  114. repo_private: bool | None = None,
  115. path_in_repo: str | None = "tensorboard",
  116. repo_allow_patterns: list[str] | str | None = "*.tfevents.*",
  117. repo_ignore_patterns: list[str] | str | None = None,
  118. token: str | None = None,
  119. **kwargs,
  120. ):
  121. # Initialize SummaryWriter
  122. super().__init__(logdir=logdir, **kwargs)
  123. # Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it.
  124. if not isinstance(self.logdir, str):
  125. raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.")
  126. # Append logdir name to `path_in_repo`
  127. if path_in_repo is None or path_in_repo == "":
  128. path_in_repo = Path(self.logdir).name
  129. else:
  130. path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name
  131. # Initialize scheduler
  132. self.scheduler = CommitScheduler(
  133. folder_path=self.logdir,
  134. path_in_repo=path_in_repo,
  135. repo_id=repo_id,
  136. repo_type=repo_type,
  137. revision=repo_revision,
  138. private=repo_private,
  139. token=token,
  140. allow_patterns=repo_allow_patterns,
  141. ignore_patterns=repo_ignore_patterns,
  142. every=commit_every,
  143. squash_history=squash_history,
  144. )
  145. # Exposing some high-level info at root level
  146. self.repo_id = self.scheduler.repo_id
  147. self.repo_type = self.scheduler.repo_type
  148. self.repo_revision = self.scheduler.revision
  149. # Add `hf-summary-writer` tag to the model card metadata
  150. try:
  151. card = ModelCard.load(repo_id_or_path=self.repo_id, repo_type=self.repo_type)
  152. except EntryNotFoundError:
  153. card = ModelCard("")
  154. tags = card.data.get("tags", [])
  155. if "hf-summary-writer" not in tags:
  156. tags.append("hf-summary-writer")
  157. card.data["tags"] = tags
  158. card.push_to_hub(repo_id=self.repo_id, repo_type=self.repo_type)
  159. def __exit__(self, exc_type, exc_val, exc_tb):
  160. """Push to hub in a non-blocking way when exiting the logger's context manager."""
  161. super().__exit__(exc_type, exc_val, exc_tb)
  162. future = self.scheduler.trigger()
  163. future.result()