| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030 |
- # mypy: allow-untyped-defs
- import collections
- import dataclasses
- import io
- import json
- import operator
- import os
- import pickle
- import queue
- import threading
- import uuid
- import warnings
- from abc import ABC, abstractmethod
- from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
- from contextlib import contextmanager
- from dataclasses import dataclass
- from enum import Enum
- from io import UnsupportedOperation
- from pathlib import Path
- from typing import Any, cast, Final, IO
- # introduced as collections.abc.Buffer in Python 3.12
- from typing_extensions import Buffer
- import torch
- from torch import Tensor
- from torch._utils import _get_available_device_type, _get_device_module
- from torch.distributed._shard._utils import narrow_tensor_by_index
- from torch.distributed.checkpoint._extension import (
- ExtensionRegistry,
- StreamTransformExtension,
- )
- from torch.distributed.checkpoint._hf_utils import (
- CUSTOM_METADATA_KEY,
- DCP_VERSION_KEY,
- FORMAT_KEY,
- FORMAT_VALUE,
- HF_DCP_VERSION,
- )
- from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta
- from torch.distributed.checkpoint.planner import (
- LoadItemType,
- LoadPlan,
- LoadPlanner,
- ReadItem,
- SavePlan,
- SavePlanner,
- WriteItem,
- WriteItemType,
- )
- from torch.distributed.checkpoint.staging import BlockingAsyncStager
- from torch.distributed.checkpoint.storage import (
- StorageReader,
- StorageWriter,
- WriteResult,
- )
- from torch.distributed.checkpoint.utils import _create_file_view
- from torch.futures import Future
- __all__ = [
- "FileSystemWriter",
- "FileSystemReader",
- "FileSystem",
- "FileSystemBase",
- "SerializationFormat",
- ]
- _metadata_fn: str = ".metadata"
- CURRENT_DCP_VERSION: Final[str] = "1.0.0"
- @dataclass
- class _StorageInfo:
- """This is the per entry storage info."""
- relative_path: str
- offset: int
- length: int
- transform_descriptors: Sequence[str] | None = None
- def __getstate__(self):
- return {k: v for k, v in self.__dict__.items() if v is not None}
- @dataclass
- class _StoragePrefix:
- prefix: str
- class SerializationFormat(Enum):
- TORCH_SAVE = "torch_save"
- SAFETENSORS = "safetensors"
- DEFAULT_SUFFIX = ".distcp"
- def _generate_uuid() -> str:
- return str(uuid.uuid4())
- class _TensorLoader(ABC):
- @abstractmethod
- def add(self, size: int, obj: object) -> None:
- pass
- @abstractmethod
- def start_loading(self) -> None:
- pass
- @abstractmethod
- def values(self) -> Iterator[tuple[torch.Tensor, object]]:
- pass
- class _SerialCpuLoader(_TensorLoader):
- def __init__(self, resolve_fun: Callable) -> None:
- self.resolve_fun = resolve_fun
- self.items: list[tuple[int, object]] = []
- def add(self, size: int, obj: object) -> None:
- self.items.append((size, obj))
- def start_loading(self) -> None:
- pass
- def values(self) -> Iterator[tuple[torch.Tensor, object]]:
- for _, obj in self.items:
- tensor = self.resolve_fun(obj).detach()
- tensor = tensor.cpu()
- if tensor.untyped_storage().size() != tensor.nbytes:
- # creates a new tensor with minimal storage while preserving memory format.
- tensor = tensor.clone()
- yield (
- tensor,
- obj,
- )
- class _OverlappingCpuLoader(_TensorLoader):
- def __init__(
- self,
- resolve_fun: Callable,
- stream: torch.Stream | None = None,
- inflight_threshhold: int = 1_000_000,
- ) -> None:
- self.resolve_fun = resolve_fun
- self.items: list[tuple[int, object]] = []
- self.inflight_threshhold = inflight_threshhold
- self.in_flight_data = 0
- self.current_items: collections.deque = collections.deque()
- self.idx = 0
- self.started = False
- self.device_type = (
- stream.device_type if stream else _get_available_device_type()
- )
- self.device_module = _get_device_module(self.device_type)
- self.stream = cast(
- torch.cuda.Stream, stream or self.device_module.current_stream()
- )
- if self.stream != self.device_module.current_stream():
- self.stream.wait_stream(self.device_module.current_stream())
- @property
- def _done(self) -> bool:
- return self.idx >= len(self.items)
- def _drain(self) -> list[tuple[torch.Tensor, object]]:
- drained = []
- if self.in_flight_data >= self.inflight_threshhold:
- self.stream.synchronize()
- while self.in_flight_data >= self.inflight_threshhold:
- val = self.current_items.popleft()
- self.in_flight_data -= val[0].numel() * val[0].element_size()
- drained.append(val)
- return drained
- def _refill(self) -> None:
- with self.device_module.stream(self.stream):
- while not self._done and self.in_flight_data < self.inflight_threshhold:
- _, obj = self.items[self.idx]
- self.idx += 1
- tensor = self.resolve_fun(obj).detach()
- if tensor.device.type == self.device_type:
- tensor = tensor.to(device="cpu", non_blocking=True)
- elif tensor.device == torch.device("cpu"):
- if (
- tensor.untyped_storage().size()
- != tensor.numel() * tensor.itemsize
- ):
- # creates a new tensor with minimal storage while preserving memory format.
- tensor = tensor.clone()
- self.current_items.append(
- (
- tensor,
- obj,
- )
- )
- self.in_flight_data += tensor.numel() * tensor.element_size()
- def _finish(self) -> Iterable[tuple[torch.Tensor, object]]:
- if not self._done:
- raise AssertionError("_finish called before all items were processed")
- if len(self.current_items) > 0:
- self.stream.synchronize()
- return self.current_items
- def add(self, size: int, obj: object) -> None:
- if self.started:
- raise RuntimeError("cannot add items after loading started")
- self.items.append((size, obj))
- def start_loading(self) -> None:
- if self.started:
- return
- self.started = True
- self.items.sort(key=operator.itemgetter(0))
- self._refill()
- def values(self) -> Iterator[tuple[torch.Tensor, object]]:
- self.start_loading()
- while not self._done:
- drained = self._drain()
- self._refill()
- yield from drained
- yield from self._finish()
- class _StorageWriterTransforms:
- """
- This is experimental, and will likely move elsewhere in the
- future. It lives here to minimize changes while we are still
- learning and gathering feedback.
- """
- def __init__(
- self, extensions: Sequence[StreamTransformExtension] | None = None
- ) -> None:
- """
- If the extensions arg is None, this means the implementation
- should provide whatever defaults it chooses. An empty
- sequence indicates no extensions should be used. At this
- time, the default extensions sequence is empty.
- """
- self.extensions = () if extensions is None else extensions
- def transform_save_stream(
- self, write_item: WriteItem, raw_stream: io.IOBase
- ) -> tuple[IO[bytes], list[str]]:
- # In order to avoid leaking fds, transformers' close must
- # cascade to wrapped streams, but since this function can
- # append to the raw stream, we can't close the actual stream.
- # So, we use this to put a wrapper around the raw stream's
- # close() to make it a noop, and it gets closed once all files
- # are appended.
- class NoCloseWriter(io.IOBase):
- def __init__(self, raw: io.IOBase):
- self.raw = raw
- def writeable(self) -> bool:
- return True
- def write(self, b: Buffer) -> int:
- return self.raw.write(b)
- def close(self):
- self.flush()
- self.raw.flush()
- # but not close.
- transform_to = cast(IO[bytes], NoCloseWriter(raw_stream))
- for ex in self.extensions:
- transform_to = ex.transform_to(transform_to)
- return (transform_to, [ex.get_descriptor() for ex in reversed(self.extensions)])
- def _item_size(item: WriteItem) -> int:
- size = 1
- if item.tensor_data is None:
- raise AssertionError("WriteItem tensor_data must not be None")
- # can't use math.prod as PT needs to support older python
- for s in item.tensor_data.size:
- size *= s
- dtype = item.tensor_data.properties.dtype
- return size * torch._utils._element_size(dtype)
- def _split_by_size_and_type(bins: int, items: list[WriteItem]) -> list[list[WriteItem]]:
- if bins == 1:
- return [items]
- bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
- tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
- buckets: list[list[WriteItem]] = [[] for _ in range(bins)]
- bucket_sizes = [0 for _ in range(bins)]
- tensor_w.sort(key=_item_size, reverse=True)
- for i, wi in enumerate(bytes_w):
- buckets[i % bins].append(wi)
- for wi in tensor_w:
- # TODO replace with headq
- idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0]
- buckets[idx].append(wi)
- bucket_sizes[idx] += _item_size(wi)
- return buckets
- def _write_item(
- transforms: _StorageWriterTransforms,
- stream: io.IOBase,
- data: io.BytesIO | torch.Tensor,
- write_item: WriteItem,
- storage_key: str,
- serialization_format: SerializationFormat,
- ) -> WriteResult:
- offset = stream.tell()
- (transform_to, transform_descriptors) = transforms.transform_save_stream(
- write_item, stream
- )
- if write_item.type == WriteItemType.BYTE_IO:
- if not isinstance(data, io.BytesIO):
- raise AssertionError("Data must be io.BytesIO for BYTE_IO write items")
- transform_to.write(data.getbuffer())
- else:
- if not isinstance(data, torch.Tensor):
- raise AssertionError(
- "Data must be torch.Tensor for non-BYTE_IO write items"
- )
- if data.device != torch.device("cpu"):
- raise AssertionError("Tensor must be on CPU device")
- if serialization_format == SerializationFormat.TORCH_SAVE:
- torch.save(data, transform_to)
- transform_to.close()
- if serialization_format == SerializationFormat.TORCH_SAVE or isinstance(
- data, io.BytesIO
- ):
- length = stream.tell() - offset
- else:
- length = data.numel() * data.element_size()
- # For consistency with earlier versions, leave this field out of the
- # metadata if there are no extensions.
- info_transform_descriptors = (
- None if len(transform_descriptors) == 0 else transform_descriptors
- )
- return WriteResult(
- index=write_item.index,
- size_in_bytes=length,
- storage_data=_StorageInfo(
- storage_key,
- offset,
- length,
- transform_descriptors=info_transform_descriptors,
- ),
- )
- def _write_files_from_queue(
- create_stream: Callable,
- file_queue: queue.Queue,
- result_queue: queue.Queue,
- planner: SavePlanner,
- transforms: _StorageWriterTransforms,
- inflight_threshhold: int,
- use_fsync: bool,
- thread_count: int,
- serialization_format: SerializationFormat,
- ) -> None:
- try:
- while True:
- file_name, storage_key, write_items = file_queue.get_nowait()
- loader: _TensorLoader
- custom_backend_name = torch._C._get_privateuse1_backend_name()
- custom_device_mod = getattr(torch, custom_backend_name, None)
- # TODO: Using the OverlappingCpuLoader with multiple threads creates significant
- # performance degradation, observed as being related to cuda stream syncs. We
- # should try to fix this and use _OverlappingCpuLoader for all threaded cases
- if (
- thread_count == 1
- and (
- torch.cuda.is_available()
- or (custom_device_mod and custom_device_mod.is_available())
- )
- and inflight_threshhold > 0
- ):
- loader = _OverlappingCpuLoader(
- planner.resolve_data,
- inflight_threshhold=inflight_threshhold,
- )
- else:
- loader = _SerialCpuLoader(
- planner.resolve_data,
- )
- tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
- for write_item in tensor_w:
- loader.add(_item_size(write_item), write_item)
- loader.start_loading()
- bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
- write_results = []
- with create_stream(file_name, "wb") as stream:
- for write_item in bytes_w:
- data = planner.resolve_data(write_item)
- write_results.append(
- _write_item(
- transforms,
- stream,
- data,
- write_item,
- storage_key,
- serialization_format,
- )
- )
- tensor_dict = {}
- metadata_dict = {}
- for tensor, write_item in loader.values():
- if not tensor.is_cpu:
- raise AssertionError("Tensor must be on CPU")
- write_results.append(
- _write_item(
- transforms,
- stream,
- tensor,
- write_item, # type: ignore[arg-type]
- storage_key,
- serialization_format,
- )
- )
- tensor_dict[write_item.index.fqn] = tensor # type: ignore[attr-defined]
- metadata_dict[write_item.index.fqn] = { # type: ignore[attr-defined]
- "saved_offsets": write_item.tensor_data.chunk.offsets # type: ignore[attr-defined]
- }
- if serialization_format == SerializationFormat.SAFETENSORS:
- from safetensors.torch import save # type: ignore[import-not-found]
- stream.write(
- save(
- tensor_dict,
- metadata={
- CUSTOM_METADATA_KEY: json.dumps(metadata_dict),
- DCP_VERSION_KEY: str(HF_DCP_VERSION),
- FORMAT_KEY: FORMAT_VALUE,
- },
- )
- )
- if use_fsync:
- try:
- os.fsync(stream.fileno())
- except (AttributeError, UnsupportedOperation):
- os.sync()
- stream.close()
- result_queue.put(write_results)
- except queue.Empty:
- pass
- class FileSystemBase(ABC):
- @contextmanager
- @abstractmethod
- def create_stream(
- self, path: str | os.PathLike, mode: str
- ) -> Generator[io.IOBase, None, None]: ...
- @abstractmethod
- def concat_path(
- self, path: str | os.PathLike, suffix: str
- ) -> str | os.PathLike: ...
- @abstractmethod
- def rename(self, path: str | os.PathLike, new_path: str | os.PathLike) -> None: ...
- @abstractmethod
- def init_path(self, path: str | os.PathLike) -> str | os.PathLike: ...
- @abstractmethod
- def mkdir(self, path: str | os.PathLike) -> None: ...
- @classmethod
- @abstractmethod
- def validate_checkpoint_id(cls, checkpoint_id: str | os.PathLike) -> bool: ...
- @abstractmethod
- def exists(self, path: str | os.PathLike) -> bool: ...
- @abstractmethod
- def rm_file(self, path: str | os.PathLike) -> None: ...
- class FileSystem(FileSystemBase):
- @contextmanager
- def create_stream(
- self, path: str | os.PathLike, mode: str
- ) -> Generator[io.IOBase, None, None]:
- if not isinstance(path, Path):
- path = Path(path)
- with path.open(mode) as stream:
- yield cast(io.IOBase, stream)
- def concat_path(self, path: str | os.PathLike, suffix: str) -> str | os.PathLike:
- if not isinstance(path, Path):
- path = Path(path)
- return path / suffix
- def init_path(self, path: str | os.PathLike) -> str | os.PathLike:
- if not isinstance(path, Path):
- path = Path(path)
- return path
- def rename(self, path: str | os.PathLike, new_path: str | os.PathLike) -> None:
- if not isinstance(path, Path):
- path = Path(path)
- path.rename(cast(Path, new_path))
- def mkdir(self, path: str | os.PathLike) -> None:
- if not isinstance(path, Path):
- path = Path(path)
- path.mkdir(parents=True, exist_ok=True)
- @classmethod
- def validate_checkpoint_id(cls, checkpoint_id: str | os.PathLike) -> bool:
- if isinstance(checkpoint_id, Path):
- return True
- if "://" in str(checkpoint_id):
- return False
- for p in Path(checkpoint_id).parents:
- if p.exists() and os.access(str(p), os.W_OK):
- return True
- return False
- def exists(self, path: str | os.PathLike) -> bool:
- if not isinstance(path, Path):
- path = Path(path)
- return path.exists()
- def rm_file(self, path: str | os.PathLike) -> None:
- if not isinstance(path, Path):
- path = Path(path)
- path.unlink()
- def ls(self, path: str | os.PathLike) -> list[str]:
- if not isinstance(path, Path):
- path = Path(path)
- return [str(p) for p in path.iterdir()]
- class _FileSystemWriter(StorageWriter):
- """
- Basic implementation of StorageWriter using file IO.
- This implementation makes the following assumptions and simplifications:
- * The checkpoint path is an empty or non-existing directory.
- * File creation is atomic
- The checkpoint consist of one file per write request plus
- a `.metadata` file with the serialized metadata.
- """
- def __init__(
- self,
- path: str | os.PathLike,
- single_file_per_rank: bool = True,
- sync_files: bool = True,
- thread_count: int = 1,
- per_thread_copy_ahead: int = 10_000_000,
- overwrite: bool = True,
- _extensions: Sequence[StreamTransformExtension] | None = None,
- serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- """
- Initialize the writer pointing to `path`.
- Args:
- path: directory where the checkpoint will be written to.
- single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
- sync_files : force files to be synced to permanent storage. Default to True.
- thread_count: Number of IO threads to use to write. Default to 1.
- per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
- overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
- _extensions: Extensions to apply to output streams (EXPERIMENTAL)
- N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
- """
- super().__init__()
- self.fs = FileSystem()
- self.path = self.fs.init_path(path)
- self.single_file_per_rank = single_file_per_rank
- self.sync_files = sync_files
- self.thread_count = thread_count
- self.per_thread_copy_ahead = per_thread_copy_ahead
- self.save_id = _generate_uuid()
- self.overwrite = overwrite
- self.transforms = _StorageWriterTransforms(_extensions)
- self.serialization_format = serialization_format
- self.rank: int | None = None
- self.use_collectives: bool = True
- def reset(self, checkpoint_id: str | os.PathLike | None = None) -> None:
- if checkpoint_id:
- self.path = self.fs.init_path(checkpoint_id)
- self.save_id = _generate_uuid()
- def set_up_storage_writer(
- self, is_coordinator: bool, *args: Any, **kwargs: Any
- ) -> None:
- self.rank = kwargs.get("rank")
- self.use_collectives = kwargs.get("use_collectives", True)
- def _metadata_exists(self) -> bool:
- if self.use_collectives:
- # A global checkpoint metadata file
- metadata_path = self._get_metadata_path(rank=None)
- else:
- # A rank 0 specific metadata file if every rank has written its own metadata
- # Just looking for lowest rank metadata file is sufficient
- metadata_path = self._get_metadata_path(rank=0)
- return self.fs.exists(metadata_path)
- def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
- self.fs.mkdir(self.path)
- if self._metadata_exists():
- if self.overwrite:
- warnings.warn(
- f"Detected an existing checkpoint in {self.path}, overwriting since {self.overwrite=}."
- " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to"
- " maintain this functionality or False to raise when an existing checkpoint is found.",
- stacklevel=2,
- )
- else:
- raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.")
- if self.rank is not None and not self.use_collectives:
- plan = dataclasses.replace(
- plan, storage_data=_StoragePrefix(f"__{self.rank}_")
- )
- return plan
- def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
- new_plans = [
- dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
- if plan.storage_data is None
- else plan
- for i, plan in enumerate(plans)
- ]
- return new_plans
- def write_data(
- self,
- plan: SavePlan,
- planner: SavePlanner,
- ) -> Future[list[WriteResult]]:
- storage_plan: _StoragePrefix = plan.storage_data
- file_count = 0
- def gen_file():
- nonlocal file_count
- file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
- file_count += 1
- return file_name
- file_queue: queue.Queue = queue.Queue()
- if self.single_file_per_rank:
- for bucket in _split_by_size_and_type(self.thread_count, plan.items):
- file_name = gen_file()
- path = self.fs.concat_path(self.path, file_name)
- file_queue.put((path, file_name, bucket))
- else:
- for item in plan.items:
- file_name = gen_file()
- path = self.fs.concat_path(self.path, file_name)
- file_queue.put((path, file_name, [item]))
- return self._write_data(planner, file_queue)
- def _write_data(
- self,
- planner: SavePlanner,
- file_queue: queue.Queue,
- ) -> Future[list[WriteResult]]:
- result_queue: queue.Queue = queue.Queue()
- threads = []
- for _ in range(1, self.thread_count):
- t = threading.Thread(
- target=_write_files_from_queue,
- args=(
- self.fs.create_stream,
- file_queue,
- result_queue,
- planner,
- self.transforms,
- self.per_thread_copy_ahead,
- self.sync_files,
- self.thread_count,
- self.serialization_format,
- ),
- )
- t.start()
- threads.append(t)
- _write_files_from_queue(
- create_stream=self.fs.create_stream,
- file_queue=file_queue,
- result_queue=result_queue,
- planner=planner,
- transforms=self.transforms,
- inflight_threshhold=self.per_thread_copy_ahead,
- use_fsync=self.sync_files,
- thread_count=self.thread_count,
- serialization_format=self.serialization_format,
- )
- for t in threads:
- t.join()
- res = []
- try:
- while True:
- res += result_queue.get_nowait()
- except queue.Empty:
- fut: Future[list[WriteResult]] = Future()
- fut.set_result(res)
- return fut
- def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
- metadata = dataclasses.replace(metadata, version=CURRENT_DCP_VERSION)
- storage_md = {}
- for wr_list in results:
- storage_md.update({wr.index: wr.storage_data for wr in wr_list})
- metadata.storage_data = storage_md
- metadata.storage_meta = self.storage_meta()
- tmp_filename = (
- f"__{self.rank}{_metadata_fn}.tmp"
- if not self.use_collectives and self.rank is not None
- else f"{_metadata_fn}.tmp"
- )
- tmp_path = cast(Path, self.fs.concat_path(self.path, tmp_filename))
- with self.fs.create_stream(tmp_path, "wb") as metadata_file:
- pickle.dump(metadata, metadata_file)
- if self.sync_files:
- try:
- os.fsync(metadata_file.fileno())
- except (AttributeError, UnsupportedOperation):
- os.sync()
- # delete in-case other checkpoints were present.
- if not self.use_collectives and self.rank is not None:
- metadata_path = self._get_metadata_path(self.rank)
- else:
- metadata_path = self._get_metadata_path()
- if self.fs.exists(metadata_path):
- self.fs.rm_file(metadata_path)
- self.fs.rename(tmp_path, metadata_path)
- def storage_meta(self) -> StorageMeta | None:
- return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id)
- def _get_metadata_path(self, rank: int | None = None) -> os.PathLike:
- filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}"
- return cast(Path, self.fs.concat_path(self.path, filename))
- @property
- def checkpoint_id(self) -> str | os.PathLike:
- """
- return the checkpoint_id that will be used to save the checkpoint.
- """
- return self.path
- @classmethod
- def validate_checkpoint_id(cls, checkpoint_id: str | os.PathLike) -> bool:
- return FileSystem.validate_checkpoint_id(checkpoint_id)
- class _StorageReaderTransforms:
- """
- This is experimental, and will likely move elsewhere in the
- future. It lives here to minimize changes while we are still
- learning and gathering feedback.
- """
- def __init__(self, extension_registry: ExtensionRegistry | None = None) -> None:
- self.extension_registry = (
- ExtensionRegistry() if extension_registry is None else extension_registry
- )
- def transform_load_stream(
- self,
- read_item: ReadItem,
- transform_descriptors: Sequence[str],
- raw_stream: IO[bytes],
- ) -> IO[bytes]:
- extensions = self.extension_registry.from_descriptor_list(transform_descriptors)
- transform_from = raw_stream
- for ex in extensions:
- if isinstance(ex, StreamTransformExtension):
- transform_from = ex.transform_from(transform_from)
- return transform_from
- class FileSystemReader(StorageReader):
- def __init__(
- self,
- path: str | os.PathLike,
- _extension_registry: ExtensionRegistry | None = None, # EXPERIMENTAL
- ) -> None:
- super().__init__()
- self.fs = FileSystem()
- self.path = self.fs.init_path(path)
- self.storage_data: dict[Any, Any] = {}
- self.load_id = _generate_uuid()
- self.transforms = _StorageReaderTransforms(_extension_registry)
- self.rank = None
- self.use_collectives = True
- def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]:
- return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length))
- def reset(self, checkpoint_id: str | os.PathLike | None = None) -> None:
- self.storage_data = {}
- if checkpoint_id:
- self.path = self.fs.init_path(checkpoint_id)
- self.load_id = _generate_uuid()
- def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
- # group requests by file
- per_file: dict[str, list[ReadItem]] = {}
- for read_item in plan.items:
- item_md: _StorageInfo = self.storage_data[read_item.storage_index]
- path = item_md.relative_path
- per_file.setdefault(path, []).append(read_item)
- for relative_path, reqs in per_file.items():
- new_path = self.fs.concat_path(self.path, relative_path)
- with self.fs.create_stream(new_path, "rb") as stream:
- # TODO sort by offset and cache the reading
- for req in reqs:
- item_md = self.storage_data[req.storage_index]
- file_slice = self._slice_file(stream, item_md)
- transform_from = self.transforms.transform_load_stream(
- req,
- # This field wasn't present in older
- # implementations so provide a fallback.
- item_md.transform_descriptors or (),
- file_slice,
- )
- if req.type == LoadItemType.BYTE_IO:
- read_bytes = io.BytesIO(transform_from.read(-1))
- read_bytes.seek(0)
- planner.load_bytes(req, read_bytes)
- else:
- if transform_from.seekable():
- seekable = transform_from
- else:
- # torch.load requires a seekable input, so read the transform
- # stream now and store the output if needed
- seekable = io.BytesIO(transform_from.read(-1))
- seekable.seek(0)
- tensor = cast(
- Tensor,
- torch.load(
- seekable,
- map_location="cpu",
- weights_only=True,
- ),
- )
- tensor = narrow_tensor_by_index(
- tensor, req.storage_offsets, req.lengths
- )
- target_tensor = planner.resolve_tensor(req).detach()
- if target_tensor.size() != tensor.size():
- raise AssertionError(
- f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
- )
- target_tensor.copy_(tensor)
- planner.commit_tensor(req, target_tensor)
- fut: Future = Future()
- fut.set_result(None)
- return fut
- def _get_metadata_path(self, rank: int | None = None) -> os.PathLike:
- filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}"
- return cast(Path, self.fs.concat_path(self.path, filename))
- # Implementing the abstract function in StorageReader
- def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata:
- rank = kwargs.get("rank")
- path = self._get_metadata_path(rank)
- with self.fs.create_stream(path, "rb") as metadata_file:
- metadata = pickle.load(metadata_file)
- if getattr(metadata, "storage_meta", None) is None:
- metadata.storage_meta = StorageMeta()
- metadata.storage_meta.load_id = self.load_id
- return metadata
- def set_up_storage_reader(
- self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any
- ) -> None:
- self.storage_data = metadata.storage_data
- self.rank = kwargs.get("rank")
- self.use_collectives = kwargs.get("use_collectives", True)
- if self.storage_data is None:
- raise AssertionError("storage_data must not be None in metadata")
- def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
- return plan
- def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
- return plans
- @property
- def checkpoint_id(self) -> str | os.PathLike:
- """
- return the checkpoint_id that will be used to load the checkpoint.
- """
- return self.path
- @classmethod
- def validate_checkpoint_id(cls, checkpoint_id: str | os.PathLike) -> bool:
- return FileSystem.validate_checkpoint_id(checkpoint_id)
- class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
- """
- Basic implementation of StorageWriter using file IO.
- This implementation makes the following assumptions and simplifications:
- * The checkpoint path is an empty or non-existing directory.
- * File creation is atomic
- The checkpoint consist of one file per write request plus
- a global `.metadata` file with the serialized metadata if rank coordination is enabled.
- a rank local `__{rank}.metadata` file with the serialized metadata if rank coordination is NOT enabled.
- """
- def __init__(
- self,
- path: str | os.PathLike,
- single_file_per_rank: bool = True,
- sync_files: bool = True,
- thread_count: int = 1,
- per_thread_copy_ahead: int = 10_000_000,
- cache_staged_state_dict: bool = False,
- overwrite: bool = True,
- _extensions: Sequence[StreamTransformExtension] | None = None,
- serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
- ) -> None:
- """
- Initialize the writer pointing to `path`.
- Args:
- path: directory where the checkpoint will be written to.
- single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
- sync_files : force files to be synced to permanent storage. Default to True.
- thread_count: Number of IO threads to use to write. Default to 1.
- per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
- cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
- at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
- that the stager is maintained and reused for multiple dcp.async_save calls. Default to False.
- overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
- _extensions: Extensions to apply to output streams (EXPERIMENTAL)
- N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
- """
- _FileSystemWriter.__init__(
- self,
- path=path,
- single_file_per_rank=single_file_per_rank,
- sync_files=sync_files,
- thread_count=thread_count,
- per_thread_copy_ahead=per_thread_copy_ahead,
- overwrite=overwrite,
- _extensions=_extensions,
- serialization_format=serialization_format,
- )
- BlockingAsyncStager.__init__(
- self,
- cache_staged_state_dict=cache_staged_state_dict,
- )
- def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
- """Override of AsyncStager.stage"""
- # in the async case, the state dict is already on CPU, so maintaining this
- # buffer makes no sense
- self.per_thread_copy_ahead = 0
- return super().stage(state_dict)
|