gitlib.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from __future__ import annotations
  2. import configparser
  3. import logging
  4. import os
  5. from typing import TYPE_CHECKING, Any
  6. from urllib.parse import urlparse, urlunparse
  7. import wandb
  8. try:
  9. from git import ( # type: ignore
  10. GitCommandError,
  11. InvalidGitRepositoryError,
  12. NoSuchPathError,
  13. Repo,
  14. )
  15. except ImportError:
  16. pass
  17. if TYPE_CHECKING:
  18. from git import Repo
  19. logger = logging.getLogger(__name__)
  20. class GitRepo:
  21. def __init__(
  22. self,
  23. root: str | None = None,
  24. remote: str = "origin",
  25. lazy: bool = True,
  26. remote_url: str | None = None,
  27. commit: str | None = None,
  28. ) -> None:
  29. self.remote_name = remote if remote_url is None else None
  30. self._root = root
  31. self._remote_url = remote_url
  32. self._commit = commit
  33. self._repo = None
  34. self._repo_initialized = False
  35. if not lazy:
  36. self._repo = self._init_repo()
  37. def _init_repo(self) -> Repo | None:
  38. self._repo_initialized = True
  39. try:
  40. from git import Repo
  41. except ImportError:
  42. return None
  43. if self.remote_name is None:
  44. return None
  45. try:
  46. return Repo(self._root or os.getcwd(), search_parent_directories=True)
  47. except FileNotFoundError:
  48. wandb.termwarn("current working directory has been invalidated")
  49. logger.warning("current working directory has been invalidated")
  50. except InvalidGitRepositoryError:
  51. logger.debug("git repository is invalid")
  52. except NoSuchPathError:
  53. wandb.termwarn(f"git root {self._root} does not exist")
  54. logger.warning(f"git root {self._root} does not exist")
  55. return None
  56. @property
  57. def repo(self) -> Repo | None:
  58. if not self._repo_initialized:
  59. self._repo = self._init_repo()
  60. return self._repo
  61. @property
  62. def auto(self) -> bool:
  63. return self._remote_url is None
  64. def is_untracked(self, file_name: str) -> bool | None:
  65. if not self.repo:
  66. return True
  67. try:
  68. return file_name in self.repo.untracked_files
  69. except GitCommandError:
  70. return None
  71. @property
  72. def enabled(self) -> bool:
  73. return bool(self.repo)
  74. @property
  75. def root(self) -> Any:
  76. if not self.repo:
  77. return None
  78. try:
  79. return self.repo.git.rev_parse("--show-toplevel")
  80. except GitCommandError:
  81. # todo: collect telemetry on this
  82. logger.exception("git root error")
  83. return None
  84. @property
  85. def dirty(self) -> Any:
  86. if not self.repo:
  87. return False
  88. try:
  89. return self.repo.is_dirty()
  90. except GitCommandError:
  91. return False
  92. @property
  93. def email(self) -> str | None:
  94. if not self.repo:
  95. return None
  96. try:
  97. return self.repo.config_reader().get_value("user", "email") # type: ignore
  98. except configparser.Error:
  99. return None
  100. @property
  101. def last_commit(self) -> Any:
  102. if self._commit:
  103. return self._commit
  104. if not self.repo:
  105. return None
  106. if not self.repo.head or not self.repo.head.is_valid():
  107. return None
  108. # TODO: Saw a user getting a Unicode decode error when parsing refs,
  109. # more details on implementing a real fix in [WB-4064]
  110. try:
  111. if len(self.repo.refs) > 0: # type: ignore[arg-type]
  112. return self.repo.head.commit.hexsha
  113. else:
  114. return self.repo.git.show_ref("--head").split(" ")[0]
  115. except Exception:
  116. logger.exception("Unable to find most recent commit in git")
  117. return None
  118. @property
  119. def branch(self) -> Any:
  120. if not self.repo:
  121. return None
  122. return self.repo.head.ref.name
  123. @property
  124. def remote(self) -> Any:
  125. if not self.repo:
  126. return None
  127. try:
  128. return self.repo.remotes[self.remote_name] # type: ignore[index]
  129. except IndexError:
  130. return None
  131. # the --submodule=diff option doesn't exist in pre-2.11 versions of git (november 2016)
  132. # https://stackoverflow.com/questions/10757091/git-list-of-all-changed-files-including-those-in-submodules
  133. @property
  134. def has_submodule_diff(self) -> bool:
  135. if not self.repo:
  136. return False
  137. return bool(self.repo.git.version_info >= (2, 11, 0))
  138. @property
  139. def remote_url(self) -> Any:
  140. if self._remote_url:
  141. return self._remote_url
  142. if not self.remote:
  143. return None
  144. parsed = urlparse(self.remote.url)
  145. hostname = parsed.hostname
  146. if parsed.port is not None:
  147. hostname = f"{hostname}:{parsed.port}"
  148. if parsed.password is not None:
  149. return urlunparse(parsed._replace(netloc=f"{parsed.username}:@{hostname}"))
  150. return urlunparse(parsed._replace(netloc=hostname))
  151. @property
  152. def root_dir(self) -> Any:
  153. if not self.repo:
  154. return None
  155. try:
  156. return self.repo.git.rev_parse("--show-toplevel")
  157. except GitCommandError:
  158. return None
  159. def get_upstream_fork_point(self) -> Any:
  160. """Get the most recent ancestor of HEAD that occurs on an upstream branch.
  161. First looks at the current branch's tracking branch, if applicable. If
  162. that doesn't work, looks at every other branch to find the most recent
  163. ancestor of HEAD that occurs on a tracking branch.
  164. Returns:
  165. git.Commit object or None
  166. """
  167. possible_relatives = []
  168. try:
  169. if not self.repo:
  170. return None
  171. try:
  172. active_branch = self.repo.active_branch
  173. except (TypeError, ValueError):
  174. logger.debug("git is in a detached head state")
  175. return None # detached head
  176. else:
  177. tracking_branch = active_branch.tracking_branch()
  178. if tracking_branch:
  179. possible_relatives.append(tracking_branch.commit)
  180. if not possible_relatives:
  181. for branch in self.repo.branches: # type: ignore[attr-defined]
  182. tracking_branch = branch.tracking_branch()
  183. if tracking_branch is not None:
  184. possible_relatives.append(tracking_branch.commit)
  185. head = self.repo.head
  186. most_recent_ancestor = None
  187. for possible_relative in possible_relatives:
  188. # at most one:
  189. for ancestor in self.repo.merge_base(head, possible_relative):
  190. if most_recent_ancestor is None:
  191. most_recent_ancestor = ancestor
  192. elif self.repo.is_ancestor(most_recent_ancestor, ancestor): # type: ignore
  193. most_recent_ancestor = ancestor
  194. except GitCommandError as e:
  195. logger.debug("git remote upstream fork point could not be found")
  196. logger.debug(str(e))
  197. return None
  198. return most_recent_ancestor
  199. def tag(self, name: str, message: str | None) -> Any:
  200. if not self.repo:
  201. return None
  202. try:
  203. return self.repo.create_tag(f"wandb/{name}", message=message, force=True)
  204. except GitCommandError:
  205. logger.debug("Failed to tag repository.")
  206. return None
  207. def push(self, name: str) -> Any:
  208. if not self.remote:
  209. return None
  210. try:
  211. return self.remote.push(f"wandb/{name}", force=True)
  212. except GitCommandError:
  213. logger.debug("failed to push git")
  214. return None