| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035 |
- # Copyright 2024 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Contains pytorch-specific helpers."""
- import importlib
- import importlib.util
- import json
- import os
- import re
- from collections import defaultdict, namedtuple
- from collections.abc import Iterable
- from functools import lru_cache
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, NamedTuple, Union
- from packaging import version
- from .. import constants, logging
- from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
- logger = logging.get_logger(__file__)
- if TYPE_CHECKING:
- import torch
- # SAVING
- def save_torch_model(
- model: "torch.nn.Module",
- save_directory: str | Path,
- *,
- filename_pattern: str | None = None,
- force_contiguous: bool = True,
- max_shard_size: int | str = MAX_SHARD_SIZE,
- metadata: dict[str, str] | None = None,
- safe_serialization: bool = True,
- is_main_process: bool = True,
- shared_tensors_to_discard: list[str] | None = None,
- ):
- """
- Saves a given torch model to disk, handling sharding and shared tensors issues.
- See also [`save_torch_state_dict`] to save a state dict with more flexibility.
- For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
- The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
- saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
- an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
- [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
- safetensors (the default). Otherwise, the shards are saved as pickle.
- Before saving the model, the `save_directory` is cleaned from any previous shard files.
- > [!WARNING]
- > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
- > size greater than `max_shard_size`.
- > [!WARNING]
- > If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
- Args:
- model (`torch.nn.Module`):
- The model to save on disk.
- save_directory (`str` or `Path`):
- The directory in which the model will be saved.
- filename_pattern (`str`, *optional*):
- The pattern to generate the files names in which the model will be saved. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
- parameter.
- force_contiguous (`boolean`, *optional*):
- Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
- model, but it could potentially change performance if the layout of the tensor was chosen specifically for
- that reason. Defaults to `True`.
- max_shard_size (`int` or `str`, *optional*):
- The maximum size of each shard, in bytes. Defaults to 5GB.
- metadata (`dict[str, str]`, *optional*):
- Extra information to save along with the model. Some metadata will be added for each dropped tensors.
- This information will not be enough to recover the entire shared structure but might help understanding
- things.
- safe_serialization (`bool`, *optional*):
- Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
- Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
- in a future version.
- is_main_process (`bool`, *optional*):
- Whether the process calling this is the main process or not. Useful when in distributed training like
- TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
- the main process to avoid race conditions. Defaults to True.
- shared_tensors_to_discard (`list[str]`, *optional*):
- List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
- detected, it will drop the first name alphabetically.
- Example:
- ```py
- >>> from huggingface_hub import save_torch_model
- >>> model = ... # A PyTorch model
- # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
- >>> save_torch_model(model, "path/to/folder")
- # Load model back
- >>> from huggingface_hub import load_torch_model # TODO
- >>> load_torch_model(model, "path/to/folder")
- >>>
- ```
- """
- save_torch_state_dict(
- state_dict=model.state_dict(),
- filename_pattern=filename_pattern,
- force_contiguous=force_contiguous,
- max_shard_size=max_shard_size,
- metadata=metadata,
- safe_serialization=safe_serialization,
- save_directory=save_directory,
- is_main_process=is_main_process,
- shared_tensors_to_discard=shared_tensors_to_discard,
- )
- def save_torch_state_dict(
- state_dict: dict[str, "torch.Tensor"],
- save_directory: str | Path,
- *,
- filename_pattern: str | None = None,
- force_contiguous: bool = True,
- max_shard_size: int | str = MAX_SHARD_SIZE,
- metadata: dict[str, str] | None = None,
- safe_serialization: bool = True,
- is_main_process: bool = True,
- shared_tensors_to_discard: list[str] | None = None,
- ) -> None:
- """
- Save a model state dictionary to the disk, handling sharding and shared tensors issues.
- See also [`save_torch_model`] to directly save a PyTorch model.
- For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
- The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
- saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
- an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
- [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
- safetensors (the default). Otherwise, the shards are saved as pickle.
- Before saving the model, the `save_directory` is cleaned from any previous shard files.
- > [!WARNING]
- > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
- > size greater than `max_shard_size`.
- > [!WARNING]
- > If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
- Args:
- state_dict (`dict[str, torch.Tensor]`):
- The state dictionary to save.
- save_directory (`str` or `Path`):
- The directory in which the model will be saved.
- filename_pattern (`str`, *optional*):
- The pattern to generate the files names in which the model will be saved. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
- parameter.
- force_contiguous (`boolean`, *optional*):
- Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
- model, but it could potentially change performance if the layout of the tensor was chosen specifically for
- that reason. Defaults to `True`.
- max_shard_size (`int` or `str`, *optional*):
- The maximum size of each shard, in bytes. Defaults to 5GB.
- metadata (`dict[str, str]`, *optional*):
- Extra information to save along with the model. Some metadata will be added for each dropped tensors.
- This information will not be enough to recover the entire shared structure but might help understanding
- things.
- safe_serialization (`bool`, *optional*):
- Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
- Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
- in a future version.
- is_main_process (`bool`, *optional*):
- Whether the process calling this is the main process or not. Useful when in distributed training like
- TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
- the main process to avoid race conditions. Defaults to True.
- shared_tensors_to_discard (`list[str]`, *optional*):
- List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
- detected, it will drop the first name alphabetically.
- Example:
- ```py
- >>> from huggingface_hub import save_torch_state_dict
- >>> model = ... # A PyTorch model
- # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
- >>> state_dict = model_to_save.state_dict()
- >>> save_torch_state_dict(state_dict, "path/to/folder")
- ```
- """
- save_directory = str(save_directory)
- if filename_pattern is None:
- filename_pattern = (
- constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
- if safe_serialization
- else constants.PYTORCH_WEIGHTS_FILE_PATTERN
- )
- if metadata is None:
- metadata = {}
- if safe_serialization:
- try:
- from safetensors.torch import save_file as save_file_fn
- except ImportError as e:
- raise ImportError(
- "Please install `safetensors` to use safe serialization. "
- "You can install it with `pip install safetensors`."
- ) from e
- # Clean state dict for safetensors
- state_dict = _clean_state_dict_for_safetensors(
- state_dict,
- metadata,
- force_contiguous=force_contiguous,
- shared_tensors_to_discard=shared_tensors_to_discard,
- )
- else:
- from torch import save as save_file_fn # type: ignore[assignment, no-redef]
- logger.warning(
- "You are using unsafe serialization. Due to security reasons, it is recommended not to load "
- "pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
- "using safe serialization by installing `safetensors` with `pip install safetensors`."
- )
- # Split dict
- state_dict_split = split_torch_state_dict_into_shards(
- state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
- )
- # Only main process should clean up existing files to avoid race conditions in distributed environment
- if is_main_process:
- existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
- for filename in os.listdir(save_directory):
- if existing_files_regex.match(filename):
- try:
- logger.debug(f"Removing existing file '{filename}' from folder.")
- os.remove(os.path.join(save_directory, filename))
- except Exception as e:
- logger.warning(
- f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
- )
- # Save each shard
- per_file_metadata = {"format": "pt"}
- if not state_dict_split.is_sharded:
- per_file_metadata.update(metadata)
- safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {}
- for filename, tensors in state_dict_split.filename_to_tensors.items():
- shard = {tensor: state_dict[tensor] for tensor in tensors}
- save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) # ty: ignore[invalid-argument-type]
- logger.debug(f"Shard saved to {filename}")
- # Save the index (if any)
- if state_dict_split.is_sharded:
- index_path = filename_pattern.format(suffix="") + ".index.json"
- index = {
- "metadata": {**state_dict_split.metadata, **metadata},
- "weight_map": state_dict_split.tensor_to_filename,
- }
- with open(os.path.join(save_directory, index_path), "w") as f:
- json.dump(index, f, indent=2)
- logger.info(
- f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). "
- f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. "
- f"You can find where each parameters has been saved in the index located at {index_path}."
- )
- logger.info(f"Model weights successfully saved to {save_directory}!")
- def split_torch_state_dict_into_shards(
- state_dict: dict[str, "torch.Tensor"],
- *,
- filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
- max_shard_size: int | str = MAX_SHARD_SIZE,
- ) -> StateDictSplit:
- """
- Split a model state dictionary in shards so that each shard is smaller than a given size.
- The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
- made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
- have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
- [6+2+2GB], [6+2GB], [6GB].
- > [!TIP]
- > To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
- > `split_torch_state_dict_into_shards` under the hood.
- > [!WARNING]
- > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
- > size greater than `max_shard_size`.
- Args:
- state_dict (`dict[str, torch.Tensor]`):
- The state dictionary to save.
- filename_pattern (`str`, *optional*):
- The pattern to generate the files names in which the model will be saved. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"`.
- max_shard_size (`int` or `str`, *optional*):
- The maximum size of each shard, in bytes. Defaults to 5GB.
- Returns:
- [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
- Example:
- ```py
- >>> import json
- >>> import os
- >>> from safetensors.torch import save_file as safe_save_file
- >>> from huggingface_hub import split_torch_state_dict_into_shards
- >>> def save_state_dict(state_dict: dict[str, torch.Tensor], save_directory: str):
- ... state_dict_split = split_torch_state_dict_into_shards(state_dict)
- ... for filename, tensors in state_dict_split.filename_to_tensors.items():
- ... shard = {tensor: state_dict[tensor] for tensor in tensors}
- ... safe_save_file(
- ... shard,
- ... os.path.join(save_directory, filename),
- ... metadata={"format": "pt"},
- ... )
- ... if state_dict_split.is_sharded:
- ... index = {
- ... "metadata": state_dict_split.metadata,
- ... "weight_map": state_dict_split.tensor_to_filename,
- ... }
- ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
- ... f.write(json.dumps(index, indent=2))
- ```
- """
- return split_state_dict_into_shards_factory(
- state_dict,
- max_shard_size=max_shard_size,
- filename_pattern=filename_pattern,
- get_storage_size=get_torch_storage_size,
- get_storage_id=get_torch_storage_id,
- )
- # LOADING
- def load_torch_model(
- model: "torch.nn.Module",
- checkpoint_path: str | os.PathLike,
- *,
- strict: bool = False,
- safe: bool = True,
- weights_only: bool = False,
- map_location: Union[str, "torch.device"] | None = None,
- mmap: bool = False,
- filename_pattern: str | None = None,
- ) -> NamedTuple:
- """
- Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
- Args:
- model (`torch.nn.Module`):
- The model in which to load the checkpoint.
- checkpoint_path (`str` or `os.PathLike`):
- Path to either the checkpoint file or directory containing the checkpoint(s).
- strict (`bool`, *optional*, defaults to `False`):
- Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
- safe (`bool`, *optional*, defaults to `True`):
- If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
- will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
- pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
- weights_only (`bool`, *optional*, defaults to `False`):
- If True, only loads the model weights without optimizer states and other metadata.
- Only supported in PyTorch >= 1.13.
- map_location (`str` or `torch.device`, *optional*):
- A `torch.device` object, string or a dict specifying how to remap storage locations. It
- indicates the location where all tensors should be loaded.
- mmap (`bool`, *optional*, defaults to `False`):
- Whether to use memory-mapped file loading. Memory mapping can improve loading performance
- for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
- filename_pattern (`str`, *optional*):
- The pattern to look for the index file. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"`.
- Returns:
- `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
- - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
- - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
- Raises:
- [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
- If the checkpoint file or directory does not exist.
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
- If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
- If the checkpoint path is invalid or if the checkpoint format cannot be determined.
- Example:
- ```python
- >>> from huggingface_hub import load_torch_model
- >>> model = ... # A PyTorch model
- >>> load_torch_model(model, "path/to/checkpoint")
- ```
- """
- checkpoint_path = Path(checkpoint_path)
- if not checkpoint_path.exists():
- raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
- # 1. Check if checkpoint is a single file
- if checkpoint_path.is_file():
- state_dict = load_state_dict_from_file(
- checkpoint_file=checkpoint_path,
- map_location=map_location,
- weights_only=weights_only,
- )
- return model.load_state_dict(state_dict, strict=strict)
- # 2. If not, checkpoint_path is a directory
- if filename_pattern is None:
- filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
- index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
- # Only fallback to pickle format if safetensors index is not found and safe is False.
- if not index_path.is_file() and not safe:
- filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
- index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
- if index_path.is_file():
- return _load_sharded_checkpoint(
- model=model,
- save_directory=checkpoint_path,
- strict=strict,
- weights_only=weights_only,
- filename_pattern=filename_pattern,
- )
- # Look for single model file
- model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
- if len(model_files) == 1:
- state_dict = load_state_dict_from_file(
- checkpoint_file=model_files[0],
- map_location=map_location,
- weights_only=weights_only,
- mmap=mmap,
- )
- return model.load_state_dict(state_dict, strict=strict)
- raise ValueError(
- f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
- "Expected either a sharded checkpoint with an index file, or a single model file."
- )
- def _load_sharded_checkpoint(
- model: "torch.nn.Module",
- save_directory: os.PathLike,
- *,
- strict: bool = False,
- weights_only: bool = False,
- filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
- ) -> NamedTuple:
- """
- Loads a sharded checkpoint into a model. This is the same as
- [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
- but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
- Args:
- model (`torch.nn.Module`):
- The model in which to load the checkpoint.
- save_directory (`str` or `os.PathLike`):
- A path to a folder containing the sharded checkpoint.
- strict (`bool`, *optional*, defaults to `False`):
- Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
- weights_only (`bool`, *optional*, defaults to `False`):
- If True, only loads the model weights without optimizer states and other metadata.
- Only supported in PyTorch >= 1.13.
- filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
- The pattern to look for the index file. Pattern must be a string that
- can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
- Defaults to `"model{suffix}.safetensors"`.
- Returns:
- `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
- - `missing_keys` is a list of str containing the missing keys
- - `unexpected_keys` is a list of str containing the unexpected keys
- """
- # 1. Load and validate index file
- # The index file contains mapping of parameter names to shard files
- index_path = filename_pattern.format(suffix="") + ".index.json"
- index_file = os.path.join(save_directory, index_path)
- with open(index_file, encoding="utf-8") as f:
- index = json.load(f)
- # 2. Validate shard filenames from the index
- # This prevents path traversal attacks and extension confusion attacks
- # (e.g. a safetensors index referencing .bin pickle files)
- expected_extension = Path(filename_pattern.format(suffix="")).suffix # e.g. ".safetensors"
- shard_files = list(set(index["weight_map"].values()))
- for shard_file in shard_files:
- # Reject path traversal (e.g. "../malicious.bin", absolute paths)
- if os.path.isabs(shard_file) or ".." in Path(shard_file).parts:
- raise ValueError(
- f"Invalid shard filename '{shard_file}' in index file '{index_file}'. "
- "Shard filenames must be relative paths without '..' components."
- )
- # Reject extension mismatch (e.g. .bin shard in a .safetensors index)
- if not shard_file.endswith(expected_extension):
- raise ValueError(
- f"Invalid shard filename '{shard_file}' in index file '{index_file}'. "
- f"Expected '{expected_extension}' extension to match the index format."
- )
- # 3. Validate keys if in strict mode
- # This is done before loading any shards to fail fast
- if strict:
- _validate_keys_for_strict_loading(model, index["weight_map"].keys())
- # 4. Load each shard using `load_state_dict`
- # Get unique shard files (multiple parameters can be in same shard)
- for shard_file in shard_files:
- # Load shard into memory
- shard_path = os.path.join(save_directory, shard_file)
- state_dict = load_state_dict_from_file(
- shard_path,
- map_location="cpu",
- weights_only=weights_only,
- )
- # Update model with parameters from this shard
- model.load_state_dict(state_dict, strict=strict)
- # Explicitly remove the state dict from memory
- del state_dict
- # 5. Return compatibility info
- loaded_keys = set(index["weight_map"].keys())
- model_keys = set(model.state_dict().keys())
- return _IncompatibleKeys(
- missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
- )
- def load_state_dict_from_file(
- checkpoint_file: str | os.PathLike,
- map_location: Union[str, "torch.device"] | None = None,
- weights_only: bool = False,
- mmap: bool = False,
- ) -> dict[str, "torch.Tensor"] | Any:
- """
- Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
- Args:
- checkpoint_file (`str` or `os.PathLike`):
- Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
- map_location (`str` or `torch.device`, *optional*):
- A `torch.device` object, string or a dict specifying how to remap storage locations. It
- indicates the location where all tensors should be loaded.
- weights_only (`bool`, *optional*, defaults to `False`):
- If True, only loads the model weights without optimizer states and other metadata.
- Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
- loading safetensors files.
- mmap (`bool`, *optional*, defaults to `False`):
- Whether to use memory-mapped file loading. Memory mapping can improve loading performance
- for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
- loading safetensors files, as the `safetensors` library uses memory mapping by default.
- Returns:
- `Union[dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
- - For safetensors files: always returns a dictionary mapping parameter names to tensors.
- - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
- an entire model, optimizer state, or any other Python object).
- Raises:
- [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
- If the checkpoint file does not exist.
- [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
- If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
- If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
- If the checkpoint file path is empty or invalid.
- Example:
- ```python
- >>> from huggingface_hub import load_state_dict_from_file
- # Load a PyTorch checkpoint
- >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
- >>> model.load_state_dict(state_dict)
- # Load a safetensors checkpoint
- >>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
- >>> model.load_state_dict(state_dict)
- ```
- """
- checkpoint_path = Path(checkpoint_file)
- # Check if file exists and is a regular file (not a directory)
- if not checkpoint_path.is_file():
- raise FileNotFoundError(
- f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
- "the file has been properly downloaded."
- )
- # Load safetensors checkpoint
- if checkpoint_path.suffix == ".safetensors":
- try:
- from safetensors import safe_open
- from safetensors.torch import load_file
- except ImportError as e:
- raise ImportError(
- "Please install `safetensors` to load safetensors checkpoint. "
- "You can install it with `pip install safetensors`."
- ) from e
- # Check format of the archive
- with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
- metadata = f.metadata()
- # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
- if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
- raise OSError(
- f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
- "you save your model with the `save_torch_model` method."
- )
- device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
- # meta device is not supported with safetensors, falling back to CPU
- if device == "meta":
- logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
- device = "cpu"
- return load_file(checkpoint_file, device=device) # type: ignore[arg-type]
- # Otherwise, load from pickle
- try:
- import torch
- from torch import load
- except ImportError as e:
- raise ImportError(
- "Please install `torch` to load torch tensors. You can install it with `pip install torch`."
- ) from e
- # Add additional kwargs, mmap is only supported in torch >= 2.1.0
- additional_kwargs = {}
- if version.parse(torch.__version__) >= version.parse("2.1.0"):
- additional_kwargs["mmap"] = mmap
- # weights_only is only supported in torch >= 1.13.0
- if version.parse(torch.__version__) >= version.parse("1.13.0"):
- additional_kwargs["weights_only"] = weights_only
- return load(
- checkpoint_file,
- map_location=map_location,
- **additional_kwargs,
- )
- # HELPERS
- def _validate_keys_for_strict_loading(
- model: "torch.nn.Module",
- loaded_keys: Iterable[str],
- ) -> None:
- """
- Validate that model keys match loaded keys when strict loading is enabled.
- Args:
- model: The PyTorch model being loaded
- loaded_keys: The keys present in the checkpoint
- Raises:
- RuntimeError: If there are missing or unexpected keys in strict mode
- """
- loaded_keys_set = set(loaded_keys)
- model_keys = set(model.state_dict().keys())
- missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint
- unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model
- if missing_keys or unexpected_keys:
- error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
- if missing_keys:
- str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
- error_message += f"\nMissing key(s): {str_missing_keys}."
- if unexpected_keys:
- str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
- error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
- raise RuntimeError(error_message)
- def _get_unique_id(tensor: "torch.Tensor") -> int | tuple[Any, ...]:
- """Returns a unique id for plain tensor
- or a (potentially nested) Tuple of unique id for the flattened Tensor
- if the input is a wrapper tensor subclass Tensor
- """
- try:
- from torch.distributed.tensor import DTensor
- if isinstance(tensor, DTensor):
- local_tensor = tensor.to_local()
- return local_tensor.storage().data_ptr()
- except ImportError:
- pass
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
- return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)
- except ImportError:
- # for torch version less than 2.1, we can fall back to original implementation
- pass
- if tensor.device.type == "xla" and is_torch_tpu_available():
- # NOTE: xla tensors don't have storage
- # use some other unique id to distinguish.
- # this is a XLA tensor, it must be created using torch_xla's
- # device. So the following import is safe:
- import torch_xla # type: ignore[import]
- unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
- else:
- unique_id = storage_ptr(tensor)
- return unique_id
- def get_torch_storage_id(tensor: "torch.Tensor") -> tuple["torch.device", int | tuple[Any, ...], int] | None:
- """
- Return unique identifier to a tensor storage.
- Multiple different tensors can share the same underlying storage. This identifier is
- guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
- non-overlapping lifetimes may have the same id.
- In the case of meta tensors, we return None since we can't tell if they share the same storage.
- Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
- """
- if tensor.device.type == "meta":
- return None
- else:
- return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
- def get_torch_storage_size(tensor: "torch.Tensor") -> int:
- """
- Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
- """
- try:
- from torch.distributed.tensor import DTensor
- if isinstance(tensor, DTensor):
- # this returns the size of the FULL tensor in bytes
- return tensor.nbytes
- except ImportError:
- pass
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
- return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
- except ImportError:
- # for torch version less than 2.1, we can fall back to original implementation
- pass
- try:
- return tensor.untyped_storage().nbytes()
- except AttributeError:
- # Fallback for torch==1.10
- try:
- return tensor.storage().size() * _get_dtype_size(tensor.dtype)
- except NotImplementedError:
- # Fallback for meta storage
- # On torch >=2.0 this is the tensor size
- return tensor.nelement() * _get_dtype_size(tensor.dtype)
- @lru_cache
- def is_torch_tpu_available(check_device=True):
- """
- Checks if `torch_xla` is installed and potentially if a TPU is in the environment
- Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463.
- """
- if importlib.util.find_spec("torch_xla") is not None:
- if check_device:
- # We need to check if `xla_device` can be found, will raise a RuntimeError if not
- try:
- import torch_xla.core.xla_model as xm # type: ignore[import]
- _ = xm.xla_device()
- return True
- except RuntimeError:
- return False
- return True
- return False
- def storage_ptr(tensor: "torch.Tensor") -> int | tuple[Any, ...]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
- """
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- return _get_unique_id(tensor) # type: ignore
- except ImportError:
- # for torch version less than 2.1, we can fall back to original implementation
- pass
- try:
- return tensor.untyped_storage().data_ptr()
- except Exception:
- # Fallback for torch==1.10
- try:
- return tensor.storage().data_ptr()
- except NotImplementedError:
- # Fallback for meta storage
- return 0
- def _clean_state_dict_for_safetensors(
- state_dict: dict[str, "torch.Tensor"],
- metadata: dict[str, str],
- force_contiguous: bool = True,
- shared_tensors_to_discard: list[str] | None = None,
- ):
- """Remove shared tensors from state_dict and update metadata accordingly (for reloading).
- Warning: `state_dict` and `metadata` are mutated in-place!
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
- """
- to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
- for kept_name, to_remove_group in to_removes.items():
- for to_remove in to_remove_group:
- if metadata is None:
- metadata = {}
- if to_remove not in metadata:
- # Do not override user data
- metadata[to_remove] = kept_name
- del state_dict[to_remove]
- if force_contiguous:
- state_dict = {k: v.contiguous() for k, v in state_dict.items()}
- return state_dict
- def _end_ptr(tensor: "torch.Tensor") -> int:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23.
- """
- if tensor.nelement():
- stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype)
- else:
- stop = tensor.data_ptr()
- return stop
- def _filter_shared_not_shared(tensors: list[set[str]], state_dict: dict[str, "torch.Tensor"]) -> list[set[str]]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44
- """
- filtered_tensors = []
- for shared in tensors:
- if len(shared) < 2:
- filtered_tensors.append(shared)
- continue
- areas = []
- for name in shared:
- tensor = state_dict[name]
- areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
- areas.sort()
- _, last_stop, last_name = areas[0]
- filtered_tensors.append({last_name})
- for start, stop, name in areas[1:]:
- if start >= last_stop:
- filtered_tensors.append({name})
- else:
- filtered_tensors[-1].add(name)
- last_stop = stop
- return filtered_tensors
- def _find_shared_tensors(state_dict: dict[str, "torch.Tensor"]) -> list[set[str]]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69.
- """
- import torch
- tensors_dict = defaultdict(set)
- for k, v in state_dict.items():
- if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0:
- # Need to add device as key because of multiple GPU.
- tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k)
- tensors = list(sorted(tensors_dict.values()))
- tensors = _filter_shared_not_shared(tensors, state_dict)
- return tensors
- def _is_complete(tensor: "torch.Tensor") -> bool:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
- """
- try:
- # for torch 2.1 and above we can also handle tensor subclasses
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- if is_traceable_wrapper_subclass(tensor):
- attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
- return all(_is_complete(getattr(tensor, attr)) for attr in attrs)
- except ImportError:
- # for torch version less than 2.1, we can fall back to original implementation
- pass
- return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
- tensor.dtype
- ) == get_torch_storage_size(tensor)
- def _remove_duplicate_names(
- state_dict: dict[str, "torch.Tensor"],
- *,
- preferred_names: list[str] | None = None,
- discard_names: list[str] | None = None,
- ) -> dict[str, list[str]]:
- """
- Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
- """
- if preferred_names is None:
- preferred_names = []
- unique_preferred_names = set(preferred_names)
- if discard_names is None:
- discard_names = []
- unique_discard_names = set(discard_names)
- shareds = _find_shared_tensors(state_dict)
- to_remove = defaultdict(list)
- for shared in shareds:
- complete_names = {name for name in shared if _is_complete(state_dict[name])}
- if not complete_names:
- raise RuntimeError(
- "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
- f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model"
- " since you could be storing much more memory than needed. Please refer to"
- " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
- " issue."
- )
- keep_name = sorted(list(complete_names))[0]
- # Mechanism to preferentially select keys to keep
- # coming from the on-disk file to allow
- # loading models saved with a different choice
- # of keep_name
- preferred = complete_names.difference(unique_discard_names)
- if preferred:
- keep_name = sorted(list(preferred))[0]
- if unique_preferred_names:
- preferred = unique_preferred_names.intersection(complete_names)
- if preferred:
- keep_name = sorted(list(preferred))[0]
- for name in sorted(shared):
- if name != keep_name:
- to_remove[keep_name].append(name)
- return to_remove
- @lru_cache
- def _get_dtype_size(dtype: "torch.dtype") -> int:
- """
- Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344
- """
- import torch
- # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
- _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
- _float8_e5m2 = getattr(torch, "float8_e5m2", None)
- _SIZE = {
- torch.int64: 8,
- torch.float32: 4,
- torch.int32: 4,
- torch.bfloat16: 2,
- torch.float16: 2,
- torch.int16: 2,
- torch.uint8: 1,
- torch.int8: 1,
- torch.bool: 1,
- torch.float64: 8,
- _float8_e4m3fn: 1,
- _float8_e5m2: 1,
- }
- return _SIZE[dtype]
- class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
- """
- This is used to report missing and unexpected keys in the state dict.
- Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
- """
- def __repr__(self) -> str:
- if not self.missing_keys and not self.unexpected_keys:
- return "<All keys matched successfully>"
- return super().__repr__()
- __str__ = __repr__
|