| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- import functools
- import operator
- from collections import defaultdict
- from dataclasses import dataclass, field
- from typing import Literal
- FILENAME_T = str
- TENSOR_NAME_T = str
- DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
- @dataclass
- class TensorInfo:
- """Information about a tensor.
- For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
- Attributes:
- dtype (`str`):
- The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
- shape (`list[int]`):
- The shape of the tensor.
- data_offsets (`tuple[int, int]`):
- The offsets of the data in the file as a tuple `[BEGIN, END]`.
- parameter_count (`int`):
- The number of parameters in the tensor.
- """
- dtype: DTYPE_T
- shape: list[int]
- data_offsets: tuple[int, int]
- parameter_count: int = field(init=False)
- def __post_init__(self) -> None:
- # Taken from https://stackoverflow.com/a/13840436
- try:
- self.parameter_count = functools.reduce(operator.mul, self.shape)
- except TypeError:
- self.parameter_count = 1 # scalar value has no shape
- @dataclass
- class SafetensorsFileMetadata:
- """Metadata for a Safetensors file hosted on the Hub.
- This class is returned by [`parse_safetensors_file_metadata`].
- For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
- Attributes:
- metadata (`dict`):
- The metadata contained in the file.
- tensors (`dict[str, TensorInfo]`):
- A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
- [`TensorInfo`] object.
- parameter_count (`dict[str, int]`):
- A map of the number of parameters per data type. Keys are data types and values are the number of parameters
- of that data type.
- """
- metadata: dict[str, str]
- tensors: dict[TENSOR_NAME_T, TensorInfo]
- parameter_count: dict[DTYPE_T, int] = field(init=False)
- def __post_init__(self) -> None:
- parameter_count: dict[DTYPE_T, int] = defaultdict(int)
- for tensor in self.tensors.values():
- parameter_count[tensor.dtype] += tensor.parameter_count
- self.parameter_count = dict(parameter_count)
- @dataclass
- class SafetensorsRepoMetadata:
- """Metadata for a Safetensors repo.
- A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
- model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
- This class is returned by [`get_safetensors_metadata`].
- For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
- Attributes:
- metadata (`dict`, *optional*):
- The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
- models.
- sharded (`bool`):
- Whether the repo contains a sharded model or not.
- weight_map (`dict[str, str]`):
- A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
- files_metadata (`dict[str, SafetensorsFileMetadata]`):
- A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
- a [`SafetensorsFileMetadata`] object.
- parameter_count (`dict[str, int]`):
- A map of the number of parameters per data type. Keys are data types and values are the number of parameters
- of that data type.
- """
- metadata: dict | None
- sharded: bool
- weight_map: dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
- files_metadata: dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
- parameter_count: dict[DTYPE_T, int] = field(init=False)
- def __post_init__(self) -> None:
- parameter_count: dict[DTYPE_T, int] = defaultdict(int)
- for file_metadata in self.files_metadata.values():
- for dtype, nb_parameters_ in file_metadata.parameter_count.items():
- parameter_count[dtype] += nb_parameters_
- self.parameter_count = dict(parameter_count)
|