| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834 |
- import inspect
- import json
- import os
- from collections.abc import Callable
- from dataclasses import Field, asdict, dataclass, is_dataclass
- from pathlib import Path
- from typing import Any, ClassVar, Protocol, TypeVar
- import packaging.version
- from . import constants
- from .errors import EntryNotFoundError, HfHubHTTPError
- from .file_download import hf_hub_download
- from .hf_api import HfApi
- from .repocard import ModelCard, ModelCardData
- from .utils import (
- SoftTemporaryDirectory,
- is_jsonable,
- is_safetensors_available,
- is_simple_optional_type,
- is_torch_available,
- logging,
- unwrap_simple_optional_type,
- validate_hf_hub_args,
- )
- if is_torch_available():
- import torch # type: ignore
- if is_safetensors_available():
- import safetensors
- from safetensors.torch import load_model as load_model_as_safetensor
- from safetensors.torch import save_model as save_model_as_safetensor
- logger = logging.get_logger(__name__)
- # Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
- class DataclassInstance(Protocol):
- __dataclass_fields__: ClassVar[dict[str, Field]]
- # Generic variable that is either ModelHubMixin or a subclass thereof
- T = TypeVar("T", bound="ModelHubMixin")
- # Generic variable to represent an args type
- ARGS_T = TypeVar("ARGS_T")
- ENCODER_T = Callable[[ARGS_T], Any]
- DECODER_T = Callable[[Any], ARGS_T]
- CODER_T = tuple[ENCODER_T, DECODER_T]
- DEFAULT_MODEL_CARD = """
- ---
- # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
- # Doc / guide: https://huggingface.co/docs/hub/model-cards
- {{ card_data }}
- ---
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
- - Code: {{ repo_url | default("[More Information Needed]", true) }}
- - Paper: {{ paper_url | default("[More Information Needed]", true) }}
- - Docs: {{ docs_url | default("[More Information Needed]", true) }}
- """
- @dataclass
- class MixinInfo:
- model_card_template: str
- model_card_data: ModelCardData
- docs_url: str | None = None
- paper_url: str | None = None
- repo_url: str | None = None
- class ModelHubMixin:
- """
- A generic mixin to integrate ANY machine learning framework with the Hub.
- To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models
- have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example
- of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions.
- When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to
- `__init__` but to the class definition itself. This is useful to define metadata about the library integrating
- [`ModelHubMixin`].
- For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations).
- Args:
- repo_url (`str`, *optional*):
- URL of the library repository. Used to generate model card.
- paper_url (`str`, *optional*):
- URL of the library paper. Used to generate model card.
- docs_url (`str`, *optional*):
- URL of the library documentation. Used to generate model card.
- model_card_template (`str`, *optional*):
- Template of the model card. Used to generate model card. Defaults to a generic template.
- language (`str` or `list[str]`, *optional*):
- Language supported by the library. Used to generate model card.
- library_name (`str`, *optional*):
- Name of the library integrating ModelHubMixin. Used to generate model card.
- license (`str`, *optional*):
- License of the library integrating ModelHubMixin. Used to generate model card.
- E.g: "apache-2.0"
- license_name (`str`, *optional*):
- Name of the library integrating ModelHubMixin. Used to generate model card.
- Only used if `license` is set to `other`.
- E.g: "coqui-public-model-license".
- license_link (`str`, *optional*):
- URL to the license of the library integrating ModelHubMixin. Used to generate model card.
- Only used if `license` is set to `other` and `license_name` is set.
- E.g: "https://coqui.ai/cpml".
- pipeline_tag (`str`, *optional*):
- Tag of the pipeline. Used to generate model card. E.g. "text-classification".
- tags (`list[str]`, *optional*):
- Tags to be added to the model card. Used to generate model card. E.g. ["computer-vision"]
- coders (`dict[Type, tuple[Callable, Callable]]`, *optional*):
- Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not
- jsonable by default. E.g. dataclasses, argparse.Namespace, OmegaConf, etc.
- Example:
- ```python
- >>> from huggingface_hub import ModelHubMixin
- # Inherit from ModelHubMixin
- >>> class MyCustomModel(
- ... ModelHubMixin,
- ... library_name="my-library",
- ... tags=["computer-vision"],
- ... repo_url="https://github.com/huggingface/my-cool-library",
- ... paper_url="https://arxiv.org/abs/2304.12244",
- ... docs_url="https://huggingface.co/docs/my-cool-library",
- ... # ^ optional metadata to generate model card
- ... ):
- ... def __init__(self, size: int = 512, device: str = "cpu"):
- ... # define how to initialize your model
- ... super().__init__()
- ... ...
- ...
- ... def _save_pretrained(self, save_directory: Path) -> None:
- ... # define how to serialize your model
- ... ...
- ...
- ... @classmethod
- ... def from_pretrained(
- ... cls: type[T],
- ... pretrained_model_name_or_path: Union[str, Path],
- ... *,
- ... force_download: bool = False,
- ... token: Optional[Union[str, bool]] = None,
- ... cache_dir: Optional[Union[str, Path]] = None,
- ... local_files_only: bool = False,
- ... revision: Optional[str] = None,
- ... **model_kwargs,
- ... ) -> T:
- ... # define how to deserialize your model
- ... ...
- >>> model = MyCustomModel(size=256, device="gpu")
- # Save model weights to local directory
- >>> model.save_pretrained("my-awesome-model")
- # Push model weights to the Hub
- >>> model.push_to_hub("my-awesome-model")
- # Download and initialize weights from the Hub
- >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model")
- >>> reloaded_model.size
- 256
- # Model card has been correctly populated
- >>> from huggingface_hub import ModelCard
- >>> card = ModelCard.load("username/my-awesome-model")
- >>> card.data.tags
- ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"]
- >>> card.data.library_name
- "my-library"
- ```
- """
- _hub_mixin_config: dict | DataclassInstance | None = None
- # ^ optional config attribute automatically set in `from_pretrained`
- _hub_mixin_info: MixinInfo
- # ^ information about the library integrating ModelHubMixin (used to generate model card)
- _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not
- _hub_mixin_init_parameters: dict[str, inspect.Parameter] # __init__ parameters
- _hub_mixin_jsonable_default_values: dict[str, Any] # default values for __init__ parameters
- _hub_mixin_jsonable_custom_types: tuple[type, ...] # custom types that can be encoded/decoded
- _hub_mixin_coders: dict[type, CODER_T] # encoders/decoders for custom types
- # ^ internal values to handle config
- def __init_subclass__(
- cls,
- *,
- # Generic info for model card
- repo_url: str | None = None,
- paper_url: str | None = None,
- docs_url: str | None = None,
- # Model card template
- model_card_template: str = DEFAULT_MODEL_CARD,
- # Model card metadata
- language: list[str] | None = None,
- library_name: str | None = None,
- license: str | None = None,
- license_name: str | None = None,
- license_link: str | None = None,
- pipeline_tag: str | None = None,
- tags: list[str] | None = None,
- # How to encode/decode arguments with custom type into a JSON config?
- coders: None
- | (
- dict[type, CODER_T]
- # Key is a type.
- # Value is a tuple (encoder, decoder).
- # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))}
- ) = None,
- ) -> None:
- """Inspect __init__ signature only once when subclassing + handle modelcard."""
- super().__init_subclass__()
- # Will be reused when creating modelcard
- tags = tags or []
- tags.append("model_hub_mixin")
- # Initialize MixinInfo if not existent
- info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())
- # If parent class has a MixinInfo, inherit from it as a copy
- if hasattr(cls, "_hub_mixin_info"):
- # Inherit model card template from parent class if not explicitly set
- if model_card_template == DEFAULT_MODEL_CARD:
- info.model_card_template = cls._hub_mixin_info.model_card_template
- # Inherit from parent model card data
- info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())
- # Inherit other info
- info.docs_url = cls._hub_mixin_info.docs_url
- info.paper_url = cls._hub_mixin_info.paper_url
- info.repo_url = cls._hub_mixin_info.repo_url
- cls._hub_mixin_info = info
- # Update MixinInfo with metadata
- if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD:
- info.model_card_template = model_card_template
- if repo_url is not None:
- info.repo_url = repo_url
- if paper_url is not None:
- info.paper_url = paper_url
- if docs_url is not None:
- info.docs_url = docs_url
- if language is not None:
- info.model_card_data.language = language
- if library_name is not None:
- info.model_card_data.library_name = library_name
- if license is not None:
- info.model_card_data.license = license
- if license_name is not None:
- info.model_card_data.license_name = license_name
- if license_link is not None:
- info.model_card_data.license_link = license_link
- if pipeline_tag is not None:
- info.model_card_data.pipeline_tag = pipeline_tag
- if tags is not None:
- normalized_tags = list(tags)
- if info.model_card_data.tags is not None:
- info.model_card_data.tags.extend(normalized_tags)
- else:
- info.model_card_data.tags = normalized_tags
- if info.model_card_data.tags is not None:
- info.model_card_data.tags = sorted(set(info.model_card_data.tags))
- # Handle encoders/decoders for args
- cls._hub_mixin_coders = coders or {}
- cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())
- # Inspect __init__ signature to handle config
- cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters)
- cls._hub_mixin_jsonable_default_values = {
- param.name: cls._encode_arg(param.default)
- for param in cls._hub_mixin_init_parameters.values()
- if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default)
- }
- cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters
- def __new__(cls: type[T], *args, **kwargs) -> T:
- """Create a new instance of the class and handle config.
- 3 cases:
- - If `self._hub_mixin_config` is already set, do nothing.
- - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`.
- - Otherwise, build `self._hub_mixin_config` from default values and passed values.
- """
- instance = super().__new__(cls)
- # If `config` is already set, return early
- if instance._hub_mixin_config is not None:
- return instance
- # Infer passed values
- passed_values = {
- **{
- key: value
- for key, value in zip(
- # [1:] to skip `self` parameter
- list(cls._hub_mixin_init_parameters)[1:],
- args,
- )
- },
- **kwargs,
- }
- # If config passed as dataclass => set it and return early
- if is_dataclass(passed_values.get("config")):
- instance._hub_mixin_config = passed_values["config"]
- return instance
- # Otherwise, build config from default + passed values
- init_config = {
- # default values
- **cls._hub_mixin_jsonable_default_values,
- # passed values
- **{
- key: cls._encode_arg(value) # Encode custom types as jsonable value
- for key, value in passed_values.items()
- if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
- },
- }
- passed_config = init_config.pop("config", {})
- # Populate `init_config` with provided config
- if isinstance(passed_config, dict):
- init_config.update(passed_config)
- # Set `config` attribute and return
- if init_config != {}:
- instance._hub_mixin_config = init_config
- return instance
- @classmethod
- def _is_jsonable(cls, value: Any) -> bool:
- """Check if a value is JSON serializable."""
- if is_dataclass(value):
- return True
- if isinstance(value, cls._hub_mixin_jsonable_custom_types):
- return True
- return is_jsonable(value)
- @classmethod
- def _encode_arg(cls, arg: Any) -> Any:
- """Encode an argument into a JSON serializable format."""
- if is_dataclass(arg):
- return asdict(arg) # type: ignore[arg-type]
- for type_, (encoder, _) in cls._hub_mixin_coders.items():
- if isinstance(arg, type_):
- if arg is None:
- return None
- return encoder(arg)
- return arg
- @classmethod
- def _decode_arg(cls, expected_type: type[ARGS_T], value: Any) -> ARGS_T | None:
- """Decode a JSON serializable value into an argument."""
- if is_simple_optional_type(expected_type):
- if value is None:
- return None
- expected_type = unwrap_simple_optional_type(expected_type) # type: ignore
- # Dataclass => handle it
- if is_dataclass(expected_type):
- return _load_dataclass(expected_type, value) # type: ignore
- # Otherwise => check custom decoders
- for type_, (_, decoder) in cls._hub_mixin_coders.items():
- if inspect.isclass(expected_type) and issubclass(expected_type, type_):
- return decoder(value)
- # Otherwise => don't decode
- return value
- def save_pretrained(
- self,
- save_directory: str | Path,
- *,
- config: dict | DataclassInstance | None = None,
- repo_id: str | None = None,
- push_to_hub: bool = False,
- model_card_kwargs: dict[str, Any] | None = None,
- **push_to_hub_kwargs,
- ) -> str | None:
- """
- Save weights in local directory.
- Args:
- save_directory (`str` or `Path`):
- Path to directory in which the model weights and configuration will be saved.
- config (`dict` or `DataclassInstance`, *optional*):
- Model configuration specified as a key/value dictionary or a dataclass instance.
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Huggingface Hub after saving it.
- repo_id (`str`, *optional*):
- ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
- not provided.
- model_card_kwargs (`dict[str, Any]`, *optional*):
- Additional arguments passed to the model card template to customize the model card.
- push_to_hub_kwargs:
- Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
- Returns:
- `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise.
- """
- save_directory = Path(save_directory)
- save_directory.mkdir(parents=True, exist_ok=True)
- # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
- # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
- # an existing config.json if it was not saved by `_save_pretrained`.
- config_path = save_directory / constants.CONFIG_NAME
- config_path.unlink(missing_ok=True)
- # save model weights/files (framework-specific)
- self._save_pretrained(save_directory)
- # save config (if provided and if not serialized yet in `_save_pretrained`)
- if config is None:
- config = self._hub_mixin_config
- if config is not None:
- if is_dataclass(config):
- config = asdict(config) # type: ignore[arg-type]
- if not config_path.exists():
- config_str = json.dumps(config, sort_keys=True, indent=2)
- config_path.write_text(config_str)
- # save model card
- model_card_path = save_directory / "README.md"
- model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {}
- if not model_card_path.exists(): # do not overwrite if already exists
- self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md")
- # push to the Hub if required
- if push_to_hub:
- kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
- if config is not None: # kwarg for `push_to_hub`
- kwargs["config"] = config
- if repo_id is None:
- repo_id = save_directory.name # Defaults to `save_directory` name
- return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs)
- return None
- def _save_pretrained(self, save_directory: Path) -> None:
- """
- Overwrite this method in subclass to define how to save your model.
- Check out our [integration guide](../guides/integrations) for instructions.
- Args:
- save_directory (`str` or `Path`):
- Path to directory in which the model weights and configuration will be saved.
- """
- raise NotImplementedError
- @classmethod
- @validate_hf_hub_args
- def from_pretrained(
- cls: type[T],
- pretrained_model_name_or_path: str | Path,
- *,
- force_download: bool = False,
- token: str | bool | None = None,
- cache_dir: str | Path | None = None,
- local_files_only: bool = False,
- revision: str | None = None,
- **model_kwargs,
- ) -> T:
- """
- Download a model from the Huggingface Hub and instantiate it.
- Args:
- pretrained_model_name_or_path (`str`, `Path`):
- - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`.
- - Or a path to a `directory` containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`.
- revision (`str`, *optional*):
- Revision of the model on the Hub. Can be a branch name, a git tag or any commit id.
- Defaults to the latest commit on `main` branch.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
- the existing cache.
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. By default, it will use the token
- cached when running `hf auth login`.
- cache_dir (`str`, `Path`, *optional*):
- Path to the folder where cached files are stored.
- local_files_only (`bool`, *optional*, defaults to `False`):
- If `True`, avoid downloading the file and return the path to the local cached file if it exists.
- model_kwargs (`dict`, *optional*):
- Additional kwargs to pass to the model during initialization.
- """
- model_id = str(pretrained_model_name_or_path)
- config_file: str | None = None
- if os.path.isdir(model_id):
- if constants.CONFIG_NAME in os.listdir(model_id):
- config_file = os.path.join(model_id, constants.CONFIG_NAME)
- else:
- logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}")
- else:
- try:
- config_file = hf_hub_download(
- repo_id=model_id,
- filename=constants.CONFIG_NAME,
- revision=revision,
- cache_dir=cache_dir,
- force_download=force_download,
- token=token,
- local_files_only=local_files_only,
- )
- except HfHubHTTPError as e:
- logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}")
- # Read config
- config = None
- if config_file is not None:
- with open(config_file, encoding="utf-8") as f:
- config = json.load(f)
- # Decode custom types in config
- for key, value in config.items():
- if key in cls._hub_mixin_init_parameters:
- expected_type = cls._hub_mixin_init_parameters[key].annotation
- if expected_type is not inspect.Parameter.empty:
- config[key] = cls._decode_arg(expected_type, value)
- # Populate model_kwargs from config
- for param in cls._hub_mixin_init_parameters.values():
- if param.name not in model_kwargs and param.name in config:
- model_kwargs[param.name] = config[param.name]
- # Check if `config` argument was passed at init
- if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
- # Decode `config` argument if it was passed
- config_annotation = cls._hub_mixin_init_parameters["config"].annotation
- config = cls._decode_arg(config_annotation, config)
- # Forward config to model initialization
- model_kwargs["config"] = config
- # Inject config if `**kwargs` are expected
- if is_dataclass(cls):
- for key in cls.__dataclass_fields__:
- if key not in model_kwargs and key in config:
- model_kwargs[key] = config[key]
- elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
- for key, value in config.items(): # type: ignore[union-attr]
- if key not in model_kwargs:
- model_kwargs[key] = value
- # Finally, also inject if `_from_pretrained` expects it
- if cls._hub_mixin_inject_config and "config" not in model_kwargs:
- model_kwargs["config"] = config
- instance = cls._from_pretrained(
- model_id=str(model_id),
- revision=revision,
- cache_dir=cache_dir,
- force_download=force_download,
- local_files_only=local_files_only,
- token=token,
- **model_kwargs,
- )
- # Implicitly set the config as instance attribute if not already set by the class
- # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
- if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})):
- instance._hub_mixin_config = config
- return instance
- @classmethod
- def _from_pretrained(
- cls: type[T],
- *,
- model_id: str,
- revision: str | None,
- cache_dir: str | Path | None,
- force_download: bool,
- local_files_only: bool,
- token: str | bool | None,
- **model_kwargs,
- ) -> T:
- """Overwrite this method in subclass to define how to load your model from pretrained.
- Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most
- args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this
- method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location`
- parameter to set on which device the model should be loaded.
- Check out our [integration guide](../guides/integrations) for more instructions.
- Args:
- model_id (`str`):
- ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
- revision (`str`, *optional*):
- Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
- latest commit on `main` branch.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
- the existing cache.
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. By default, it will use the token
- cached when running `hf auth login`.
- cache_dir (`str`, `Path`, *optional*):
- Path to the folder where cached files are stored.
- local_files_only (`bool`, *optional*, defaults to `False`):
- If `True`, avoid downloading the file and return the path to the local cached file if it exists.
- model_kwargs:
- Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
- """
- raise NotImplementedError
- @validate_hf_hub_args
- def push_to_hub(
- self,
- repo_id: str,
- *,
- config: dict | DataclassInstance | None = None,
- commit_message: str = "Push model using huggingface_hub.",
- private: bool | None = None,
- token: str | None = None,
- branch: str | None = None,
- create_pr: bool | None = None,
- allow_patterns: list[str] | str | None = None,
- ignore_patterns: list[str] | str | None = None,
- delete_patterns: list[str] | str | None = None,
- model_card_kwargs: dict[str, Any] | None = None,
- ) -> str:
- """
- Upload model checkpoint to the Hub.
- Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
- `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
- details.
- Args:
- repo_id (`str`):
- ID of the repository to push to (example: `"username/my-model"`).
- config (`dict` or `DataclassInstance`, *optional*):
- Model configuration specified as a key/value dictionary or a dataclass instance.
- commit_message (`str`, *optional*):
- Message to commit while pushing.
- private (`bool`, *optional*):
- Whether the repository created should be private.
- If `None` (default), the repo will be public unless the organization's default is private.
- token (`str`, *optional*):
- The token to use as HTTP bearer authorization for remote files. By default, it will use the token
- cached when running `hf auth login`.
- branch (`str`, *optional*):
- The git branch on which to push the model. This defaults to `"main"`.
- create_pr (`boolean`, *optional*):
- Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`.
- allow_patterns (`list[str]` or `str`, *optional*):
- If provided, only files matching at least one pattern are pushed.
- ignore_patterns (`list[str]` or `str`, *optional*):
- If provided, files matching any of the patterns are not pushed.
- delete_patterns (`list[str]` or `str`, *optional*):
- If provided, remote files matching any of the patterns will be deleted from the repo.
- model_card_kwargs (`dict[str, Any]`, *optional*):
- Additional arguments passed to the model card template to customize the model card.
- Returns:
- The url of the commit of your model in the given repository.
- """
- api = HfApi(token=token)
- repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
- # Push the files to the repo in a single commit
- with SoftTemporaryDirectory() as tmp:
- saved_path = Path(tmp) / repo_id
- self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs)
- return api.upload_folder(
- repo_id=repo_id,
- repo_type="model",
- folder_path=saved_path,
- commit_message=commit_message,
- revision=branch,
- create_pr=create_pr,
- allow_patterns=allow_patterns,
- ignore_patterns=ignore_patterns,
- delete_patterns=delete_patterns,
- )
- def generate_model_card(self, *args, **kwargs) -> ModelCard:
- card = ModelCard.from_template(
- card_data=self._hub_mixin_info.model_card_data,
- template_str=self._hub_mixin_info.model_card_template,
- repo_url=self._hub_mixin_info.repo_url,
- paper_url=self._hub_mixin_info.paper_url,
- docs_url=self._hub_mixin_info.docs_url,
- **kwargs,
- )
- return card
- class PyTorchModelHubMixin(ModelHubMixin):
- """
- Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model
- is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model,
- you should first set it back in training mode with `model.train()`.
- See [`ModelHubMixin`] for more details on how to use the mixin.
- Example:
- ```python
- >>> import torch
- >>> import torch.nn as nn
- >>> from huggingface_hub import PyTorchModelHubMixin
- >>> class MyModel(
- ... nn.Module,
- ... PyTorchModelHubMixin,
- ... library_name="keras-nlp",
- ... repo_url="https://github.com/keras-team/keras-nlp",
- ... paper_url="https://arxiv.org/abs/2304.12244",
- ... docs_url="https://keras.io/keras_nlp/",
- ... # ^ optional metadata to generate model card
- ... ):
- ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4):
- ... super().__init__()
- ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size))
- ... self.linear = nn.Linear(output_size, vocab_size)
- ... def forward(self, x):
- ... return self.linear(x + self.param)
- >>> model = MyModel(hidden_size=256)
- # Save model weights to local directory
- >>> model.save_pretrained("my-awesome-model")
- # Push model weights to the Hub
- >>> model.push_to_hub("my-awesome-model")
- # Download and initialize weights from the Hub
- >>> model = MyModel.from_pretrained("username/my-awesome-model")
- >>> model.hidden_size
- 256
- ```
- """
- def __init_subclass__(cls, *args, tags: list[str] | None = None, **kwargs) -> None:
- tags = tags or []
- tags.append("pytorch_model_hub_mixin")
- kwargs["tags"] = tags
- return super().__init_subclass__(*args, **kwargs)
- def _save_pretrained(self, save_directory: Path) -> None:
- """Save weights from a Pytorch model to a local directory."""
- model_to_save = self.module if hasattr(self, "module") else self # type: ignore
- save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) # type: ignore [arg-type]
- @classmethod
- def _from_pretrained(
- cls,
- *,
- model_id: str,
- revision: str | None,
- cache_dir: str | Path | None,
- force_download: bool,
- local_files_only: bool,
- token: str | bool | None,
- map_location: str = "cpu",
- strict: bool = False,
- **model_kwargs,
- ):
- """Load Pytorch pretrained weights and return the loaded model."""
- model = cls(**model_kwargs)
- if os.path.isdir(model_id):
- print("Loading weights from local directory")
- model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE)
- return cls._load_as_safetensor(model, model_file, map_location, strict)
- else:
- try:
- model_file = hf_hub_download(
- repo_id=model_id,
- filename=constants.SAFETENSORS_SINGLE_FILE,
- revision=revision,
- cache_dir=cache_dir,
- force_download=force_download,
- token=token,
- local_files_only=local_files_only,
- )
- return cls._load_as_safetensor(model, model_file, map_location, strict)
- except EntryNotFoundError:
- model_file = hf_hub_download(
- repo_id=model_id,
- filename=constants.PYTORCH_WEIGHTS_NAME,
- revision=revision,
- cache_dir=cache_dir,
- force_download=force_download,
- token=token,
- local_files_only=local_files_only,
- )
- return cls._load_as_pickle(model, model_file, map_location, strict)
- @classmethod
- def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
- state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True)
- model.load_state_dict(state_dict, strict=strict) # type: ignore
- model.eval() # type: ignore
- return model
- @classmethod
- def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
- if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined]
- load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
- if map_location != "cpu":
- logger.warning(
- "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
- " This means that the model is loaded on 'cpu' first and then copied to the device."
- " This leads to a slower loading time."
- " Please update safetensors to version 0.4.3 or above for improved performance."
- )
- model.to(map_location) # type: ignore [attr-defined]
- else:
- safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type]
- model.eval() # type: ignore
- return model
- def _load_dataclass(datacls: type[DataclassInstance], data: dict) -> DataclassInstance:
- """Load a dataclass instance from a dictionary.
- Fields not expected by the dataclass are ignored.
- """
- return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__})
|