_skills.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. """Internal helpers for Hugging Face marketplace skill installation and upgrades."""
  2. import base64
  3. import io
  4. import json
  5. import shutil
  6. import tarfile
  7. import tempfile
  8. from dataclasses import dataclass, replace
  9. from pathlib import Path, PurePosixPath
  10. from typing import Any, Literal
  11. from huggingface_hub.errors import CLIError
  12. from huggingface_hub.utils import get_session
  13. DEFAULT_SKILLS_REPO_ID = "huggingface/skills"
  14. DEFAULT_SKILLS_REPO_OWNER, DEFAULT_SKILLS_REPO_NAME = DEFAULT_SKILLS_REPO_ID.split("/")
  15. DEFAULT_SKILLS_REF = "main"
  16. MARKETPLACE_PATH = ".claude-plugin/marketplace.json"
  17. GITHUB_API_TIMEOUT = 10
  18. SKILL_MANIFEST_FILENAME = ".hf-skill-manifest.json"
  19. SKILL_MANIFEST_SCHEMA_VERSION = 1
  20. SkillUpdateStatus = Literal[
  21. "up_to_date",
  22. "update_available",
  23. "updated",
  24. "unmanaged",
  25. "invalid_metadata",
  26. "source_unreachable",
  27. ]
  28. @dataclass(frozen=True)
  29. class MarketplaceSkill:
  30. name: str
  31. repo_path: str
  32. @dataclass(frozen=True)
  33. class InstalledSkillManifest:
  34. schema_version: int
  35. installed_revision: str
  36. @dataclass(frozen=True)
  37. class SkillUpdateInfo:
  38. name: str
  39. skill_dir: Path
  40. status: SkillUpdateStatus
  41. detail: str | None = None
  42. current_revision: str | None = None
  43. available_revision: str | None = None
  44. def load_marketplace_skills() -> list[MarketplaceSkill]:
  45. """Load skills from the default Hugging Face marketplace."""
  46. payload = _load_marketplace_payload()
  47. plugins = payload.get("plugins")
  48. if not isinstance(plugins, list):
  49. raise CLIError("Invalid marketplace payload: expected a top-level 'plugins' list.")
  50. skills: list[MarketplaceSkill] = []
  51. for plugin in plugins:
  52. if not isinstance(plugin, dict):
  53. continue
  54. name = plugin.get("name")
  55. source = plugin.get("source")
  56. if not isinstance(name, str) or not isinstance(source, str):
  57. continue
  58. skills.append(MarketplaceSkill(name=name, repo_path=_normalize_repo_path(source)))
  59. return skills
  60. def get_marketplace_skill(selector: str) -> MarketplaceSkill:
  61. """Resolve a marketplace skill by name."""
  62. selected = _select_marketplace_skill(load_marketplace_skills(), selector)
  63. if selected is None:
  64. raise CLIError(
  65. f"Skill '{selector}' not found in {DEFAULT_SKILLS_REPO_ID}. "
  66. "Try `hf skills add` to install `hf-cli` or use a known skill name."
  67. )
  68. return selected
  69. def install_marketplace_skill(skill: MarketplaceSkill, destination_root: Path, force: bool = False) -> Path:
  70. """Install a marketplace skill into a local skills directory."""
  71. destination_root = destination_root.expanduser().resolve()
  72. destination_root.mkdir(parents=True, exist_ok=True)
  73. install_dir = destination_root / skill.name
  74. if install_dir.exists() and not force:
  75. raise FileExistsError(f"Skill already exists: {install_dir}")
  76. if install_dir.exists():
  77. with tempfile.TemporaryDirectory(dir=destination_root, prefix=f".{install_dir.name}.install-") as tmp_dir_str:
  78. tmp_dir = Path(tmp_dir_str)
  79. staged_dir = tmp_dir / install_dir.name
  80. _populate_install_dir(skill=skill, install_dir=staged_dir)
  81. _atomic_replace_directory(existing_dir=install_dir, staged_dir=staged_dir)
  82. return install_dir
  83. try:
  84. _populate_install_dir(skill=skill, install_dir=install_dir)
  85. except Exception:
  86. if install_dir.exists():
  87. shutil.rmtree(install_dir)
  88. raise
  89. return install_dir
  90. def check_for_updates(
  91. roots: list[Path],
  92. selector: str | None = None,
  93. ) -> list[SkillUpdateInfo]:
  94. """Check managed skill installs for newer upstream revisions."""
  95. marketplace_skills = {skill.name.lower(): skill for skill in load_marketplace_skills()}
  96. updates = [_evaluate_update(skill_dir, marketplace_skills) for skill_dir in _iter_unique_skill_dirs(roots)]
  97. filtered = _filter_updates(updates, selector)
  98. if selector is not None and not filtered:
  99. raise CLIError(f"No installed skills match '{selector}'.")
  100. return filtered
  101. def apply_updates(
  102. roots: list[Path],
  103. selector: str | None = None,
  104. ) -> list[SkillUpdateInfo]:
  105. """Upgrade managed skills in place when the upstream revision changes."""
  106. updates = check_for_updates(roots, selector)
  107. results: list[SkillUpdateInfo] = []
  108. for update in updates:
  109. results.append(_apply_single_update(update))
  110. return results
  111. def read_installed_skill_manifest(skill_dir: Path) -> tuple[InstalledSkillManifest | None, str | None]:
  112. """Read local skill metadata written by `hf skills add`."""
  113. manifest_path = skill_dir / SKILL_MANIFEST_FILENAME
  114. if not manifest_path.exists():
  115. return None, None
  116. try:
  117. payload = json.loads(manifest_path.read_text(encoding="utf-8"))
  118. except Exception as exc: # noqa: BLE001
  119. return None, f"invalid json: {exc}"
  120. if not isinstance(payload, dict):
  121. return None, "metadata root must be an object"
  122. try:
  123. return _parse_installed_skill_manifest(payload), None
  124. except ValueError as exc:
  125. return None, str(exc)
  126. def write_installed_skill_manifest(skill_dir: Path, manifest: InstalledSkillManifest) -> None:
  127. payload = {
  128. "schema_version": manifest.schema_version,
  129. "installed_revision": manifest.installed_revision,
  130. }
  131. (skill_dir / SKILL_MANIFEST_FILENAME).write_text(
  132. json.dumps(payload, indent=2, sort_keys=True) + "\n",
  133. encoding="utf-8",
  134. )
  135. def _load_marketplace_payload() -> dict[str, Any]:
  136. response = _fetch_from_skills_repo(
  137. f"contents/{MARKETPLACE_PATH}",
  138. params={"ref": DEFAULT_SKILLS_REF},
  139. )
  140. try:
  141. payload = response.json()
  142. except Exception as exc: # noqa: BLE001
  143. raise CLIError(f"Failed to decode GitHub API response for 'contents/{MARKETPLACE_PATH}': {exc}") from exc
  144. if not isinstance(payload, dict):
  145. raise CLIError("Invalid marketplace response: expected a JSON object.")
  146. content = payload.get("content")
  147. encoding = payload.get("encoding")
  148. if not isinstance(content, str) or encoding != "base64":
  149. raise CLIError("Invalid marketplace payload: expected base64-encoded content.")
  150. try:
  151. decoded = base64.b64decode(content).decode("utf-8")
  152. parsed = json.loads(decoded)
  153. except Exception as exc: # noqa: BLE001
  154. raise CLIError(f"Failed to decode marketplace payload: {exc}") from exc
  155. if not isinstance(parsed, dict):
  156. raise CLIError("Invalid marketplace payload: expected a JSON object.")
  157. return parsed
  158. def _select_marketplace_skill(skills: list[MarketplaceSkill], selector: str) -> MarketplaceSkill | None:
  159. selector_lower = selector.strip().lower()
  160. for skill in skills:
  161. if skill.name.lower() == selector_lower:
  162. return skill
  163. return None
  164. def _normalize_repo_path(path: str) -> str:
  165. normalized = path.strip()
  166. while normalized.startswith("./"):
  167. normalized = normalized[2:]
  168. normalized = normalized.strip("/")
  169. if not normalized:
  170. raise CLIError("Invalid marketplace entry: empty source path.")
  171. return normalized
  172. def _populate_install_dir(skill: MarketplaceSkill, install_dir: Path) -> None:
  173. installed_revision = _resolve_available_revision(skill)
  174. install_dir.mkdir(parents=True, exist_ok=True)
  175. _extract_remote_github_path(
  176. revision=installed_revision,
  177. source_path=skill.repo_path,
  178. install_dir=install_dir,
  179. )
  180. _validate_installed_skill_dir(install_dir)
  181. write_installed_skill_manifest(
  182. install_dir,
  183. InstalledSkillManifest(
  184. schema_version=SKILL_MANIFEST_SCHEMA_VERSION,
  185. installed_revision=installed_revision,
  186. ),
  187. )
  188. def _validate_installed_skill_dir(skill_dir: Path) -> None:
  189. skill_file = skill_dir / "SKILL.md"
  190. if not skill_file.is_file():
  191. raise RuntimeError(f"Installed skill is missing SKILL.md: {skill_file}")
  192. def _extract_remote_github_path(revision: str, source_path: str, install_dir: Path) -> None:
  193. tar_bytes = _fetch_from_skills_repo(f"tarball/{revision}").content
  194. _extract_tar_subpath(tar_bytes, source_path=source_path, install_dir=install_dir)
  195. def _extract_tar_subpath(tar_bytes: bytes, source_path: str, install_dir: Path) -> None:
  196. """Extract a skill subdirectory from a tar archive.
  197. GitHub tarballs include a leading `<repo>-<revision>/` directory. The helper also
  198. accepts archives that start directly at `skills/<name>/...` to keep tests simple.
  199. """
  200. source_parts = PurePosixPath(source_path).parts
  201. with tarfile.open(fileobj=io.BytesIO(tar_bytes), mode="r:*") as archive:
  202. members = archive.getmembers()
  203. matched = False
  204. for member in members:
  205. relative_parts = _member_relative_parts(member_name=member.name, source_parts=source_parts)
  206. if relative_parts is None:
  207. continue
  208. if not relative_parts:
  209. matched = True
  210. continue
  211. matched = True
  212. relative_path = Path(*relative_parts)
  213. if ".." in relative_path.parts:
  214. raise RuntimeError(f"Invalid path found in archive for {source_path}.")
  215. destination_path = install_dir / relative_path
  216. if member.isdir():
  217. destination_path.mkdir(parents=True, exist_ok=True)
  218. continue
  219. if not member.isfile():
  220. continue
  221. destination_path.parent.mkdir(parents=True, exist_ok=True)
  222. extracted = archive.extractfile(member)
  223. if extracted is None:
  224. raise RuntimeError(f"Failed to extract {member.name}.")
  225. destination_path.write_bytes(extracted.read())
  226. if not matched:
  227. raise FileNotFoundError(f"Path '{source_path}' not found in source archive.")
  228. def _member_relative_parts(member_name: str, source_parts: tuple[str, ...]) -> tuple[str, ...] | None:
  229. path_parts = PurePosixPath(member_name).parts
  230. if tuple(path_parts[: len(source_parts)]) == source_parts:
  231. return path_parts[len(source_parts) :]
  232. if len(path_parts) > len(source_parts) and tuple(path_parts[1 : 1 + len(source_parts)]) == source_parts:
  233. return path_parts[1 + len(source_parts) :]
  234. return None
  235. def _atomic_replace_directory(existing_dir: Path, staged_dir: Path) -> None:
  236. backup_dir = staged_dir.parent / f"{existing_dir.name}.backup"
  237. try:
  238. existing_dir.rename(backup_dir)
  239. staged_dir.rename(existing_dir)
  240. shutil.rmtree(backup_dir)
  241. except Exception:
  242. if backup_dir.exists() and not existing_dir.exists():
  243. backup_dir.rename(existing_dir)
  244. raise
  245. def _iter_unique_skill_dirs(roots: list[Path]) -> list[Path]:
  246. seen: set[Path] = set()
  247. discovered: list[Path] = []
  248. for root in roots:
  249. root = root.expanduser().resolve()
  250. if not root.is_dir():
  251. continue
  252. for child in sorted(root.iterdir()):
  253. if child.name.startswith("."):
  254. continue
  255. if not child.is_dir() and not child.is_symlink():
  256. continue
  257. resolved = child.resolve()
  258. if resolved in seen or not resolved.is_dir():
  259. continue
  260. seen.add(resolved)
  261. discovered.append(resolved)
  262. return discovered
  263. def _evaluate_update(skill_dir: Path, marketplace_skills: dict[str, MarketplaceSkill]) -> SkillUpdateInfo:
  264. base = SkillUpdateInfo(name=skill_dir.name, skill_dir=skill_dir, status="unmanaged")
  265. manifest, error = read_installed_skill_manifest(skill_dir)
  266. if manifest is None:
  267. return replace(base, status="invalid_metadata" if error else "unmanaged", detail=error)
  268. skill = marketplace_skills.get(skill_dir.name.lower())
  269. if skill is None:
  270. return replace(
  271. base,
  272. status="source_unreachable",
  273. detail=f"Skill '{skill_dir.name}' is no longer available in {DEFAULT_SKILLS_REPO_ID}.",
  274. current_revision=manifest.installed_revision,
  275. )
  276. current_revision = manifest.installed_revision
  277. try:
  278. available_revision = _resolve_available_revision(skill)
  279. except Exception as exc:
  280. return replace(base, status="source_unreachable", detail=str(exc), current_revision=current_revision)
  281. status: SkillUpdateStatus = "up_to_date" if available_revision == current_revision else "update_available"
  282. return replace(
  283. base,
  284. status=status,
  285. detail="update available" if status == "update_available" else None,
  286. current_revision=current_revision,
  287. available_revision=available_revision,
  288. )
  289. def _apply_single_update(update: SkillUpdateInfo) -> SkillUpdateInfo:
  290. if update.status != "update_available":
  291. return update
  292. try:
  293. skill = get_marketplace_skill(update.skill_dir.name)
  294. install_marketplace_skill(skill, update.skill_dir.parent, force=True)
  295. except Exception as exc:
  296. return replace(update, status="source_unreachable", detail=str(exc))
  297. return replace(update, status="updated", detail="updated")
  298. def _filter_updates(updates: list[SkillUpdateInfo], selector: str | None) -> list[SkillUpdateInfo]:
  299. if selector is None:
  300. return updates
  301. selector_lower = selector.strip().lower()
  302. return [update for update in updates if update.name.lower() == selector_lower]
  303. def _resolve_available_revision(skill: MarketplaceSkill) -> str:
  304. response = _fetch_from_skills_repo(
  305. "commits",
  306. params={"sha": DEFAULT_SKILLS_REF, "path": skill.repo_path, "per_page": 1},
  307. )
  308. try:
  309. payload = response.json()
  310. except Exception as exc: # noqa: BLE001
  311. raise CLIError(f"Failed to decode GitHub API response for 'commits': {exc}") from exc
  312. if not isinstance(payload, list) or not payload:
  313. raise CLIError(f"Unable to resolve the current revision for skill '{skill.name}'.")
  314. latest = payload[0]
  315. if not isinstance(latest, dict):
  316. raise CLIError(f"Invalid commit response while resolving skill '{skill.name}'.")
  317. revision = latest.get("sha")
  318. if not isinstance(revision, str) or not revision:
  319. raise CLIError(f"Invalid commit response while resolving skill '{skill.name}'.")
  320. return revision
  321. def _parse_installed_skill_manifest(payload: dict[str, Any]) -> InstalledSkillManifest:
  322. if payload.get("schema_version") != SKILL_MANIFEST_SCHEMA_VERSION:
  323. raise ValueError(f"unsupported schema_version: {payload.get('schema_version')}")
  324. installed_revision = payload.get("installed_revision")
  325. if not isinstance(installed_revision, str) or not installed_revision:
  326. raise ValueError("missing installed_revision")
  327. return InstalledSkillManifest(
  328. schema_version=SKILL_MANIFEST_SCHEMA_VERSION,
  329. installed_revision=installed_revision,
  330. )
  331. def _fetch_from_skills_repo(endpoint: str, params: dict[str, Any] | None = None) -> Any:
  332. url = f"https://api.github.com/repos/{DEFAULT_SKILLS_REPO_OWNER}/{DEFAULT_SKILLS_REPO_NAME}/{endpoint.lstrip('/')}"
  333. try:
  334. response = get_session().get(
  335. url,
  336. params=params,
  337. headers={"Accept": "application/vnd.github+json"},
  338. follow_redirects=True,
  339. timeout=GITHUB_API_TIMEOUT,
  340. )
  341. response.raise_for_status()
  342. except Exception as exc: # noqa: BLE001
  343. raise CLIError(f"Failed to fetch '{endpoint}' from {DEFAULT_SKILLS_REPO_ID}: {exc}") from exc
  344. return response