torch.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. import os
  2. import sys
  3. from collections import defaultdict
  4. from typing import Any, Dict, List, Optional, Set, Tuple, Union
  5. from packaging.version import Version
  6. import torch
  7. from safetensors import deserialize, safe_open, serialize, serialize_file
  8. def storage_ptr(tensor: torch.Tensor) -> int:
  9. try:
  10. return tensor.untyped_storage().data_ptr()
  11. except Exception:
  12. # Fallback for torch==1.10
  13. try:
  14. return tensor.storage().data_ptr()
  15. except NotImplementedError:
  16. # Fallback for meta storage
  17. return 0
  18. def _end_ptr(tensor: torch.Tensor) -> int:
  19. if tensor.nelement():
  20. stop = tensor.view(-1)[-1].data_ptr() + _SIZE[tensor.dtype]
  21. else:
  22. stop = tensor.data_ptr()
  23. return stop
  24. def storage_size(tensor: torch.Tensor) -> int:
  25. try:
  26. return tensor.untyped_storage().nbytes()
  27. except AttributeError:
  28. # Fallback for torch==1.10
  29. try:
  30. return tensor.storage().size() * _SIZE[tensor.dtype]
  31. except NotImplementedError:
  32. # Fallback for meta storage
  33. # On torch >=2.0 this is the tensor size
  34. return tensor.nelement() * _SIZE[tensor.dtype]
  35. def _filter_shared_not_shared(
  36. tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
  37. ) -> List[Set[str]]:
  38. filtered_tensors = []
  39. for shared in tensors:
  40. if len(shared) < 2:
  41. filtered_tensors.append(shared)
  42. continue
  43. areas = []
  44. for name in shared:
  45. tensor = state_dict[name]
  46. areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
  47. areas.sort()
  48. _, last_stop, last_name = areas[0]
  49. filtered_tensors.append({last_name})
  50. for start, stop, name in areas[1:]:
  51. if start >= last_stop:
  52. filtered_tensors.append({name})
  53. else:
  54. filtered_tensors[-1].add(name)
  55. last_stop = stop
  56. return filtered_tensors
  57. def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
  58. tensors = defaultdict(set)
  59. for k, v in state_dict.items():
  60. if (
  61. v.device != torch.device("meta")
  62. and storage_ptr(v) != 0
  63. and storage_size(v) != 0
  64. ):
  65. # Need to add device as key because of multiple GPU.
  66. tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
  67. tensors = list(sorted(tensors.values()))
  68. tensors = _filter_shared_not_shared(tensors, state_dict)
  69. return tensors
  70. def _is_complete(tensor: torch.Tensor) -> bool:
  71. return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[
  72. tensor.dtype
  73. ] == storage_size(tensor)
  74. def _remove_duplicate_names(
  75. state_dict: Dict[str, torch.Tensor],
  76. *,
  77. preferred_names: Optional[List[str]] = None,
  78. discard_names: Optional[List[str]] = None,
  79. ) -> Dict[str, List[str]]:
  80. if preferred_names is None:
  81. preferred_names = []
  82. preferred_names = set(preferred_names)
  83. if discard_names is None:
  84. discard_names = []
  85. discard_names = set(discard_names)
  86. shareds = _find_shared_tensors(state_dict)
  87. to_remove = defaultdict(list)
  88. for shared in shareds:
  89. complete_names = set(
  90. [name for name in shared if _is_complete(state_dict[name])]
  91. )
  92. if not complete_names:
  93. raise RuntimeError(
  94. "Error while trying to find names to remove to save state dict, but found no suitable name to keep"
  95. f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model"
  96. " since you could be storing much more memory than needed. Please refer to"
  97. " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
  98. " issue."
  99. )
  100. keep_name = sorted(list(complete_names))[0]
  101. # Mechanism to preferentially select keys to keep
  102. # coming from the on-disk file to allow
  103. # loading models saved with a different choice
  104. # of keep_name
  105. preferred = complete_names.difference(discard_names)
  106. if preferred:
  107. keep_name = sorted(list(preferred))[0]
  108. if preferred_names:
  109. preferred = preferred_names.intersection(complete_names)
  110. if preferred:
  111. keep_name = sorted(list(preferred))[0]
  112. for name in sorted(shared):
  113. if name != keep_name:
  114. to_remove[keep_name].append(name)
  115. return to_remove
  116. def save_model(
  117. model: torch.nn.Module,
  118. filename: str,
  119. metadata: Optional[Dict[str, str]] = None,
  120. force_contiguous: bool = True,
  121. ):
  122. """
  123. Saves a given torch model to specified filename.
  124. This method exists specifically to avoid tensor sharing issues which are
  125. not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
  126. Args:
  127. model (`torch.nn.Module`):
  128. The model to save on disk.
  129. filename (`str`):
  130. The filename location to save the file
  131. metadata (`Dict[str, str]`, *optional*):
  132. Extra information to save along with the file.
  133. Some metadata will be added for each dropped tensors.
  134. This information will not be enough to recover the entire
  135. shared structure but might help understanding things
  136. force_contiguous (`boolean`, *optional*, defaults to True):
  137. Forcing the state_dict to be saved as contiguous tensors.
  138. This has no effect on the correctness of the model, but it
  139. could potentially change performance if the layout of the tensor
  140. was chosen specifically for that reason.
  141. """
  142. state_dict = model.state_dict()
  143. to_removes = _remove_duplicate_names(state_dict)
  144. for kept_name, to_remove_group in to_removes.items():
  145. for to_remove in to_remove_group:
  146. if metadata is None:
  147. metadata = {}
  148. if to_remove not in metadata:
  149. # Do not override user data
  150. metadata[to_remove] = kept_name
  151. del state_dict[to_remove]
  152. if force_contiguous:
  153. state_dict = {k: v.contiguous() for k, v in state_dict.items()}
  154. try:
  155. save_file(state_dict, filename, metadata=metadata)
  156. except ValueError as e:
  157. msg = str(e)
  158. msg += " Or use save_model(..., force_contiguous=True), read the docs for potential caveats."
  159. raise ValueError(msg)
  160. def load_model(
  161. model: torch.nn.Module,
  162. filename: Union[str, os.PathLike],
  163. strict: bool = True,
  164. device: Union[str, int] = "cpu",
  165. ) -> Tuple[List[str], List[str]]:
  166. """
  167. Loads a given filename onto a torch model.
  168. This method exists specifically to avoid tensor sharing issues which are
  169. not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
  170. Args:
  171. model (`torch.nn.Module`):
  172. The model to load onto.
  173. filename (`str`, or `os.PathLike`):
  174. The filename location to load the file from.
  175. strict (`bool`, *optional*, defaults to True):
  176. Whether to fail if you're missing keys or having unexpected ones.
  177. When false, the function simply returns missing and unexpected names.
  178. device (`Union[str, int]`, *optional*, defaults to `cpu`):
  179. The device where the tensors need to be located after load.
  180. available options are all regular torch device locations.
  181. Returns:
  182. `(missing, unexpected): (List[str], List[str])`
  183. `missing` are names in the model which were not modified during loading
  184. `unexpected` are names that are on the file, but weren't used during
  185. the load.
  186. """
  187. state_dict = load_file(filename, device=device)
  188. model_state_dict = model.state_dict()
  189. to_removes = _remove_duplicate_names(
  190. model_state_dict, preferred_names=state_dict.keys()
  191. )
  192. missing, unexpected = model.load_state_dict(state_dict, strict=False)
  193. missing = set(missing)
  194. for to_remove_group in to_removes.values():
  195. for to_remove in to_remove_group:
  196. if to_remove not in missing:
  197. unexpected.append(to_remove)
  198. else:
  199. missing.remove(to_remove)
  200. if strict and (missing or unexpected):
  201. missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)])
  202. unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)])
  203. error = f"Error(s) in loading state_dict for {model.__class__.__name__}:"
  204. if missing:
  205. error += f"\n Missing key(s) in state_dict: {missing_keys}"
  206. if unexpected:
  207. error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}"
  208. raise RuntimeError(error)
  209. return missing, unexpected
  210. def save(
  211. tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
  212. ) -> bytes:
  213. """
  214. Saves a dictionary of tensors into raw bytes in safetensors format.
  215. Args:
  216. tensors (`Dict[str, torch.Tensor]`):
  217. The incoming tensors. Tensors need to be contiguous and dense.
  218. metadata (`Dict[str, str]`, *optional*, defaults to `None`):
  219. Optional text only metadata you might want to save in your header.
  220. For instance it can be useful to specify more about the underlying
  221. tensors. This is purely informative and does not affect tensor loading.
  222. Returns:
  223. `bytes`: The raw bytes representing the format
  224. Example:
  225. ```python
  226. from safetensors.torch import save
  227. import torch
  228. tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
  229. byte_data = save(tensors)
  230. ```
  231. """
  232. serialized = serialize(_flatten(tensors), metadata=metadata)
  233. result = bytes(serialized)
  234. return result
  235. def save_file(
  236. tensors: Dict[str, torch.Tensor],
  237. filename: Union[str, os.PathLike],
  238. metadata: Optional[Dict[str, str]] = None,
  239. ):
  240. """
  241. Saves a dictionary of tensors into raw bytes in safetensors format.
  242. Args:
  243. tensors (`Dict[str, torch.Tensor]`):
  244. The incoming tensors. Tensors need to be contiguous and dense.
  245. filename (`str`, or `os.PathLike`)):
  246. The filename we're saving into.
  247. metadata (`Dict[str, str]`, *optional*, defaults to `None`):
  248. Optional text only metadata you might want to save in your header.
  249. For instance it can be useful to specify more about the underlying
  250. tensors. This is purely informative and does not affect tensor loading.
  251. Returns:
  252. `None`
  253. Example:
  254. ```python
  255. from safetensors.torch import save_file
  256. import torch
  257. tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
  258. save_file(tensors, "model.safetensors")
  259. ```
  260. """
  261. serialize_file(_flatten(tensors), filename, metadata=metadata)
  262. def load_file(
  263. filename: Union[str, os.PathLike], device: Union[str, int] = "cpu"
  264. ) -> Dict[str, torch.Tensor]:
  265. """
  266. Loads a safetensors file into torch format.
  267. Args:
  268. filename (`str`, or `os.PathLike`):
  269. The name of the file which contains the tensors
  270. device (`Union[str, int]`, *optional*, defaults to `cpu`):
  271. The device where the tensors need to be located after load.
  272. available options are all regular torch device locations.
  273. Returns:
  274. `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor`
  275. Example:
  276. ```python
  277. from safetensors.torch import load_file
  278. file_path = "./my_folder/bert.safetensors"
  279. loaded = load_file(file_path)
  280. ```
  281. """
  282. result = {}
  283. with safe_open(filename, framework="pt", device=device) as f:
  284. for k in f.offset_keys():
  285. result[k] = f.get_tensor(k)
  286. return result
  287. def load(data: bytes) -> Dict[str, torch.Tensor]:
  288. """
  289. Loads a safetensors file into torch format from pure bytes.
  290. Args:
  291. data (`bytes`):
  292. The content of a safetensors file
  293. Returns:
  294. `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor` on cpu
  295. Example:
  296. ```python
  297. from safetensors.torch import load
  298. file_path = "./my_folder/bert.safetensors"
  299. with open(file_path, "rb") as f:
  300. data = f.read()
  301. loaded = load(data)
  302. ```
  303. """
  304. flat = deserialize(data)
  305. return _view2torch(flat)
  306. # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
  307. _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
  308. _float8_e5m2 = getattr(torch, "float8_e5m2", None)
  309. _float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
  310. _float4_e2m1_x2 = getattr(torch, "float4_e2m1fn_x2", None)
  311. _SIZE = {
  312. torch.int64: 8,
  313. torch.float32: 4,
  314. torch.int32: 4,
  315. torch.bfloat16: 2,
  316. torch.float16: 2,
  317. torch.int16: 2,
  318. torch.uint8: 1,
  319. torch.int8: 1,
  320. torch.bool: 1,
  321. torch.float64: 8,
  322. torch.complex64: 8,
  323. _float8_e4m3fn: 1,
  324. _float8_e5m2: 1,
  325. _float8_e8m0: 1,
  326. _float4_e2m1_x2: 1,
  327. }
  328. if Version(torch.__version__) >= Version("2.3.0"):
  329. _SIZE.update(
  330. {
  331. torch.uint64: 8,
  332. torch.uint32: 4,
  333. torch.uint16: 2,
  334. }
  335. )
  336. _TYPES = {
  337. "F64": torch.float64,
  338. "F32": torch.float32,
  339. "F16": torch.float16,
  340. "BF16": torch.bfloat16,
  341. "I64": torch.int64,
  342. "I32": torch.int32,
  343. "I16": torch.int16,
  344. "I8": torch.int8,
  345. "U8": torch.uint8,
  346. "BOOL": torch.bool,
  347. "F8_E4M3": _float8_e4m3fn,
  348. "F8_E5M2": _float8_e5m2,
  349. "C64": torch.complex64,
  350. }
  351. if Version(torch.__version__) >= Version("2.3.0"):
  352. _TYPES.update(
  353. {
  354. "U64": torch.uint64,
  355. "U32": torch.uint32,
  356. "U16": torch.uint16,
  357. }
  358. )
  359. def _getdtype(dtype_str: str) -> torch.dtype:
  360. return _TYPES[dtype_str]
  361. def _view2torch(safeview) -> Dict[str, torch.Tensor]:
  362. result = {}
  363. for k, v in safeview:
  364. dtype = _getdtype(v["dtype"])
  365. if len(v["data"]) == 0:
  366. # Workaround because frombuffer doesn't accept zero-size tensors
  367. assert any(x == 0 for x in v["shape"])
  368. arr = torch.empty(v["shape"], dtype=dtype)
  369. else:
  370. arr = torch.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
  371. if sys.byteorder == "big":
  372. arr = torch.from_numpy(arr.numpy().byteswap(inplace=False))
  373. result[k] = arr
  374. return result
  375. def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
  376. if tensor.layout != torch.strided:
  377. raise ValueError(
  378. f"You are trying to save a sparse tensor: `{name}` which this library does not support."
  379. " You can make it a dense tensor before saving with `.to_dense()` but be aware this might"
  380. " make a much larger file than needed."
  381. )
  382. if not tensor.is_contiguous():
  383. raise ValueError(
  384. f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you"
  385. " are trying to save tensors which are reference of each other in which case it's recommended to save"
  386. " only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to"
  387. " pack it before saving."
  388. )
  389. if tensor.device.type != "cpu":
  390. # Moving tensor to cpu before saving
  391. tensor = tensor.to("cpu")
  392. import ctypes
  393. import numpy as np
  394. # When shape is empty (scalar), np.prod returns a float
  395. # we need a int for the following calculations
  396. length = int(np.prod(tensor.shape).item())
  397. bytes_per_item = _SIZE[tensor.dtype]
  398. total_bytes = length * bytes_per_item
  399. ptr = tensor.data_ptr()
  400. if ptr == 0:
  401. return b""
  402. newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte))
  403. data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy
  404. if sys.byteorder == "big":
  405. NPDTYPES = {
  406. torch.int64: np.int64,
  407. torch.float32: np.float32,
  408. torch.int32: np.int32,
  409. # XXX: This is ok because both have the same width
  410. torch.bfloat16: np.float16,
  411. torch.float16: np.float16,
  412. torch.int16: np.int16,
  413. torch.uint8: np.uint8,
  414. torch.int8: np.int8,
  415. torch.bool: bool,
  416. torch.float64: np.float64,
  417. # XXX: This is ok because both have the same width and byteswap is a no-op anyway
  418. _float8_e4m3fn: np.uint8,
  419. _float8_e5m2: np.uint8,
  420. torch.complex64: np.complex64,
  421. }
  422. npdtype = NPDTYPES[tensor.dtype]
  423. # Not in place as that would potentially modify a live running model
  424. data = data.view(npdtype).byteswap(inplace=False)
  425. return data.tobytes()
  426. def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]:
  427. if not isinstance(tensors, dict):
  428. raise ValueError(
  429. f"Expected a dict of [str, torch.Tensor] but received {type(tensors)}"
  430. )
  431. invalid_tensors = []
  432. for k, v in tensors.items():
  433. if not isinstance(v, torch.Tensor):
  434. raise ValueError(
  435. f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}"
  436. )
  437. if v.layout != torch.strided:
  438. invalid_tensors.append(k)
  439. if invalid_tensors:
  440. raise ValueError(
  441. f"You are trying to save a sparse tensors: `{invalid_tensors}` which this library does not support."
  442. " You can make it a dense tensor before saving with `.to_dense()` but be aware this might"
  443. " make a much larger file than needed."
  444. )
  445. shared_pointers = _find_shared_tensors(tensors)
  446. failing = []
  447. for names in shared_pointers:
  448. if len(names) > 1:
  449. failing.append(names)
  450. if failing:
  451. raise RuntimeError(
  452. f"""
  453. Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: {failing}.
  454. A potential way to correctly save your model is to use `save_model`.
  455. More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
  456. """
  457. )
  458. return {
  459. k: {
  460. "dtype": str(v.dtype).split(".")[-1],
  461. "shape": v.shape,
  462. "data": _tobytes(v, k),
  463. }
  464. for k, v in tensors.items()
  465. }