| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561 |
- # mypy: allow-untyped-defs
- import io
- import itertools
- from bisect import bisect_right, insort
- from collections.abc import Callable
- from typing import Any, cast
- import torch
- import torch.distributed as dist
- from torch._utils import _get_device_module
- from torch.distributed._shard.metadata import ShardMetadata
- from torch.distributed._shard.sharded_tensor import ShardedTensor
- from torch.distributed.tensor import DTensor
- from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
- from .metadata import (
- BytesStorageMetadata,
- ChunkStorageMetadata,
- MetadataIndex,
- STATE_DICT_TYPE,
- STORAGE_TYPES,
- TensorProperties,
- TensorStorageMetadata,
- )
- from .planner import (
- LoadItemType,
- ReadItem,
- SavePlan,
- TensorWriteData,
- WriteItem,
- WriteItemType,
- )
- from .resharding import (
- _check_shard_metadata_pair_overlap,
- _shards_get_overlap_region_wrt_saved_tensor,
- )
- __all__: list[str] = ["create_read_items_for_chunk_list"]
- def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool:
- """
- Compare the two Save plans and return True if they are equal.
- Args:
- plan (SavePlan): First SavePlan to compare.
- other_plan (SavePlan): Second SavePlan to compare.
- Returns:
- True if the two plans are equal, False otherwise.
- """
- if plan.usable != other_plan.usable:
- return False
- # Both the plans should have the same number of items
- if len(plan.items) != len(other_plan.items):
- return False
- # Both the plans should have the same write items.
- for plan_item, other_plan_item in zip(plan.items, other_plan.items):
- # Write item type should be same
- if plan_item.type != other_plan_item.type:
- return False
- plan_metadata_index = plan_item.index
- other_plan_metadata_index = other_plan_item.index
- # Write item metadata_index should be same
- if (
- plan_metadata_index.fqn != other_plan_metadata_index.fqn
- or plan_metadata_index.offset != other_plan_metadata_index.offset
- or plan_metadata_index.index != other_plan_metadata_index.index
- ):
- return False
- # Write item tensor_data should be present in both the write items plans, if it exists in either of them.
- tensor_data = plan_item.tensor_data
- other_tensor_data = other_plan_item.tensor_data
- if (tensor_data and not other_tensor_data) or (
- not tensor_data and other_tensor_data
- ):
- return False
- if tensor_data and other_tensor_data:
- # Write item tensor_data size should be same
- if tensor_data.size != other_tensor_data.size:
- return False
- # Write item tensor_data chunk should be present in both the write items, if it exists in either of them.
- chunk = tensor_data.chunk
- other_chunk = other_tensor_data.chunk
- if (chunk and not other_chunk) or (not chunk and other_chunk):
- return False
- # Write item tensor_data chunk offsets and sizes should be same
- if chunk and other_chunk:
- if (
- chunk.offsets != other_chunk.offsets
- or chunk.sizes != other_chunk.sizes
- ):
- return False
- return True
- def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool:
- """
- Check if any delta plan is usable, indicating the plan has changed.
- Args:
- delta_plans (List[SavePlan]): A list of delta plans to check.
- Returns:
- True if any delta plan is usable, False otherwise.
- """
- return any(delta_plan and delta_plan.usable for delta_plan in delta_plans)
- def _merge_delta_local_plans(
- cached_plans: list[SavePlan],
- delta_plans: list[SavePlan],
- ) -> list[SavePlan]:
- """
- Merge a list of delta plans into a single plan.
- Args:
- cached_plans (List[SavePlan]): A list of cached plans.
- delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans
- Returns:
- A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan.
- """
- merged_plans = []
- for cached_plan, delta_plan in zip(cached_plans, delta_plans):
- if delta_plan and not delta_plan.usable:
- merged_plans.append(cached_plan)
- else:
- merged_plans.append(delta_plan)
- return merged_plans
- def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
- return ChunkStorageMetadata(
- offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
- )
- def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
- return ChunkStorageMetadata(
- offsets=torch.Size(shard_md.shard_offsets),
- sizes=torch.Size(shard_md.shard_sizes),
- )
- def _sharded_tensor_metadata(
- sharded_tensor: ShardedTensor, shard_md: ShardMetadata
- ) -> TensorWriteData:
- shard_properties = sharded_tensor.metadata().tensor_properties
- properties = TensorProperties(
- dtype=shard_properties.dtype,
- layout=shard_properties.layout,
- requires_grad=shard_properties.requires_grad,
- memory_format=shard_properties.memory_format,
- pin_memory=shard_properties.pin_memory,
- )
- return TensorWriteData(
- chunk=_chunk_for_shard(shard_md),
- properties=properties,
- size=sharded_tensor.metadata().size,
- )
- def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
- sizes, offsets = compute_local_shape_and_global_offset(
- tensor.shape, tensor.device_mesh, tensor.placements
- )
- sizes, offsets = torch.Size(sizes), torch.Size(offsets)
- return WriteItem(
- index=MetadataIndex(fqn, offsets),
- type=WriteItemType.SHARD,
- tensor_data=TensorWriteData(
- chunk=ChunkStorageMetadata(
- offsets=offsets,
- sizes=sizes,
- ),
- properties=TensorProperties.create_from_tensor(tensor.to_local()),
- size=tensor.size(),
- ),
- )
- def _create_write_item_for_shard(
- fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
- ) -> WriteItem:
- offsets = torch.Size(shard_md.shard_offsets)
- return WriteItem(
- index=MetadataIndex(fqn, offsets),
- type=WriteItemType.SHARD,
- tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
- )
- def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
- offsets = torch.Size([0] * len(tensor.size()))
- return WriteItem(
- index=MetadataIndex(fqn, offsets),
- type=WriteItemType.TENSOR,
- tensor_data=TensorWriteData(
- chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
- properties=TensorProperties.create_from_tensor(tensor),
- size=tensor.size(),
- ),
- )
- def _create_write_item_for_bytesio(fqn: str, bytes: Any):
- return WriteItem(
- index=MetadataIndex(fqn),
- type=WriteItemType.BYTE_IO,
- )
- def _create_read_item_for_byteio(
- dest_index, dest_offset, storage_index, storage_offset, length
- ):
- return ReadItem(
- type=LoadItemType.BYTE_IO,
- dest_index=dest_index,
- dest_offsets=torch.Size((dest_offset,)),
- storage_index=storage_index,
- storage_offsets=torch.Size((storage_offset,)),
- lengths=torch.Size((length,)),
- )
- def _create_read_item_for_tensor(
- dest_index, dest_offsets, storage_index, storage_offsets, lengths
- ):
- return ReadItem(
- type=LoadItemType.TENSOR,
- dest_index=dest_index,
- dest_offsets=torch.Size(dest_offsets),
- storage_index=storage_index,
- storage_offsets=torch.Size(storage_offsets),
- lengths=torch.Size(lengths),
- )
- def create_read_items_for_chunk_list(
- fqn: str,
- checkpoint_md: TensorStorageMetadata,
- local_chunks: list[ChunkStorageMetadata],
- ) -> list[ReadItem]:
- """
- Create a list of ``ReadItem`` based on the checkpoint and local chunks.
- This applies the resharding algorithm and computes the reads needed
- to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
- Args:
- fqn (str) : The state_dict FQN to pass to ``ReadItem``.
- checkpoint_md (TensorStorageMetadata): metadata for a given tensor
- from a checkpoint.
- local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
- loaded.
- Returns:
- A list of ``ReadItem`` that will satisfy all input chunks.
- """
- read_items: list[ReadItem] = []
- saved_chunks = checkpoint_md.chunks
- if not local_chunks or not saved_chunks:
- return read_items
- num_dims = len(local_chunks[0].offsets)
- # Find sweep dimension (dimension with largest extent for better pruning)
- sweep_dim = 0
- if num_dims > 1:
- max_size = 0
- for dim in range(num_dims):
- dim_size = max(
- chunk.offsets[dim] + chunk.sizes[dim]
- for chunk in itertools.chain(local_chunks, saved_chunks)
- )
- if dim_size > max_size:
- max_size = dim_size
- sweep_dim = dim
- # Pre-compute bounds: (start, end) for each chunk in sweep dimension
- # For 0-d tensors, use (0, 1) so all chunks overlap in the sweep line
- if num_dims == 0:
- saved_bounds = [(0, 1)] * len(saved_chunks)
- local_bounds = [(0, 1)] * len(local_chunks)
- else:
- saved_bounds = [
- (c.offsets[sweep_dim], c.offsets[sweep_dim] + c.sizes[sweep_dim])
- for c in saved_chunks
- ]
- local_bounds = [
- (c.offsets[sweep_dim], c.offsets[sweep_dim] + c.sizes[sweep_dim])
- for c in local_chunks
- ]
- saved_sorted_indices = sorted(
- range(len(saved_chunks)),
- key=lambda idx: saved_bounds[idx][0],
- )
- local_sorted_indices = sorted(
- range(len(local_chunks)),
- key=lambda idx: local_bounds[idx][0],
- )
- active_saved: list[tuple[int, int]] = []
- saved_ptr = 0
- num_saved = len(saved_sorted_indices)
- for local_idx in local_sorted_indices:
- local_chunk = local_chunks[local_idx]
- local_start, local_end = local_bounds[local_idx]
- cutoff = bisect_right(active_saved, (local_start, -1))
- if cutoff:
- del active_saved[:cutoff]
- while saved_ptr < num_saved:
- storage_idx = saved_sorted_indices[saved_ptr]
- storage_chunk = saved_chunks[storage_idx]
- saved_start, saved_end = saved_bounds[storage_idx]
- if saved_start >= local_end:
- break
- insort(active_saved, (saved_end, storage_idx))
- saved_ptr += 1
- for _, storage_idx in active_saved:
- storage_chunk = saved_chunks[storage_idx]
- if not _check_shard_metadata_pair_overlap(local_chunk, storage_chunk):
- continue
- storage_offsets = []
- dest_offsets = []
- lengths = []
- for (
- _dim,
- offset_for_saved_tensor,
- offset_for_current_tensor,
- length,
- ) in _shards_get_overlap_region_wrt_saved_tensor(
- saved_shard=storage_chunk, current_shard=local_chunk
- ):
- storage_offsets.append(offset_for_saved_tensor)
- dest_offsets.append(offset_for_current_tensor)
- lengths.append(length)
- read_items.append(
- _create_read_item_for_tensor(
- dest_index=MetadataIndex(fqn, local_chunk.offsets, local_idx),
- dest_offsets=dest_offsets,
- storage_index=MetadataIndex(
- fqn, storage_chunk.offsets, storage_idx
- ),
- storage_offsets=storage_offsets,
- lengths=lengths,
- )
- )
- return read_items
- def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
- requests = []
- for fqn, obj in state_dict.items():
- if isinstance(obj, DTensor):
- requests.append(_create_write_items_for_dtensor(fqn, obj))
- elif isinstance(obj, ShardedTensor):
- requests.extend(
- _create_write_item_for_shard(fqn, obj, shard_md)
- for shard_md in obj.metadata().shards_metadata
- )
- elif isinstance(obj, torch.Tensor):
- requests.append(_create_write_item_for_tensor(fqn, obj))
- else:
- requests.append(_create_write_item_for_bytesio(fqn, obj))
- return SavePlan(requests)
- def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
- if hasattr(object, "__create_write_items__"):
- # DTensor implements _Checkpointable
- return object.__create_write_items__(fqn, object)
- elif isinstance(object, ShardedTensor):
- return [
- _create_write_item_for_shard(fqn, object, shard.metadata)
- for shard in object.local_shards()
- ]
- elif isinstance(object, torch.Tensor):
- return [_create_write_item_for_tensor(fqn, object)]
- else:
- return [_create_write_item_for_bytesio(fqn, object)]
- def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
- sizes, offsets = compute_local_shape_and_global_offset(
- tensor.shape, tensor.device_mesh, tensor.placements
- )
- sizes, offsets = torch.Size(sizes), torch.Size(offsets)
- return ChunkStorageMetadata(
- offsets=offsets,
- sizes=sizes,
- )
- def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
- if hasattr(tensor, "__create_chunk_list__"):
- # DTensor implements _Checkpointable
- local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
- elif isinstance(tensor, ShardedTensor):
- local_chunks = [
- _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
- ]
- elif isinstance(tensor, torch.Tensor):
- local_chunks = [_create_chunk_from_tensor(tensor)]
- else:
- raise ValueError(
- "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
- f",but got {type(tensor)}"
- )
- return local_chunks
- def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
- if not isinstance(md, BytesStorageMetadata):
- try:
- local_chunks = _create_chunk_list(obj)
- except ValueError as ex:
- raise ValueError(
- f"Invalid checkpoint metadata for {fqn}, "
- + f"expected BytesStorageMetadata but found {type(md)}",
- ) from ex
- return create_read_items_for_chunk_list(fqn, md, local_chunks)
- else:
- return [
- _create_read_item_for_byteio(
- dest_index=MetadataIndex(fqn),
- dest_offset=0,
- storage_index=MetadataIndex(fqn),
- storage_offset=0,
- length=0,
- )
- ]
- def _init_state_dict(state_dict: dict[str, Any]) -> Any:
- """
- Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
- """
- def dtensor_func(value: DTensor):
- device = getattr(value, "device", None)
- if device == torch.device("meta"):
- device_type = dist.distributed_c10d._get_pg_default_device().type
- device = cast(
- torch.device, _get_device_module(device_type).current_device()
- )
- new_local_tensor = torch.empty_like(value.to_local(), device=device)
- # We need to pass shape and stride explicitly, since DTensor might be
- # sharded unevenly.
- dtensor = DTensor.from_local(
- new_local_tensor,
- device_mesh=value.device_mesh,
- placements=value.placements,
- shape=value.size(),
- stride=value.stride(),
- )
- return dtensor
- else:
- return value
- def sharded_tensor_func(value: Any):
- device = getattr(value, "device", None)
- if device == torch.device("meta"):
- raise RuntimeError(
- f"Found unsupported type {type(value)} for meta device loading."
- )
- else:
- return value
- def tensor_func(value: torch.Tensor):
- device = getattr(value, "device", None)
- if device == torch.device("meta"):
- device_type = dist.distributed_c10d._get_pg_default_device().type
- device = cast(
- torch.device, _get_device_module(device_type).current_device()
- )
- tensor = torch.empty_like(value, device=device)
- return tensor
- else:
- return value
- _iterate_state_dict(
- state_dict,
- dtensor_func,
- sharded_tensor_func,
- tensor_func,
- )
- def _iterate_state_dict(
- iter_object: Any,
- dtensor_func: Callable,
- sharded_tensor_func: Callable,
- tensor_func: Callable,
- ):
- """
- Iterate through the state dict, applying the given functions to each tensor type
- and update the state dict in place.
- Args:
- iter_object (Any): the target state_dict.
- sharded_tensor_func (Callable): the function to apply to ShardedTensor
- dtensor_func (Callable): the function to apply to DTensor
- tensor_func (Callable): the function to apply to Tensor
- # TODO: let state_dict_util._iterate_state_dict() to support in place option
- so we don't need to have two versions of _iterate_state_dict.
- """
- if isinstance(iter_object, DTensor):
- return dtensor_func(iter_object)
- elif isinstance(iter_object, ShardedTensor):
- return sharded_tensor_func(iter_object)
- elif isinstance(iter_object, torch.Tensor):
- return tensor_func(iter_object)
- elif (
- isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
- or iter_object is None
- ):
- return iter_object
- elif isinstance(iter_object, dict):
- for key, value in iter_object.items():
- iter_object[key] = _iterate_state_dict(
- value, dtensor_func, sharded_tensor_func, tensor_func
- )
- return iter_object
- elif isinstance(iter_object, (list, tuple)):
- ret = [
- _iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func)
- for v in iter_object
- ]
- if isinstance(iter_object, tuple):
- ret = tuple(ret) # type: ignore[assignment]
- return ret
|