_safetensors.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import functools
  2. import operator
  3. from collections import defaultdict
  4. from dataclasses import dataclass, field
  5. from typing import Literal
  6. FILENAME_T = str
  7. TENSOR_NAME_T = str
  8. DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
  9. @dataclass
  10. class TensorInfo:
  11. """Information about a tensor.
  12. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
  13. Attributes:
  14. dtype (`str`):
  15. The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
  16. shape (`list[int]`):
  17. The shape of the tensor.
  18. data_offsets (`tuple[int, int]`):
  19. The offsets of the data in the file as a tuple `[BEGIN, END]`.
  20. parameter_count (`int`):
  21. The number of parameters in the tensor.
  22. """
  23. dtype: DTYPE_T
  24. shape: list[int]
  25. data_offsets: tuple[int, int]
  26. parameter_count: int = field(init=False)
  27. def __post_init__(self) -> None:
  28. # Taken from https://stackoverflow.com/a/13840436
  29. try:
  30. self.parameter_count = functools.reduce(operator.mul, self.shape)
  31. except TypeError:
  32. self.parameter_count = 1 # scalar value has no shape
  33. @dataclass
  34. class SafetensorsFileMetadata:
  35. """Metadata for a Safetensors file hosted on the Hub.
  36. This class is returned by [`parse_safetensors_file_metadata`].
  37. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
  38. Attributes:
  39. metadata (`dict`):
  40. The metadata contained in the file.
  41. tensors (`dict[str, TensorInfo]`):
  42. A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
  43. [`TensorInfo`] object.
  44. parameter_count (`dict[str, int]`):
  45. A map of the number of parameters per data type. Keys are data types and values are the number of parameters
  46. of that data type.
  47. """
  48. metadata: dict[str, str]
  49. tensors: dict[TENSOR_NAME_T, TensorInfo]
  50. parameter_count: dict[DTYPE_T, int] = field(init=False)
  51. def __post_init__(self) -> None:
  52. parameter_count: dict[DTYPE_T, int] = defaultdict(int)
  53. for tensor in self.tensors.values():
  54. parameter_count[tensor.dtype] += tensor.parameter_count
  55. self.parameter_count = dict(parameter_count)
  56. @dataclass
  57. class SafetensorsRepoMetadata:
  58. """Metadata for a Safetensors repo.
  59. A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
  60. model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
  61. This class is returned by [`get_safetensors_metadata`].
  62. For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
  63. Attributes:
  64. metadata (`dict`, *optional*):
  65. The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
  66. models.
  67. sharded (`bool`):
  68. Whether the repo contains a sharded model or not.
  69. weight_map (`dict[str, str]`):
  70. A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
  71. files_metadata (`dict[str, SafetensorsFileMetadata]`):
  72. A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
  73. a [`SafetensorsFileMetadata`] object.
  74. parameter_count (`dict[str, int]`):
  75. A map of the number of parameters per data type. Keys are data types and values are the number of parameters
  76. of that data type.
  77. """
  78. metadata: dict | None
  79. sharded: bool
  80. weight_map: dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
  81. files_metadata: dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
  82. parameter_count: dict[DTYPE_T, int] = field(init=False)
  83. def __post_init__(self) -> None:
  84. parameter_count: dict[DTYPE_T, int] = defaultdict(int)
  85. for file_metadata in self.files_metadata.values():
  86. for dtype, nb_parameters_ in file_metadata.parameter_count.items():
  87. parameter_count[dtype] += nb_parameters_
  88. self.parameter_count = dict(parameter_count)