_verification.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import re
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. from typing import TYPE_CHECKING, Literal, TypedDict
  5. from .. import constants
  6. from ..file_download import repo_folder_name
  7. from .sha import git_hash, sha_fileobj
  8. if TYPE_CHECKING:
  9. from ..hf_api import RepoFile
  10. # using fullmatch for clarity and strictness
  11. _REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
  12. # Typed structure describing a checksum mismatch
  13. class Mismatch(TypedDict):
  14. path: str
  15. expected: str
  16. actual: str
  17. algorithm: str
  18. HashAlgo = Literal["sha256", "git-sha1"]
  19. @dataclass(frozen=True)
  20. class FolderVerification:
  21. revision: str
  22. checked_count: int
  23. mismatches: list[Mismatch]
  24. missing_paths: list[str]
  25. extra_paths: list[str]
  26. verified_path: Path
  27. def collect_local_files(root: Path) -> dict[str, Path]:
  28. """
  29. Return a mapping of repo-relative path -> absolute path for all files under `root`.
  30. """
  31. return {p.relative_to(root).as_posix(): p for p in root.rglob("*") if p.is_file()}
  32. def _resolve_commit_hash_from_cache(storage_folder: Path, revision: str | None) -> str:
  33. """
  34. Resolve a commit hash from a cache repo folder and an optional revision.
  35. """
  36. if revision and _REGEX_COMMIT_HASH.fullmatch(revision):
  37. return revision
  38. refs_dir = storage_folder / "refs"
  39. snapshots_dir = storage_folder / "snapshots"
  40. if revision:
  41. ref_path = refs_dir / revision
  42. if ref_path.is_file():
  43. return ref_path.read_text(encoding="utf-8").strip()
  44. raise ValueError(f"Revision '{revision}' could not be resolved in cache (expected file '{ref_path}').")
  45. # No revision provided: try common defaults
  46. main_ref = refs_dir / "main"
  47. if main_ref.is_file():
  48. return main_ref.read_text(encoding="utf-8").strip()
  49. if not snapshots_dir.is_dir():
  50. raise ValueError(f"Cache repo is missing snapshots directory: {snapshots_dir}. Provide --revision explicitly.")
  51. candidates = [p.name for p in snapshots_dir.iterdir() if p.is_dir() and _REGEX_COMMIT_HASH.fullmatch(p.name)]
  52. if len(candidates) == 1:
  53. return candidates[0]
  54. raise ValueError(
  55. "Ambiguous cached revision: multiple snapshots found and no refs to disambiguate. Please pass --revision."
  56. )
  57. def compute_file_hash(path: Path, algorithm: HashAlgo) -> str:
  58. """
  59. Compute the checksum of a local file using the requested algorithm.
  60. """
  61. with path.open("rb") as stream:
  62. if algorithm == "sha256":
  63. return sha_fileobj(stream).hex()
  64. if algorithm == "git-sha1":
  65. return git_hash(stream.read())
  66. raise ValueError(f"Unsupported hash algorithm: {algorithm}")
  67. def verify_maps(
  68. *,
  69. remote_by_path: dict[str, "RepoFile"],
  70. local_by_path: dict[str, Path],
  71. revision: str,
  72. verified_path: Path,
  73. ) -> FolderVerification:
  74. """Compare remote entries and local files and return a verification result."""
  75. remote_paths = set(remote_by_path)
  76. local_paths = set(local_by_path)
  77. missing = sorted(remote_paths - local_paths)
  78. extra = sorted(local_paths - remote_paths)
  79. both = sorted(remote_paths & local_paths)
  80. mismatches: list[Mismatch] = []
  81. for rel_path in both:
  82. remote_entry = remote_by_path[rel_path]
  83. local_path = local_by_path[rel_path]
  84. lfs = getattr(remote_entry, "lfs", None)
  85. lfs_sha = getattr(lfs, "sha256", None) if lfs is not None else None
  86. if lfs_sha is None and isinstance(lfs, dict):
  87. lfs_sha = lfs.get("sha256")
  88. if lfs_sha:
  89. algorithm: HashAlgo = "sha256"
  90. expected = str(lfs_sha).lower()
  91. else:
  92. blob_id = remote_entry.blob_id # type: ignore
  93. algorithm = "git-sha1"
  94. expected = str(blob_id).lower()
  95. actual = compute_file_hash(local_path, algorithm)
  96. if actual != expected:
  97. mismatches.append(Mismatch(path=rel_path, expected=expected, actual=actual, algorithm=algorithm))
  98. return FolderVerification(
  99. revision=revision,
  100. checked_count=len(both),
  101. mismatches=mismatches,
  102. missing_paths=missing,
  103. extra_paths=extra,
  104. verified_path=verified_path,
  105. )
  106. def resolve_local_root(
  107. *,
  108. repo_id: str,
  109. repo_type: str,
  110. revision: str | None,
  111. cache_dir: Path | None,
  112. local_dir: Path | None,
  113. ) -> tuple[Path, str]:
  114. """
  115. Resolve the root directory to scan locally and the remote revision to verify.
  116. """
  117. if local_dir is not None:
  118. root = Path(local_dir).expanduser().resolve()
  119. if not root.is_dir():
  120. raise ValueError(f"Local directory does not exist or is not a directory: {root}")
  121. return root, (revision or constants.DEFAULT_REVISION)
  122. cache_root = Path(cache_dir or constants.HF_HUB_CACHE).expanduser().resolve()
  123. storage_folder = cache_root / repo_folder_name(repo_id=repo_id, repo_type=repo_type)
  124. if not storage_folder.exists():
  125. raise ValueError(
  126. f"Repo is not present in cache: {storage_folder}. Use 'hf download' first or pass --local-dir."
  127. )
  128. commit = _resolve_commit_hash_from_cache(storage_folder, revision)
  129. snapshot_dir = storage_folder / "snapshots" / commit
  130. if not snapshot_dir.is_dir():
  131. raise ValueError(f"Snapshot directory does not exist for revision '{commit}': {snapshot_dir}.")
  132. return snapshot_dir, commit