safetensors_conversion.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from typing import Optional
  2. import httpx
  3. from huggingface_hub import Discussion, HfApi, get_repo_discussions
  4. from .utils import cached_file, http_user_agent, logging
  5. logger = logging.get_logger(__name__)
  6. def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
  7. main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
  8. for discussion in get_repo_discussions(repo_id=model_id, token=token):
  9. if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
  10. commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)
  11. if main_commit == commits[1].commit_id:
  12. return discussion
  13. return None
  14. def spawn_conversion(token: str, private: bool, model_id: str):
  15. logger.info("Attempting to convert .bin model on the fly to safetensors.")
  16. safetensors_convert_space_url = "https://safetensors-convert.hf.space"
  17. sse_url = f"{safetensors_convert_space_url}/call/run"
  18. def start(_sse_connection):
  19. for line in _sse_connection.iter_lines():
  20. if not isinstance(line, str):
  21. line = line.decode()
  22. if line.startswith("event:"):
  23. status = line[7:]
  24. logger.debug(f"Safetensors conversion status: {status}")
  25. if status == "complete":
  26. return
  27. elif status == "heartbeat":
  28. logger.debug("Heartbeat")
  29. else:
  30. logger.debug(f"Unknown status {status}")
  31. else:
  32. logger.debug(line)
  33. data = {"data": [model_id, private, token]}
  34. result = httpx.post(sse_url, follow_redirects=True, json=data).json()
  35. event_id = result["event_id"]
  36. with httpx.stream("GET", f"{sse_url}/{event_id}") as sse_connection:
  37. try:
  38. logger.debug("Spawning safetensors automatic conversion.")
  39. start(sse_connection)
  40. except Exception as e:
  41. logger.warning(f"Error during conversion: {repr(e)}")
  42. def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
  43. private = api.model_info(model_id).private
  44. logger.info("Attempting to create safetensors variant")
  45. pr_title = "Adding `safetensors` variant of this model"
  46. token = kwargs.get("token")
  47. # This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it
  48. # returns it. It checks that the PR was opened by the bot and not by another user so as to prevent
  49. # security breaches.
  50. pr = previous_pr(api, model_id, pr_title, token=token)
  51. if pr is None or (not private and pr.author != "SFconvertbot"):
  52. spawn_conversion(token, private, model_id)
  53. pr = previous_pr(api, model_id, pr_title, token=token)
  54. else:
  55. logger.info("Safetensors PR exists")
  56. if pr is None:
  57. raise OSError(
  58. "Could not create safetensors conversion PR. The repo does not appear to have a file named pytorch_model.bin or model.safetensors."
  59. "If you are loading with variant, use `use_safetensors=False` to load the original model."
  60. )
  61. sha = f"refs/pr/{pr.num}"
  62. return sha
  63. def auto_conversion(
  64. pretrained_model_name_or_path: str,
  65. ignore_errors_during_conversion: bool = False,
  66. safe_weights_name: str = "model.safetensors",
  67. safe_weights_index_name: str = "model.safetensors.index.json",
  68. **cached_file_kwargs,
  69. ):
  70. try:
  71. api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()})
  72. sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
  73. if sha is None:
  74. return None, None
  75. cached_file_kwargs["revision"] = sha
  76. del cached_file_kwargs["_commit_hash"]
  77. # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
  78. # description.
  79. sharded = api.file_exists(
  80. pretrained_model_name_or_path,
  81. safe_weights_index_name,
  82. revision=sha,
  83. token=cached_file_kwargs.get("token"),
  84. )
  85. filename = safe_weights_index_name if sharded else safe_weights_name
  86. resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
  87. return resolved_archive_file, sha, sharded
  88. except Exception as e:
  89. if not ignore_errors_during_conversion:
  90. raise e