| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- from __future__ import annotations
- import configparser
- import logging
- import os
- from typing import TYPE_CHECKING, Any
- from urllib.parse import urlparse, urlunparse
- import wandb
- try:
- from git import ( # type: ignore
- GitCommandError,
- InvalidGitRepositoryError,
- NoSuchPathError,
- Repo,
- )
- except ImportError:
- pass
- if TYPE_CHECKING:
- from git import Repo
- logger = logging.getLogger(__name__)
- class GitRepo:
- def __init__(
- self,
- root: str | None = None,
- remote: str = "origin",
- lazy: bool = True,
- remote_url: str | None = None,
- commit: str | None = None,
- ) -> None:
- self.remote_name = remote if remote_url is None else None
- self._root = root
- self._remote_url = remote_url
- self._commit = commit
- self._repo = None
- self._repo_initialized = False
- if not lazy:
- self._repo = self._init_repo()
- def _init_repo(self) -> Repo | None:
- self._repo_initialized = True
- try:
- from git import Repo
- except ImportError:
- return None
- if self.remote_name is None:
- return None
- try:
- return Repo(self._root or os.getcwd(), search_parent_directories=True)
- except FileNotFoundError:
- wandb.termwarn("current working directory has been invalidated")
- logger.warning("current working directory has been invalidated")
- except InvalidGitRepositoryError:
- logger.debug("git repository is invalid")
- except NoSuchPathError:
- wandb.termwarn(f"git root {self._root} does not exist")
- logger.warning(f"git root {self._root} does not exist")
- return None
- @property
- def repo(self) -> Repo | None:
- if not self._repo_initialized:
- self._repo = self._init_repo()
- return self._repo
- @property
- def auto(self) -> bool:
- return self._remote_url is None
- def is_untracked(self, file_name: str) -> bool | None:
- if not self.repo:
- return True
- try:
- return file_name in self.repo.untracked_files
- except GitCommandError:
- return None
- @property
- def enabled(self) -> bool:
- return bool(self.repo)
- @property
- def root(self) -> Any:
- if not self.repo:
- return None
- try:
- return self.repo.git.rev_parse("--show-toplevel")
- except GitCommandError:
- # todo: collect telemetry on this
- logger.exception("git root error")
- return None
- @property
- def dirty(self) -> Any:
- if not self.repo:
- return False
- try:
- return self.repo.is_dirty()
- except GitCommandError:
- return False
- @property
- def email(self) -> str | None:
- if not self.repo:
- return None
- try:
- return self.repo.config_reader().get_value("user", "email") # type: ignore
- except configparser.Error:
- return None
- @property
- def last_commit(self) -> Any:
- if self._commit:
- return self._commit
- if not self.repo:
- return None
- if not self.repo.head or not self.repo.head.is_valid():
- return None
- # TODO: Saw a user getting a Unicode decode error when parsing refs,
- # more details on implementing a real fix in [WB-4064]
- try:
- if len(self.repo.refs) > 0: # type: ignore[arg-type]
- return self.repo.head.commit.hexsha
- else:
- return self.repo.git.show_ref("--head").split(" ")[0]
- except Exception:
- logger.exception("Unable to find most recent commit in git")
- return None
- @property
- def branch(self) -> Any:
- if not self.repo:
- return None
- return self.repo.head.ref.name
- @property
- def remote(self) -> Any:
- if not self.repo:
- return None
- try:
- return self.repo.remotes[self.remote_name] # type: ignore[index]
- except IndexError:
- return None
- # the --submodule=diff option doesn't exist in pre-2.11 versions of git (november 2016)
- # https://stackoverflow.com/questions/10757091/git-list-of-all-changed-files-including-those-in-submodules
- @property
- def has_submodule_diff(self) -> bool:
- if not self.repo:
- return False
- return bool(self.repo.git.version_info >= (2, 11, 0))
- @property
- def remote_url(self) -> Any:
- if self._remote_url:
- return self._remote_url
- if not self.remote:
- return None
- parsed = urlparse(self.remote.url)
- hostname = parsed.hostname
- if parsed.port is not None:
- hostname = f"{hostname}:{parsed.port}"
- if parsed.password is not None:
- return urlunparse(parsed._replace(netloc=f"{parsed.username}:@{hostname}"))
- return urlunparse(parsed._replace(netloc=hostname))
- @property
- def root_dir(self) -> Any:
- if not self.repo:
- return None
- try:
- return self.repo.git.rev_parse("--show-toplevel")
- except GitCommandError:
- return None
- def get_upstream_fork_point(self) -> Any:
- """Get the most recent ancestor of HEAD that occurs on an upstream branch.
- First looks at the current branch's tracking branch, if applicable. If
- that doesn't work, looks at every other branch to find the most recent
- ancestor of HEAD that occurs on a tracking branch.
- Returns:
- git.Commit object or None
- """
- possible_relatives = []
- try:
- if not self.repo:
- return None
- try:
- active_branch = self.repo.active_branch
- except (TypeError, ValueError):
- logger.debug("git is in a detached head state")
- return None # detached head
- else:
- tracking_branch = active_branch.tracking_branch()
- if tracking_branch:
- possible_relatives.append(tracking_branch.commit)
- if not possible_relatives:
- for branch in self.repo.branches: # type: ignore[attr-defined]
- tracking_branch = branch.tracking_branch()
- if tracking_branch is not None:
- possible_relatives.append(tracking_branch.commit)
- head = self.repo.head
- most_recent_ancestor = None
- for possible_relative in possible_relatives:
- # at most one:
- for ancestor in self.repo.merge_base(head, possible_relative):
- if most_recent_ancestor is None:
- most_recent_ancestor = ancestor
- elif self.repo.is_ancestor(most_recent_ancestor, ancestor): # type: ignore
- most_recent_ancestor = ancestor
- except GitCommandError as e:
- logger.debug("git remote upstream fork point could not be found")
- logger.debug(str(e))
- return None
- return most_recent_ancestor
- def tag(self, name: str, message: str | None) -> Any:
- if not self.repo:
- return None
- try:
- return self.repo.create_tag(f"wandb/{name}", message=message, force=True)
- except GitCommandError:
- logger.debug("Failed to tag repository.")
- return None
- def push(self, name: str) -> Any:
- if not self.remote:
- return None
- try:
- return self.remote.push(f"wandb/{name}", force=True)
- except GitCommandError:
- logger.debug("failed to push git")
- return None
|