git_reference.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """Support for parsing GitHub URLs (which might be user provided) into constituent parts."""
  2. from __future__ import annotations
  3. import re
  4. from dataclasses import dataclass
  5. from enum import IntEnum
  6. from wandb.sdk.launch.errors import LaunchError
  7. PREFIX_HTTPS = "https://"
  8. PREFIX_SSH = "git@"
  9. SUFFIX_GIT = ".git"
  10. GIT_COMMIT_REGEX = re.compile(r"[0-9a-f]{40}")
  11. class ReferenceType(IntEnum):
  12. BRANCH = 1
  13. COMMIT = 2
  14. def _parse_netloc(netloc: str) -> tuple[str | None, str | None, str]:
  15. """Parse netloc into username, password, and host.
  16. github.com => None, None, "@github.com"
  17. username@github.com => "username", None, "github.com"
  18. username:password@github.com => "username", "password", "github.com"
  19. """
  20. parts = netloc.split("@", 1)
  21. if len(parts) == 1:
  22. return None, None, parts[0]
  23. auth, host = parts
  24. parts = auth.split(":", 1)
  25. if len(parts) == 1:
  26. return parts[0], None, host
  27. return parts[0], parts[1], host
  28. @dataclass
  29. class GitReference:
  30. def __init__(self, remote: str, ref: str | None = None) -> None:
  31. """Initialize a reference from a remote and ref.
  32. Arguments:
  33. remote: A remote URL or URI.
  34. ref: A branch, tag, or commit hash.
  35. """
  36. self.uri = remote
  37. self.ref = ref
  38. @property
  39. def url(self) -> str | None:
  40. return self.uri
  41. def fetch(self, dst_dir: str) -> None:
  42. """Fetch the repo into dst_dir and refine githubref based on what we learn."""
  43. # We defer importing git until the last moment, because the import requires that the git
  44. # executable is available on the PATH, so we only want to fail if we actually need it.
  45. import git # type: ignore
  46. repo = git.Repo.init(dst_dir)
  47. self.path = repo.working_dir
  48. origin = repo.create_remote("origin", self.uri or "")
  49. try:
  50. # We fetch the origin so that we have branch and tag references
  51. origin.fetch()
  52. except git.exc.GitCommandError as e:
  53. raise LaunchError(
  54. f"Unable to fetch from git remote repository {self.url}:\n{e}"
  55. )
  56. ref: git.RemoteReference | str
  57. if self.ref:
  58. if self.ref in origin.refs:
  59. ref = origin.refs[self.ref]
  60. else:
  61. ref = self.ref
  62. head = repo.create_head(self.ref, ref)
  63. head.checkout()
  64. self.commit_hash = head.commit.hexsha
  65. else:
  66. # TODO: Is there a better way to do this?
  67. default_branch = None
  68. for ref in repo.references:
  69. if hasattr(ref, "tag"): # Skip tag references
  70. continue
  71. refname = ref.name
  72. if refname.startswith("origin/"): # Trim off "origin/"
  73. refname = refname[7:]
  74. if refname == "main":
  75. default_branch = "main"
  76. break
  77. if refname == "master":
  78. default_branch = "master"
  79. # Keep looking in case we also have a main, which we let take precedence
  80. # (While the references appear to be sorted, not clear if that's guaranteed.)
  81. if not default_branch:
  82. raise LaunchError(
  83. f"Unable to determine branch or commit to checkout from {self.url}"
  84. )
  85. self.default_branch = default_branch
  86. self.ref = default_branch
  87. head = repo.create_head(default_branch, origin.refs[default_branch])
  88. head.checkout()
  89. self.commit_hash = head.commit.hexsha
  90. repo.submodule_update(init=True, recursive=True)