| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550 |
- import os
- import sys
- from collections import defaultdict
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
- from packaging.version import Version
- import torch
- from safetensors import deserialize, safe_open, serialize, serialize_file
- def storage_ptr(tensor: torch.Tensor) -> int:
- try:
- return tensor.untyped_storage().data_ptr()
- except Exception:
- # Fallback for torch==1.10
- try:
- return tensor.storage().data_ptr()
- except NotImplementedError:
- # Fallback for meta storage
- return 0
- def _end_ptr(tensor: torch.Tensor) -> int:
- if tensor.nelement():
- stop = tensor.view(-1)[-1].data_ptr() + _SIZE[tensor.dtype]
- else:
- stop = tensor.data_ptr()
- return stop
- def storage_size(tensor: torch.Tensor) -> int:
- try:
- return tensor.untyped_storage().nbytes()
- except AttributeError:
- # Fallback for torch==1.10
- try:
- return tensor.storage().size() * _SIZE[tensor.dtype]
- except NotImplementedError:
- # Fallback for meta storage
- # On torch >=2.0 this is the tensor size
- return tensor.nelement() * _SIZE[tensor.dtype]
- def _filter_shared_not_shared(
- tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
- ) -> List[Set[str]]:
- filtered_tensors = []
- for shared in tensors:
- if len(shared) < 2:
- filtered_tensors.append(shared)
- continue
- areas = []
- for name in shared:
- tensor = state_dict[name]
- areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
- areas.sort()
- _, last_stop, last_name = areas[0]
- filtered_tensors.append({last_name})
- for start, stop, name in areas[1:]:
- if start >= last_stop:
- filtered_tensors.append({name})
- else:
- filtered_tensors[-1].add(name)
- last_stop = stop
- return filtered_tensors
- def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
- tensors = defaultdict(set)
- for k, v in state_dict.items():
- if (
- v.device != torch.device("meta")
- and storage_ptr(v) != 0
- and storage_size(v) != 0
- ):
- # Need to add device as key because of multiple GPU.
- tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
- tensors = list(sorted(tensors.values()))
- tensors = _filter_shared_not_shared(tensors, state_dict)
- return tensors
- def _is_complete(tensor: torch.Tensor) -> bool:
- return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[
- tensor.dtype
- ] == storage_size(tensor)
- def _remove_duplicate_names(
- state_dict: Dict[str, torch.Tensor],
- *,
- preferred_names: Optional[List[str]] = None,
- discard_names: Optional[List[str]] = None,
- ) -> Dict[str, List[str]]:
- if preferred_names is None:
- preferred_names = []
- preferred_names = set(preferred_names)
- if discard_names is None:
- discard_names = []
- discard_names = set(discard_names)
- shareds = _find_shared_tensors(state_dict)
- to_remove = defaultdict(list)
- for shared in shareds:
- complete_names = set(
- [name for name in shared if _is_complete(state_dict[name])]
- )
- if not complete_names:
- raise RuntimeError(
- "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
- f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model"
- " since you could be storing much more memory than needed. Please refer to"
- " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
- " issue."
- )
- keep_name = sorted(list(complete_names))[0]
- # Mechanism to preferentially select keys to keep
- # coming from the on-disk file to allow
- # loading models saved with a different choice
- # of keep_name
- preferred = complete_names.difference(discard_names)
- if preferred:
- keep_name = sorted(list(preferred))[0]
- if preferred_names:
- preferred = preferred_names.intersection(complete_names)
- if preferred:
- keep_name = sorted(list(preferred))[0]
- for name in sorted(shared):
- if name != keep_name:
- to_remove[keep_name].append(name)
- return to_remove
- def save_model(
- model: torch.nn.Module,
- filename: str,
- metadata: Optional[Dict[str, str]] = None,
- force_contiguous: bool = True,
- ):
- """
- Saves a given torch model to specified filename.
- This method exists specifically to avoid tensor sharing issues which are
- not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
- Args:
- model (`torch.nn.Module`):
- The model to save on disk.
- filename (`str`):
- The filename location to save the file
- metadata (`Dict[str, str]`, *optional*):
- Extra information to save along with the file.
- Some metadata will be added for each dropped tensors.
- This information will not be enough to recover the entire
- shared structure but might help understanding things
- force_contiguous (`boolean`, *optional*, defaults to True):
- Forcing the state_dict to be saved as contiguous tensors.
- This has no effect on the correctness of the model, but it
- could potentially change performance if the layout of the tensor
- was chosen specifically for that reason.
- """
- state_dict = model.state_dict()
- to_removes = _remove_duplicate_names(state_dict)
- for kept_name, to_remove_group in to_removes.items():
- for to_remove in to_remove_group:
- if metadata is None:
- metadata = {}
- if to_remove not in metadata:
- # Do not override user data
- metadata[to_remove] = kept_name
- del state_dict[to_remove]
- if force_contiguous:
- state_dict = {k: v.contiguous() for k, v in state_dict.items()}
- try:
- save_file(state_dict, filename, metadata=metadata)
- except ValueError as e:
- msg = str(e)
- msg += " Or use save_model(..., force_contiguous=True), read the docs for potential caveats."
- raise ValueError(msg)
- def load_model(
- model: torch.nn.Module,
- filename: Union[str, os.PathLike],
- strict: bool = True,
- device: Union[str, int] = "cpu",
- ) -> Tuple[List[str], List[str]]:
- """
- Loads a given filename onto a torch model.
- This method exists specifically to avoid tensor sharing issues which are
- not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
- Args:
- model (`torch.nn.Module`):
- The model to load onto.
- filename (`str`, or `os.PathLike`):
- The filename location to load the file from.
- strict (`bool`, *optional*, defaults to True):
- Whether to fail if you're missing keys or having unexpected ones.
- When false, the function simply returns missing and unexpected names.
- device (`Union[str, int]`, *optional*, defaults to `cpu`):
- The device where the tensors need to be located after load.
- available options are all regular torch device locations.
- Returns:
- `(missing, unexpected): (List[str], List[str])`
- `missing` are names in the model which were not modified during loading
- `unexpected` are names that are on the file, but weren't used during
- the load.
- """
- state_dict = load_file(filename, device=device)
- model_state_dict = model.state_dict()
- to_removes = _remove_duplicate_names(
- model_state_dict, preferred_names=state_dict.keys()
- )
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
- missing = set(missing)
- for to_remove_group in to_removes.values():
- for to_remove in to_remove_group:
- if to_remove not in missing:
- unexpected.append(to_remove)
- else:
- missing.remove(to_remove)
- if strict and (missing or unexpected):
- missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)])
- unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)])
- error = f"Error(s) in loading state_dict for {model.__class__.__name__}:"
- if missing:
- error += f"\n Missing key(s) in state_dict: {missing_keys}"
- if unexpected:
- error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}"
- raise RuntimeError(error)
- return missing, unexpected
- def save(
- tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
- ) -> bytes:
- """
- Saves a dictionary of tensors into raw bytes in safetensors format.
- Args:
- tensors (`Dict[str, torch.Tensor]`):
- The incoming tensors. Tensors need to be contiguous and dense.
- metadata (`Dict[str, str]`, *optional*, defaults to `None`):
- Optional text only metadata you might want to save in your header.
- For instance it can be useful to specify more about the underlying
- tensors. This is purely informative and does not affect tensor loading.
- Returns:
- `bytes`: The raw bytes representing the format
- Example:
- ```python
- from safetensors.torch import save
- import torch
- tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
- byte_data = save(tensors)
- ```
- """
- serialized = serialize(_flatten(tensors), metadata=metadata)
- result = bytes(serialized)
- return result
- def save_file(
- tensors: Dict[str, torch.Tensor],
- filename: Union[str, os.PathLike],
- metadata: Optional[Dict[str, str]] = None,
- ):
- """
- Saves a dictionary of tensors into raw bytes in safetensors format.
- Args:
- tensors (`Dict[str, torch.Tensor]`):
- The incoming tensors. Tensors need to be contiguous and dense.
- filename (`str`, or `os.PathLike`)):
- The filename we're saving into.
- metadata (`Dict[str, str]`, *optional*, defaults to `None`):
- Optional text only metadata you might want to save in your header.
- For instance it can be useful to specify more about the underlying
- tensors. This is purely informative and does not affect tensor loading.
- Returns:
- `None`
- Example:
- ```python
- from safetensors.torch import save_file
- import torch
- tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
- save_file(tensors, "model.safetensors")
- ```
- """
- serialize_file(_flatten(tensors), filename, metadata=metadata)
- def load_file(
- filename: Union[str, os.PathLike], device: Union[str, int] = "cpu"
- ) -> Dict[str, torch.Tensor]:
- """
- Loads a safetensors file into torch format.
- Args:
- filename (`str`, or `os.PathLike`):
- The name of the file which contains the tensors
- device (`Union[str, int]`, *optional*, defaults to `cpu`):
- The device where the tensors need to be located after load.
- available options are all regular torch device locations.
- Returns:
- `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor`
- Example:
- ```python
- from safetensors.torch import load_file
- file_path = "./my_folder/bert.safetensors"
- loaded = load_file(file_path)
- ```
- """
- result = {}
- with safe_open(filename, framework="pt", device=device) as f:
- for k in f.offset_keys():
- result[k] = f.get_tensor(k)
- return result
- def load(data: bytes) -> Dict[str, torch.Tensor]:
- """
- Loads a safetensors file into torch format from pure bytes.
- Args:
- data (`bytes`):
- The content of a safetensors file
- Returns:
- `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor` on cpu
- Example:
- ```python
- from safetensors.torch import load
- file_path = "./my_folder/bert.safetensors"
- with open(file_path, "rb") as f:
- data = f.read()
- loaded = load(data)
- ```
- """
- flat = deserialize(data)
- return _view2torch(flat)
- # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
- _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
- _float8_e5m2 = getattr(torch, "float8_e5m2", None)
- _float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
- _float4_e2m1_x2 = getattr(torch, "float4_e2m1fn_x2", None)
- _SIZE = {
- torch.int64: 8,
- torch.float32: 4,
- torch.int32: 4,
- torch.bfloat16: 2,
- torch.float16: 2,
- torch.int16: 2,
- torch.uint8: 1,
- torch.int8: 1,
- torch.bool: 1,
- torch.float64: 8,
- torch.complex64: 8,
- _float8_e4m3fn: 1,
- _float8_e5m2: 1,
- _float8_e8m0: 1,
- _float4_e2m1_x2: 1,
- }
- if Version(torch.__version__) >= Version("2.3.0"):
- _SIZE.update(
- {
- torch.uint64: 8,
- torch.uint32: 4,
- torch.uint16: 2,
- }
- )
- _TYPES = {
- "F64": torch.float64,
- "F32": torch.float32,
- "F16": torch.float16,
- "BF16": torch.bfloat16,
- "I64": torch.int64,
- "I32": torch.int32,
- "I16": torch.int16,
- "I8": torch.int8,
- "U8": torch.uint8,
- "BOOL": torch.bool,
- "F8_E4M3": _float8_e4m3fn,
- "F8_E5M2": _float8_e5m2,
- "C64": torch.complex64,
- }
- if Version(torch.__version__) >= Version("2.3.0"):
- _TYPES.update(
- {
- "U64": torch.uint64,
- "U32": torch.uint32,
- "U16": torch.uint16,
- }
- )
- def _getdtype(dtype_str: str) -> torch.dtype:
- return _TYPES[dtype_str]
- def _view2torch(safeview) -> Dict[str, torch.Tensor]:
- result = {}
- for k, v in safeview:
- dtype = _getdtype(v["dtype"])
- if len(v["data"]) == 0:
- # Workaround because frombuffer doesn't accept zero-size tensors
- assert any(x == 0 for x in v["shape"])
- arr = torch.empty(v["shape"], dtype=dtype)
- else:
- arr = torch.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
- if sys.byteorder == "big":
- arr = torch.from_numpy(arr.numpy().byteswap(inplace=False))
- result[k] = arr
- return result
- def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
- if tensor.layout != torch.strided:
- raise ValueError(
- f"You are trying to save a sparse tensor: `{name}` which this library does not support."
- " You can make it a dense tensor before saving with `.to_dense()` but be aware this might"
- " make a much larger file than needed."
- )
- if not tensor.is_contiguous():
- raise ValueError(
- f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you"
- " are trying to save tensors which are reference of each other in which case it's recommended to save"
- " only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to"
- " pack it before saving."
- )
- if tensor.device.type != "cpu":
- # Moving tensor to cpu before saving
- tensor = tensor.to("cpu")
- import ctypes
- import numpy as np
- # When shape is empty (scalar), np.prod returns a float
- # we need a int for the following calculations
- length = int(np.prod(tensor.shape).item())
- bytes_per_item = _SIZE[tensor.dtype]
- total_bytes = length * bytes_per_item
- ptr = tensor.data_ptr()
- if ptr == 0:
- return b""
- newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte))
- data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy
- if sys.byteorder == "big":
- NPDTYPES = {
- torch.int64: np.int64,
- torch.float32: np.float32,
- torch.int32: np.int32,
- # XXX: This is ok because both have the same width
- torch.bfloat16: np.float16,
- torch.float16: np.float16,
- torch.int16: np.int16,
- torch.uint8: np.uint8,
- torch.int8: np.int8,
- torch.bool: bool,
- torch.float64: np.float64,
- # XXX: This is ok because both have the same width and byteswap is a no-op anyway
- _float8_e4m3fn: np.uint8,
- _float8_e5m2: np.uint8,
- torch.complex64: np.complex64,
- }
- npdtype = NPDTYPES[tensor.dtype]
- # Not in place as that would potentially modify a live running model
- data = data.view(npdtype).byteswap(inplace=False)
- return data.tobytes()
- def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]:
- if not isinstance(tensors, dict):
- raise ValueError(
- f"Expected a dict of [str, torch.Tensor] but received {type(tensors)}"
- )
- invalid_tensors = []
- for k, v in tensors.items():
- if not isinstance(v, torch.Tensor):
- raise ValueError(
- f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}"
- )
- if v.layout != torch.strided:
- invalid_tensors.append(k)
- if invalid_tensors:
- raise ValueError(
- f"You are trying to save a sparse tensors: `{invalid_tensors}` which this library does not support."
- " You can make it a dense tensor before saving with `.to_dense()` but be aware this might"
- " make a much larger file than needed."
- )
- shared_pointers = _find_shared_tensors(tensors)
- failing = []
- for names in shared_pointers:
- if len(names) > 1:
- failing.append(names)
- if failing:
- raise RuntimeError(
- f"""
- Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: {failing}.
- A potential way to correctly save your model is to use `save_model`.
- More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
- """
- )
- return {
- k: {
- "dtype": str(v.dtype).split(".")[-1],
- "shape": v.shape,
- "data": _tobytes(v, k),
- }
- for k, v in tensors.items()
- }
|