_torch.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains pytorch-specific helpers."""
  15. import importlib
  16. import importlib.util
  17. import json
  18. import os
  19. import re
  20. from collections import defaultdict, namedtuple
  21. from collections.abc import Iterable
  22. from functools import lru_cache
  23. from pathlib import Path
  24. from typing import TYPE_CHECKING, Any, NamedTuple, Union
  25. from packaging import version
  26. from .. import constants, logging
  27. from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
  28. logger = logging.get_logger(__file__)
  29. if TYPE_CHECKING:
  30. import torch
  31. # SAVING
  32. def save_torch_model(
  33. model: "torch.nn.Module",
  34. save_directory: str | Path,
  35. *,
  36. filename_pattern: str | None = None,
  37. force_contiguous: bool = True,
  38. max_shard_size: int | str = MAX_SHARD_SIZE,
  39. metadata: dict[str, str] | None = None,
  40. safe_serialization: bool = True,
  41. is_main_process: bool = True,
  42. shared_tensors_to_discard: list[str] | None = None,
  43. ):
  44. """
  45. Saves a given torch model to disk, handling sharding and shared tensors issues.
  46. See also [`save_torch_state_dict`] to save a state dict with more flexibility.
  47. For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
  48. The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
  49. saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
  50. an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
  51. [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
  52. safetensors (the default). Otherwise, the shards are saved as pickle.
  53. Before saving the model, the `save_directory` is cleaned from any previous shard files.
  54. > [!WARNING]
  55. > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
  56. > size greater than `max_shard_size`.
  57. > [!WARNING]
  58. > If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
  59. Args:
  60. model (`torch.nn.Module`):
  61. The model to save on disk.
  62. save_directory (`str` or `Path`):
  63. The directory in which the model will be saved.
  64. filename_pattern (`str`, *optional*):
  65. The pattern to generate the files names in which the model will be saved. Pattern must be a string that
  66. can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
  67. Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
  68. parameter.
  69. force_contiguous (`boolean`, *optional*):
  70. Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
  71. model, but it could potentially change performance if the layout of the tensor was chosen specifically for
  72. that reason. Defaults to `True`.
  73. max_shard_size (`int` or `str`, *optional*):
  74. The maximum size of each shard, in bytes. Defaults to 5GB.
  75. metadata (`dict[str, str]`, *optional*):
  76. Extra information to save along with the model. Some metadata will be added for each dropped tensors.
  77. This information will not be enough to recover the entire shared structure but might help understanding
  78. things.
  79. safe_serialization (`bool`, *optional*):
  80. Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
  81. Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
  82. in a future version.
  83. is_main_process (`bool`, *optional*):
  84. Whether the process calling this is the main process or not. Useful when in distributed training like
  85. TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
  86. the main process to avoid race conditions. Defaults to True.
  87. shared_tensors_to_discard (`list[str]`, *optional*):
  88. List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
  89. detected, it will drop the first name alphabetically.
  90. Example:
  91. ```py
  92. >>> from huggingface_hub import save_torch_model
  93. >>> model = ... # A PyTorch model
  94. # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
  95. >>> save_torch_model(model, "path/to/folder")
  96. # Load model back
  97. >>> from huggingface_hub import load_torch_model # TODO
  98. >>> load_torch_model(model, "path/to/folder")
  99. >>>
  100. ```
  101. """
  102. save_torch_state_dict(
  103. state_dict=model.state_dict(),
  104. filename_pattern=filename_pattern,
  105. force_contiguous=force_contiguous,
  106. max_shard_size=max_shard_size,
  107. metadata=metadata,
  108. safe_serialization=safe_serialization,
  109. save_directory=save_directory,
  110. is_main_process=is_main_process,
  111. shared_tensors_to_discard=shared_tensors_to_discard,
  112. )
  113. def save_torch_state_dict(
  114. state_dict: dict[str, "torch.Tensor"],
  115. save_directory: str | Path,
  116. *,
  117. filename_pattern: str | None = None,
  118. force_contiguous: bool = True,
  119. max_shard_size: int | str = MAX_SHARD_SIZE,
  120. metadata: dict[str, str] | None = None,
  121. safe_serialization: bool = True,
  122. is_main_process: bool = True,
  123. shared_tensors_to_discard: list[str] | None = None,
  124. ) -> None:
  125. """
  126. Save a model state dictionary to the disk, handling sharding and shared tensors issues.
  127. See also [`save_torch_model`] to directly save a PyTorch model.
  128. For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
  129. The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
  130. saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
  131. an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
  132. [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
  133. safetensors (the default). Otherwise, the shards are saved as pickle.
  134. Before saving the model, the `save_directory` is cleaned from any previous shard files.
  135. > [!WARNING]
  136. > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
  137. > size greater than `max_shard_size`.
  138. > [!WARNING]
  139. > If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
  140. Args:
  141. state_dict (`dict[str, torch.Tensor]`):
  142. The state dictionary to save.
  143. save_directory (`str` or `Path`):
  144. The directory in which the model will be saved.
  145. filename_pattern (`str`, *optional*):
  146. The pattern to generate the files names in which the model will be saved. Pattern must be a string that
  147. can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
  148. Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
  149. parameter.
  150. force_contiguous (`boolean`, *optional*):
  151. Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
  152. model, but it could potentially change performance if the layout of the tensor was chosen specifically for
  153. that reason. Defaults to `True`.
  154. max_shard_size (`int` or `str`, *optional*):
  155. The maximum size of each shard, in bytes. Defaults to 5GB.
  156. metadata (`dict[str, str]`, *optional*):
  157. Extra information to save along with the model. Some metadata will be added for each dropped tensors.
  158. This information will not be enough to recover the entire shared structure but might help understanding
  159. things.
  160. safe_serialization (`bool`, *optional*):
  161. Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
  162. Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
  163. in a future version.
  164. is_main_process (`bool`, *optional*):
  165. Whether the process calling this is the main process or not. Useful when in distributed training like
  166. TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
  167. the main process to avoid race conditions. Defaults to True.
  168. shared_tensors_to_discard (`list[str]`, *optional*):
  169. List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
  170. detected, it will drop the first name alphabetically.
  171. Example:
  172. ```py
  173. >>> from huggingface_hub import save_torch_state_dict
  174. >>> model = ... # A PyTorch model
  175. # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
  176. >>> state_dict = model_to_save.state_dict()
  177. >>> save_torch_state_dict(state_dict, "path/to/folder")
  178. ```
  179. """
  180. save_directory = str(save_directory)
  181. if filename_pattern is None:
  182. filename_pattern = (
  183. constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
  184. if safe_serialization
  185. else constants.PYTORCH_WEIGHTS_FILE_PATTERN
  186. )
  187. if metadata is None:
  188. metadata = {}
  189. if safe_serialization:
  190. try:
  191. from safetensors.torch import save_file as save_file_fn
  192. except ImportError as e:
  193. raise ImportError(
  194. "Please install `safetensors` to use safe serialization. "
  195. "You can install it with `pip install safetensors`."
  196. ) from e
  197. # Clean state dict for safetensors
  198. state_dict = _clean_state_dict_for_safetensors(
  199. state_dict,
  200. metadata,
  201. force_contiguous=force_contiguous,
  202. shared_tensors_to_discard=shared_tensors_to_discard,
  203. )
  204. else:
  205. from torch import save as save_file_fn # type: ignore[assignment, no-redef]
  206. logger.warning(
  207. "You are using unsafe serialization. Due to security reasons, it is recommended not to load "
  208. "pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
  209. "using safe serialization by installing `safetensors` with `pip install safetensors`."
  210. )
  211. # Split dict
  212. state_dict_split = split_torch_state_dict_into_shards(
  213. state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
  214. )
  215. # Only main process should clean up existing files to avoid race conditions in distributed environment
  216. if is_main_process:
  217. existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
  218. for filename in os.listdir(save_directory):
  219. if existing_files_regex.match(filename):
  220. try:
  221. logger.debug(f"Removing existing file '{filename}' from folder.")
  222. os.remove(os.path.join(save_directory, filename))
  223. except Exception as e:
  224. logger.warning(
  225. f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
  226. )
  227. # Save each shard
  228. per_file_metadata = {"format": "pt"}
  229. if not state_dict_split.is_sharded:
  230. per_file_metadata.update(metadata)
  231. safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {}
  232. for filename, tensors in state_dict_split.filename_to_tensors.items():
  233. shard = {tensor: state_dict[tensor] for tensor in tensors}
  234. save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) # ty: ignore[invalid-argument-type]
  235. logger.debug(f"Shard saved to {filename}")
  236. # Save the index (if any)
  237. if state_dict_split.is_sharded:
  238. index_path = filename_pattern.format(suffix="") + ".index.json"
  239. index = {
  240. "metadata": {**state_dict_split.metadata, **metadata},
  241. "weight_map": state_dict_split.tensor_to_filename,
  242. }
  243. with open(os.path.join(save_directory, index_path), "w") as f:
  244. json.dump(index, f, indent=2)
  245. logger.info(
  246. f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). "
  247. f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. "
  248. f"You can find where each parameters has been saved in the index located at {index_path}."
  249. )
  250. logger.info(f"Model weights successfully saved to {save_directory}!")
  251. def split_torch_state_dict_into_shards(
  252. state_dict: dict[str, "torch.Tensor"],
  253. *,
  254. filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
  255. max_shard_size: int | str = MAX_SHARD_SIZE,
  256. ) -> StateDictSplit:
  257. """
  258. Split a model state dictionary in shards so that each shard is smaller than a given size.
  259. The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
  260. made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
  261. have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
  262. [6+2+2GB], [6+2GB], [6GB].
  263. > [!TIP]
  264. > To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
  265. > `split_torch_state_dict_into_shards` under the hood.
  266. > [!WARNING]
  267. > If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
  268. > size greater than `max_shard_size`.
  269. Args:
  270. state_dict (`dict[str, torch.Tensor]`):
  271. The state dictionary to save.
  272. filename_pattern (`str`, *optional*):
  273. The pattern to generate the files names in which the model will be saved. Pattern must be a string that
  274. can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
  275. Defaults to `"model{suffix}.safetensors"`.
  276. max_shard_size (`int` or `str`, *optional*):
  277. The maximum size of each shard, in bytes. Defaults to 5GB.
  278. Returns:
  279. [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
  280. Example:
  281. ```py
  282. >>> import json
  283. >>> import os
  284. >>> from safetensors.torch import save_file as safe_save_file
  285. >>> from huggingface_hub import split_torch_state_dict_into_shards
  286. >>> def save_state_dict(state_dict: dict[str, torch.Tensor], save_directory: str):
  287. ... state_dict_split = split_torch_state_dict_into_shards(state_dict)
  288. ... for filename, tensors in state_dict_split.filename_to_tensors.items():
  289. ... shard = {tensor: state_dict[tensor] for tensor in tensors}
  290. ... safe_save_file(
  291. ... shard,
  292. ... os.path.join(save_directory, filename),
  293. ... metadata={"format": "pt"},
  294. ... )
  295. ... if state_dict_split.is_sharded:
  296. ... index = {
  297. ... "metadata": state_dict_split.metadata,
  298. ... "weight_map": state_dict_split.tensor_to_filename,
  299. ... }
  300. ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
  301. ... f.write(json.dumps(index, indent=2))
  302. ```
  303. """
  304. return split_state_dict_into_shards_factory(
  305. state_dict,
  306. max_shard_size=max_shard_size,
  307. filename_pattern=filename_pattern,
  308. get_storage_size=get_torch_storage_size,
  309. get_storage_id=get_torch_storage_id,
  310. )
  311. # LOADING
  312. def load_torch_model(
  313. model: "torch.nn.Module",
  314. checkpoint_path: str | os.PathLike,
  315. *,
  316. strict: bool = False,
  317. safe: bool = True,
  318. weights_only: bool = False,
  319. map_location: Union[str, "torch.device"] | None = None,
  320. mmap: bool = False,
  321. filename_pattern: str | None = None,
  322. ) -> NamedTuple:
  323. """
  324. Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
  325. Args:
  326. model (`torch.nn.Module`):
  327. The model in which to load the checkpoint.
  328. checkpoint_path (`str` or `os.PathLike`):
  329. Path to either the checkpoint file or directory containing the checkpoint(s).
  330. strict (`bool`, *optional*, defaults to `False`):
  331. Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
  332. safe (`bool`, *optional*, defaults to `True`):
  333. If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
  334. will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
  335. pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
  336. weights_only (`bool`, *optional*, defaults to `False`):
  337. If True, only loads the model weights without optimizer states and other metadata.
  338. Only supported in PyTorch >= 1.13.
  339. map_location (`str` or `torch.device`, *optional*):
  340. A `torch.device` object, string or a dict specifying how to remap storage locations. It
  341. indicates the location where all tensors should be loaded.
  342. mmap (`bool`, *optional*, defaults to `False`):
  343. Whether to use memory-mapped file loading. Memory mapping can improve loading performance
  344. for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
  345. filename_pattern (`str`, *optional*):
  346. The pattern to look for the index file. Pattern must be a string that
  347. can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
  348. Defaults to `"model{suffix}.safetensors"`.
  349. Returns:
  350. `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
  351. - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
  352. - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
  353. Raises:
  354. [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
  355. If the checkpoint file or directory does not exist.
  356. [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
  357. If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
  358. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  359. If the checkpoint path is invalid or if the checkpoint format cannot be determined.
  360. Example:
  361. ```python
  362. >>> from huggingface_hub import load_torch_model
  363. >>> model = ... # A PyTorch model
  364. >>> load_torch_model(model, "path/to/checkpoint")
  365. ```
  366. """
  367. checkpoint_path = Path(checkpoint_path)
  368. if not checkpoint_path.exists():
  369. raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
  370. # 1. Check if checkpoint is a single file
  371. if checkpoint_path.is_file():
  372. state_dict = load_state_dict_from_file(
  373. checkpoint_file=checkpoint_path,
  374. map_location=map_location,
  375. weights_only=weights_only,
  376. )
  377. return model.load_state_dict(state_dict, strict=strict)
  378. # 2. If not, checkpoint_path is a directory
  379. if filename_pattern is None:
  380. filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
  381. index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
  382. # Only fallback to pickle format if safetensors index is not found and safe is False.
  383. if not index_path.is_file() and not safe:
  384. filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
  385. index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
  386. if index_path.is_file():
  387. return _load_sharded_checkpoint(
  388. model=model,
  389. save_directory=checkpoint_path,
  390. strict=strict,
  391. weights_only=weights_only,
  392. filename_pattern=filename_pattern,
  393. )
  394. # Look for single model file
  395. model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
  396. if len(model_files) == 1:
  397. state_dict = load_state_dict_from_file(
  398. checkpoint_file=model_files[0],
  399. map_location=map_location,
  400. weights_only=weights_only,
  401. mmap=mmap,
  402. )
  403. return model.load_state_dict(state_dict, strict=strict)
  404. raise ValueError(
  405. f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
  406. "Expected either a sharded checkpoint with an index file, or a single model file."
  407. )
  408. def _load_sharded_checkpoint(
  409. model: "torch.nn.Module",
  410. save_directory: os.PathLike,
  411. *,
  412. strict: bool = False,
  413. weights_only: bool = False,
  414. filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
  415. ) -> NamedTuple:
  416. """
  417. Loads a sharded checkpoint into a model. This is the same as
  418. [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
  419. but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
  420. Args:
  421. model (`torch.nn.Module`):
  422. The model in which to load the checkpoint.
  423. save_directory (`str` or `os.PathLike`):
  424. A path to a folder containing the sharded checkpoint.
  425. strict (`bool`, *optional*, defaults to `False`):
  426. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  427. weights_only (`bool`, *optional*, defaults to `False`):
  428. If True, only loads the model weights without optimizer states and other metadata.
  429. Only supported in PyTorch >= 1.13.
  430. filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
  431. The pattern to look for the index file. Pattern must be a string that
  432. can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
  433. Defaults to `"model{suffix}.safetensors"`.
  434. Returns:
  435. `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
  436. - `missing_keys` is a list of str containing the missing keys
  437. - `unexpected_keys` is a list of str containing the unexpected keys
  438. """
  439. # 1. Load and validate index file
  440. # The index file contains mapping of parameter names to shard files
  441. index_path = filename_pattern.format(suffix="") + ".index.json"
  442. index_file = os.path.join(save_directory, index_path)
  443. with open(index_file, encoding="utf-8") as f:
  444. index = json.load(f)
  445. # 2. Validate shard filenames from the index
  446. # This prevents path traversal attacks and extension confusion attacks
  447. # (e.g. a safetensors index referencing .bin pickle files)
  448. expected_extension = Path(filename_pattern.format(suffix="")).suffix # e.g. ".safetensors"
  449. shard_files = list(set(index["weight_map"].values()))
  450. for shard_file in shard_files:
  451. # Reject path traversal (e.g. "../malicious.bin", absolute paths)
  452. if os.path.isabs(shard_file) or ".." in Path(shard_file).parts:
  453. raise ValueError(
  454. f"Invalid shard filename '{shard_file}' in index file '{index_file}'. "
  455. "Shard filenames must be relative paths without '..' components."
  456. )
  457. # Reject extension mismatch (e.g. .bin shard in a .safetensors index)
  458. if not shard_file.endswith(expected_extension):
  459. raise ValueError(
  460. f"Invalid shard filename '{shard_file}' in index file '{index_file}'. "
  461. f"Expected '{expected_extension}' extension to match the index format."
  462. )
  463. # 3. Validate keys if in strict mode
  464. # This is done before loading any shards to fail fast
  465. if strict:
  466. _validate_keys_for_strict_loading(model, index["weight_map"].keys())
  467. # 4. Load each shard using `load_state_dict`
  468. # Get unique shard files (multiple parameters can be in same shard)
  469. for shard_file in shard_files:
  470. # Load shard into memory
  471. shard_path = os.path.join(save_directory, shard_file)
  472. state_dict = load_state_dict_from_file(
  473. shard_path,
  474. map_location="cpu",
  475. weights_only=weights_only,
  476. )
  477. # Update model with parameters from this shard
  478. model.load_state_dict(state_dict, strict=strict)
  479. # Explicitly remove the state dict from memory
  480. del state_dict
  481. # 5. Return compatibility info
  482. loaded_keys = set(index["weight_map"].keys())
  483. model_keys = set(model.state_dict().keys())
  484. return _IncompatibleKeys(
  485. missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
  486. )
  487. def load_state_dict_from_file(
  488. checkpoint_file: str | os.PathLike,
  489. map_location: Union[str, "torch.device"] | None = None,
  490. weights_only: bool = False,
  491. mmap: bool = False,
  492. ) -> dict[str, "torch.Tensor"] | Any:
  493. """
  494. Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
  495. Args:
  496. checkpoint_file (`str` or `os.PathLike`):
  497. Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
  498. map_location (`str` or `torch.device`, *optional*):
  499. A `torch.device` object, string or a dict specifying how to remap storage locations. It
  500. indicates the location where all tensors should be loaded.
  501. weights_only (`bool`, *optional*, defaults to `False`):
  502. If True, only loads the model weights without optimizer states and other metadata.
  503. Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
  504. loading safetensors files.
  505. mmap (`bool`, *optional*, defaults to `False`):
  506. Whether to use memory-mapped file loading. Memory mapping can improve loading performance
  507. for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
  508. loading safetensors files, as the `safetensors` library uses memory mapping by default.
  509. Returns:
  510. `Union[dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
  511. - For safetensors files: always returns a dictionary mapping parameter names to tensors.
  512. - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
  513. an entire model, optimizer state, or any other Python object).
  514. Raises:
  515. [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
  516. If the checkpoint file does not exist.
  517. [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
  518. If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
  519. [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
  520. If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
  521. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  522. If the checkpoint file path is empty or invalid.
  523. Example:
  524. ```python
  525. >>> from huggingface_hub import load_state_dict_from_file
  526. # Load a PyTorch checkpoint
  527. >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
  528. >>> model.load_state_dict(state_dict)
  529. # Load a safetensors checkpoint
  530. >>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
  531. >>> model.load_state_dict(state_dict)
  532. ```
  533. """
  534. checkpoint_path = Path(checkpoint_file)
  535. # Check if file exists and is a regular file (not a directory)
  536. if not checkpoint_path.is_file():
  537. raise FileNotFoundError(
  538. f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
  539. "the file has been properly downloaded."
  540. )
  541. # Load safetensors checkpoint
  542. if checkpoint_path.suffix == ".safetensors":
  543. try:
  544. from safetensors import safe_open
  545. from safetensors.torch import load_file
  546. except ImportError as e:
  547. raise ImportError(
  548. "Please install `safetensors` to load safetensors checkpoint. "
  549. "You can install it with `pip install safetensors`."
  550. ) from e
  551. # Check format of the archive
  552. with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
  553. metadata = f.metadata()
  554. # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
  555. if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
  556. raise OSError(
  557. f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
  558. "you save your model with the `save_torch_model` method."
  559. )
  560. device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
  561. # meta device is not supported with safetensors, falling back to CPU
  562. if device == "meta":
  563. logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
  564. device = "cpu"
  565. return load_file(checkpoint_file, device=device) # type: ignore[arg-type]
  566. # Otherwise, load from pickle
  567. try:
  568. import torch
  569. from torch import load
  570. except ImportError as e:
  571. raise ImportError(
  572. "Please install `torch` to load torch tensors. You can install it with `pip install torch`."
  573. ) from e
  574. # Add additional kwargs, mmap is only supported in torch >= 2.1.0
  575. additional_kwargs = {}
  576. if version.parse(torch.__version__) >= version.parse("2.1.0"):
  577. additional_kwargs["mmap"] = mmap
  578. # weights_only is only supported in torch >= 1.13.0
  579. if version.parse(torch.__version__) >= version.parse("1.13.0"):
  580. additional_kwargs["weights_only"] = weights_only
  581. return load(
  582. checkpoint_file,
  583. map_location=map_location,
  584. **additional_kwargs,
  585. )
  586. # HELPERS
  587. def _validate_keys_for_strict_loading(
  588. model: "torch.nn.Module",
  589. loaded_keys: Iterable[str],
  590. ) -> None:
  591. """
  592. Validate that model keys match loaded keys when strict loading is enabled.
  593. Args:
  594. model: The PyTorch model being loaded
  595. loaded_keys: The keys present in the checkpoint
  596. Raises:
  597. RuntimeError: If there are missing or unexpected keys in strict mode
  598. """
  599. loaded_keys_set = set(loaded_keys)
  600. model_keys = set(model.state_dict().keys())
  601. missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint
  602. unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model
  603. if missing_keys or unexpected_keys:
  604. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  605. if missing_keys:
  606. str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
  607. error_message += f"\nMissing key(s): {str_missing_keys}."
  608. if unexpected_keys:
  609. str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
  610. error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
  611. raise RuntimeError(error_message)
  612. def _get_unique_id(tensor: "torch.Tensor") -> int | tuple[Any, ...]:
  613. """Returns a unique id for plain tensor
  614. or a (potentially nested) Tuple of unique id for the flattened Tensor
  615. if the input is a wrapper tensor subclass Tensor
  616. """
  617. try:
  618. from torch.distributed.tensor import DTensor
  619. if isinstance(tensor, DTensor):
  620. local_tensor = tensor.to_local()
  621. return local_tensor.storage().data_ptr()
  622. except ImportError:
  623. pass
  624. try:
  625. # for torch 2.1 and above we can also handle tensor subclasses
  626. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  627. if is_traceable_wrapper_subclass(tensor):
  628. attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
  629. return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)
  630. except ImportError:
  631. # for torch version less than 2.1, we can fall back to original implementation
  632. pass
  633. if tensor.device.type == "xla" and is_torch_tpu_available():
  634. # NOTE: xla tensors don't have storage
  635. # use some other unique id to distinguish.
  636. # this is a XLA tensor, it must be created using torch_xla's
  637. # device. So the following import is safe:
  638. import torch_xla # type: ignore[import]
  639. unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
  640. else:
  641. unique_id = storage_ptr(tensor)
  642. return unique_id
  643. def get_torch_storage_id(tensor: "torch.Tensor") -> tuple["torch.device", int | tuple[Any, ...], int] | None:
  644. """
  645. Return unique identifier to a tensor storage.
  646. Multiple different tensors can share the same underlying storage. This identifier is
  647. guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
  648. non-overlapping lifetimes may have the same id.
  649. In the case of meta tensors, we return None since we can't tell if they share the same storage.
  650. Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
  651. """
  652. if tensor.device.type == "meta":
  653. return None
  654. else:
  655. return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
  656. def get_torch_storage_size(tensor: "torch.Tensor") -> int:
  657. """
  658. Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
  659. """
  660. try:
  661. from torch.distributed.tensor import DTensor
  662. if isinstance(tensor, DTensor):
  663. # this returns the size of the FULL tensor in bytes
  664. return tensor.nbytes
  665. except ImportError:
  666. pass
  667. try:
  668. # for torch 2.1 and above we can also handle tensor subclasses
  669. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  670. if is_traceable_wrapper_subclass(tensor):
  671. attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
  672. return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
  673. except ImportError:
  674. # for torch version less than 2.1, we can fall back to original implementation
  675. pass
  676. try:
  677. return tensor.untyped_storage().nbytes()
  678. except AttributeError:
  679. # Fallback for torch==1.10
  680. try:
  681. return tensor.storage().size() * _get_dtype_size(tensor.dtype)
  682. except NotImplementedError:
  683. # Fallback for meta storage
  684. # On torch >=2.0 this is the tensor size
  685. return tensor.nelement() * _get_dtype_size(tensor.dtype)
  686. @lru_cache
  687. def is_torch_tpu_available(check_device=True):
  688. """
  689. Checks if `torch_xla` is installed and potentially if a TPU is in the environment
  690. Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463.
  691. """
  692. if importlib.util.find_spec("torch_xla") is not None:
  693. if check_device:
  694. # We need to check if `xla_device` can be found, will raise a RuntimeError if not
  695. try:
  696. import torch_xla.core.xla_model as xm # type: ignore[import]
  697. _ = xm.xla_device()
  698. return True
  699. except RuntimeError:
  700. return False
  701. return True
  702. return False
  703. def storage_ptr(tensor: "torch.Tensor") -> int | tuple[Any, ...]:
  704. """
  705. Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
  706. """
  707. try:
  708. # for torch 2.1 and above we can also handle tensor subclasses
  709. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  710. if is_traceable_wrapper_subclass(tensor):
  711. return _get_unique_id(tensor) # type: ignore
  712. except ImportError:
  713. # for torch version less than 2.1, we can fall back to original implementation
  714. pass
  715. try:
  716. return tensor.untyped_storage().data_ptr()
  717. except Exception:
  718. # Fallback for torch==1.10
  719. try:
  720. return tensor.storage().data_ptr()
  721. except NotImplementedError:
  722. # Fallback for meta storage
  723. return 0
  724. def _clean_state_dict_for_safetensors(
  725. state_dict: dict[str, "torch.Tensor"],
  726. metadata: dict[str, str],
  727. force_contiguous: bool = True,
  728. shared_tensors_to_discard: list[str] | None = None,
  729. ):
  730. """Remove shared tensors from state_dict and update metadata accordingly (for reloading).
  731. Warning: `state_dict` and `metadata` are mutated in-place!
  732. Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
  733. """
  734. to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
  735. for kept_name, to_remove_group in to_removes.items():
  736. for to_remove in to_remove_group:
  737. if metadata is None:
  738. metadata = {}
  739. if to_remove not in metadata:
  740. # Do not override user data
  741. metadata[to_remove] = kept_name
  742. del state_dict[to_remove]
  743. if force_contiguous:
  744. state_dict = {k: v.contiguous() for k, v in state_dict.items()}
  745. return state_dict
  746. def _end_ptr(tensor: "torch.Tensor") -> int:
  747. """
  748. Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23.
  749. """
  750. if tensor.nelement():
  751. stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype)
  752. else:
  753. stop = tensor.data_ptr()
  754. return stop
  755. def _filter_shared_not_shared(tensors: list[set[str]], state_dict: dict[str, "torch.Tensor"]) -> list[set[str]]:
  756. """
  757. Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44
  758. """
  759. filtered_tensors = []
  760. for shared in tensors:
  761. if len(shared) < 2:
  762. filtered_tensors.append(shared)
  763. continue
  764. areas = []
  765. for name in shared:
  766. tensor = state_dict[name]
  767. areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
  768. areas.sort()
  769. _, last_stop, last_name = areas[0]
  770. filtered_tensors.append({last_name})
  771. for start, stop, name in areas[1:]:
  772. if start >= last_stop:
  773. filtered_tensors.append({name})
  774. else:
  775. filtered_tensors[-1].add(name)
  776. last_stop = stop
  777. return filtered_tensors
  778. def _find_shared_tensors(state_dict: dict[str, "torch.Tensor"]) -> list[set[str]]:
  779. """
  780. Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69.
  781. """
  782. import torch
  783. tensors_dict = defaultdict(set)
  784. for k, v in state_dict.items():
  785. if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0:
  786. # Need to add device as key because of multiple GPU.
  787. tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k)
  788. tensors = list(sorted(tensors_dict.values()))
  789. tensors = _filter_shared_not_shared(tensors, state_dict)
  790. return tensors
  791. def _is_complete(tensor: "torch.Tensor") -> bool:
  792. """
  793. Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
  794. """
  795. try:
  796. # for torch 2.1 and above we can also handle tensor subclasses
  797. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  798. if is_traceable_wrapper_subclass(tensor):
  799. attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
  800. return all(_is_complete(getattr(tensor, attr)) for attr in attrs)
  801. except ImportError:
  802. # for torch version less than 2.1, we can fall back to original implementation
  803. pass
  804. return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
  805. tensor.dtype
  806. ) == get_torch_storage_size(tensor)
  807. def _remove_duplicate_names(
  808. state_dict: dict[str, "torch.Tensor"],
  809. *,
  810. preferred_names: list[str] | None = None,
  811. discard_names: list[str] | None = None,
  812. ) -> dict[str, list[str]]:
  813. """
  814. Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
  815. """
  816. if preferred_names is None:
  817. preferred_names = []
  818. unique_preferred_names = set(preferred_names)
  819. if discard_names is None:
  820. discard_names = []
  821. unique_discard_names = set(discard_names)
  822. shareds = _find_shared_tensors(state_dict)
  823. to_remove = defaultdict(list)
  824. for shared in shareds:
  825. complete_names = {name for name in shared if _is_complete(state_dict[name])}
  826. if not complete_names:
  827. raise RuntimeError(
  828. "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
  829. f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model"
  830. " since you could be storing much more memory than needed. Please refer to"
  831. " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
  832. " issue."
  833. )
  834. keep_name = sorted(list(complete_names))[0]
  835. # Mechanism to preferentially select keys to keep
  836. # coming from the on-disk file to allow
  837. # loading models saved with a different choice
  838. # of keep_name
  839. preferred = complete_names.difference(unique_discard_names)
  840. if preferred:
  841. keep_name = sorted(list(preferred))[0]
  842. if unique_preferred_names:
  843. preferred = unique_preferred_names.intersection(complete_names)
  844. if preferred:
  845. keep_name = sorted(list(preferred))[0]
  846. for name in sorted(shared):
  847. if name != keep_name:
  848. to_remove[keep_name].append(name)
  849. return to_remove
  850. @lru_cache
  851. def _get_dtype_size(dtype: "torch.dtype") -> int:
  852. """
  853. Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344
  854. """
  855. import torch
  856. # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
  857. _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
  858. _float8_e5m2 = getattr(torch, "float8_e5m2", None)
  859. _SIZE = {
  860. torch.int64: 8,
  861. torch.float32: 4,
  862. torch.int32: 4,
  863. torch.bfloat16: 2,
  864. torch.float16: 2,
  865. torch.int16: 2,
  866. torch.uint8: 1,
  867. torch.int8: 1,
  868. torch.bool: 1,
  869. torch.float64: 8,
  870. _float8_e4m3fn: 1,
  871. _float8_e5m2: 1,
  872. }
  873. return _SIZE[dtype]
  874. class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
  875. """
  876. This is used to report missing and unexpected keys in the state dict.
  877. Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
  878. """
  879. def __repr__(self) -> str:
  880. if not self.missing_keys and not self.unexpected_keys:
  881. return "<All keys matched successfully>"
  882. return super().__repr__()
  883. __str__ = __repr__