| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702 |
- # mypy: allow-untyped-defs
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import dataclasses
- import io
- import logging
- import math
- import sys
- from bisect import bisect_right, insort
- from collections import ChainMap
- from typing import Any, cast
- import torch
- from torch.distributed._shard._utils import narrow_tensor_by_index
- from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
- from torch.distributed.checkpoint._nested_dict import (
- FLATTEN_MAPPING,
- flatten_state_dict,
- )
- from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
- from torch.distributed.checkpoint._traverse import set_element
- from torch.distributed.checkpoint.metadata import (
- BytesStorageMetadata,
- ChunkStorageMetadata,
- Metadata,
- MetadataIndex,
- STATE_DICT_TYPE,
- STORAGE_TYPES,
- StorageMeta,
- TensorStorageMetadata,
- )
- from torch.distributed.checkpoint.planner import (
- LoadPlan,
- LoadPlanner,
- ReadItem,
- SavePlan,
- SavePlanner,
- WriteItem,
- WriteItemType,
- )
- from torch.distributed.checkpoint.planner_helpers import (
- _compare_save_plans,
- _contains_usable_plan,
- _create_default_metadata_only_plan,
- _create_read_items,
- _create_write_items,
- _init_state_dict,
- _merge_delta_local_plans,
- )
- from torch.distributed.checkpoint.utils import find_state_dict_object
- from torch.distributed.tensor import DTensor
- from . import _version
- logger: logging.Logger = logging.getLogger(__name__)
- __all__ = [
- "DefaultSavePlanner",
- "DefaultLoadPlanner",
- "create_default_local_load_plan",
- "create_default_global_load_plan",
- "create_default_local_save_plan",
- "create_default_global_save_plan",
- ]
- # TODO: Update docstrings for default_planner.py
- class DefaultSavePlanner(SavePlanner):
- mappings: FLATTEN_MAPPING
- def __init__(
- self,
- flatten_state_dict: bool = True,
- flatten_sharded_tensors: bool = True,
- dedup_replicated_tensors: bool | None = None,
- dedup_save_to_lowest_rank: bool = False,
- enable_plan_caching: bool = False,
- ) -> None:
- self.flatten_state_dict = flatten_state_dict
- self.flatten_sharded_tensors = flatten_sharded_tensors
- self.mappings = {}
- self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank
- if dedup_replicated_tensors is not None:
- logger.warning(
- "DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
- "deprecated, and no longer has any effect. Please remove this argument "
- "from your call."
- )
- self._cached_plans_key: str = self.__class__.__name__
- self._enable_plan_caching = enable_plan_caching
- def set_up_planner(
- self,
- state_dict: STATE_DICT_TYPE,
- storage_meta: StorageMeta | None = None,
- is_coordinator: bool = False,
- ) -> None:
- if self.flatten_state_dict:
- state_dict, self.mappings = flatten_state_dict(state_dict)
- if self.flatten_sharded_tensors:
- state_dict = _flatten_sharded_tensors(state_dict)
- self.state_dict = state_dict
- self.is_coordinator = is_coordinator
- def create_local_plan(self) -> SavePlan:
- plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
- if self.flatten_state_dict:
- plan = dataclasses.replace(plan, planner_data=self.mappings)
- self.plan = plan
- if self._enable_plan_caching:
- # If plans are equal, we can skip sending the plan to the coordinator.
- if (
- self._cached_plans_key in SavePlanner._cached_save_plan
- and _compare_save_plans(
- plan, SavePlanner._cached_save_plan[self._cached_plans_key]
- )
- ):
- logger.info(
- "No change in the local plan. Skipping sending the plan to the coordinator"
- )
- return SavePlan([], usable=False)
- else:
- SavePlanner._cached_save_plan[self._cached_plans_key] = plan
- return self.plan
- def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]:
- return dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
- def _create_global_plan(
- self, all_plans: list[SavePlan]
- ) -> tuple[list[SavePlan], Metadata]:
- deduped_plans = self._dedup_save_plans(all_plans)
- global_plan, metadata = create_default_global_save_plan(deduped_plans)
- if self.flatten_state_dict:
- # | does not work for Python 3.8 or older version.
- # merged_mappings = reduce(
- # lambda x, y: x | y, (p.planner_data for p in global_plan)
- # )
- planner_data_dict = [p.planner_data for p in global_plan]
- merged_mappings = dict(ChainMap(*planner_data_dict))
- metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
- if not _validate_global_plan(global_plan, metadata):
- raise ValueError("Failed to validate global plan")
- return global_plan, metadata
- def _create_global_plan_with_caching(
- self, all_plans: list[SavePlan]
- ) -> tuple[list[SavePlan], list[SavePlan], Metadata]:
- """
- Create global plan with caching.
- Returns a tuple of global_plan_delta, global_plan, metadata.
- """
- global_plan_delta: list[SavePlan] = []
- if self._cached_plans_key not in SavePlanner._cached_all_plans:
- # Case 1: If the plans are not cached, the cache will be hydrated with the
- # all_plans, global_plans (Deduped), and metadata.
- # Cache the original all_plans
- SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans
- global_plan, metadata = self._create_global_plan(all_plans)
- # Cache the deduped and validated global_plan
- SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
- # Cache the metadata
- SavePlanner._cached_metadata[self._cached_plans_key] = metadata
- # If plans are not cached, global_plan delta will be the same as global plan.
- return global_plan, global_plan, metadata
- # Case 2: Plans are cached
- if not _contains_usable_plan(all_plans):
- # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans).
- # Global plan delta will be empty plans to avoid the collective overhead.
- # We can reuse the deduped global plan and metadata from the cache directly.
- global_plan_delta = [SavePlan([], usable=False)] * len(all_plans)
- global_plan = SavePlanner._cached_global_plan[self._cached_plans_key]
- metadata = SavePlanner._cached_metadata[self._cached_plans_key]
- else:
- # Case 2.2: Plans are cached but the local plans have changed.
- # We will merge the changed local plans with the cached local plans.
- # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached.
- # Global plan delta will be created by comparing the new global plan with the cached global plan.
- # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead.
- merged_plans = _merge_delta_local_plans(
- SavePlanner._cached_all_plans[self._cached_plans_key], all_plans
- )
- # Cache the updated local plans
- SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans
- global_plan, metadata = self._create_global_plan(merged_plans)
- if self._cached_plans_key in self._cached_global_plan:
- for cached_plan, new_plan in zip(
- SavePlanner._cached_global_plan[self._cached_plans_key], global_plan
- ):
- if _compare_save_plans(cached_plan, new_plan):
- global_plan_delta.append(SavePlan([], usable=False))
- else:
- global_plan_delta.append(new_plan)
- # Cache the new global plan and the metadata
- SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
- SavePlanner._cached_metadata[self._cached_plans_key] = metadata
- return global_plan_delta, global_plan, metadata
- def create_global_plan(
- self, all_plans: list[SavePlan]
- ) -> tuple[list[SavePlan], Metadata]:
- global_plan_delta: list[SavePlan] = []
- if self._enable_plan_caching:
- # If the plans are cached, we only need to send the global plan delta to be scattered
- # across ranks. Ranks will use the cached final plans instead.
- (
- global_plan_delta,
- global_plan,
- metadata,
- ) = self._create_global_plan_with_caching(all_plans)
- else:
- global_plan, metadata = self._create_global_plan(all_plans)
- # If the caching is not enabled, global delta plan will always be same as the new global plan.
- global_plan_delta = global_plan
- self.global_plan = global_plan
- self.metadata = metadata
- return global_plan_delta, self.metadata
- def _finish_plan_with_caching(self, new_plan: SavePlan) -> SavePlan:
- finished_plan: SavePlan = new_plan
- if not new_plan.usable:
- finished_plan = SavePlanner._cached_final_save_plan[self._cached_plans_key]
- else:
- finished_plan = new_plan
- SavePlanner._cached_final_save_plan[self._cached_plans_key] = new_plan
- return finished_plan
- def finish_plan(self, new_plan: SavePlan) -> SavePlan:
- finished_plan: SavePlan = new_plan
- if self._enable_plan_caching:
- finished_plan = self._finish_plan_with_caching(new_plan)
- self.plan = finished_plan
- return self.plan
- def resolve_data(self, write_item: WriteItem) -> torch.Tensor | io.BytesIO:
- object = self.lookup_object(write_item.index)
- return self.transform_object(write_item, object)
- def lookup_object(self, index: MetadataIndex) -> Any:
- """Extension from the planner interface to make it easy to extend the default planner."""
- return find_state_dict_object(self.state_dict, index)
- def transform_object(self, write_item: WriteItem, object: Any):
- """Extension from the planner interface to make it easy to extend the default planner."""
- if write_item.type == WriteItemType.BYTE_IO:
- bytes = io.BytesIO()
- torch.save(object, bytes)
- object = bytes
- return object
- class DefaultLoadPlanner(LoadPlanner):
- """
- DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
- In particular it adds the following:
- flatten_state_dict: Handle state_dict with nested dicts
- flatten_sharded_tensors: For FSDP in 2D parallel mode
- allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
- """
- original_state_dict: STATE_DICT_TYPE
- mappings: FLATTEN_MAPPING
- def __init__(
- self,
- flatten_state_dict: bool = True,
- flatten_sharded_tensors: bool = True,
- allow_partial_load: bool = False,
- ) -> None:
- self.flatten_state_dict = flatten_state_dict
- self.flatten_sharded_tensors = flatten_sharded_tensors
- self.original_state_dict = {}
- self.mappings = {}
- self.allow_partial_load = allow_partial_load
- def set_up_planner(
- self,
- state_dict: STATE_DICT_TYPE,
- metadata: Metadata | None = None,
- is_coordinator: bool = False,
- ) -> None:
- _init_state_dict(state_dict)
- self.original_state_dict = state_dict
- if self.flatten_sharded_tensors:
- state_dict = _flatten_sharded_tensors(state_dict)
- if self.flatten_state_dict:
- state_dict, self.mappings = flatten_state_dict(state_dict)
- self.state_dict = state_dict
- self.metadata = metadata
- self.is_coordinator = is_coordinator
- def create_local_plan(self) -> LoadPlan:
- if self.metadata is None:
- raise AssertionError("self.metadata is not None")
- if self.flatten_state_dict:
- # To support checkpoints that are saved before v2.4, we have to
- # differentiate if the missing keys are due to old checkpoints.
- # The contracts are:
- # 1. There are 3 cases when we found a missing key.
- # 1.1 Actual missing key, but allow_partial_load is False
- # 1.2 Actual missing key, but allow_partial load is True
- # 1.3 Old checkpoint, but allow_partial_load is False
- # 1.4 Old checkpoint, but allow_partial_load is True
- # 2. If we found a missing key, we first convert the keys back to
- # the key format of v2.3
- # 3. If the previous missing keys are in the v2.3 keys, we assume
- # this is a old checkpoint.
- # 4. Pass the state_dict to `create_default_local_load_plan()`,
- # which has the logic to check missing for allow_partial_load.
- # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
- # `create_default_local_load_plan()`. The logic here is to determine
- # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
- current_keys = set(self.state_dict.keys())
- load_keys = set(self.metadata.state_dict_metadata.keys())
- missing_keys = load_keys - current_keys
- if missing_keys:
- _version._derived_version = "2_3"
- old_state_dict, old_mappings = flatten_state_dict(
- self.original_state_dict
- )
- old_keys = set(old_state_dict.keys())
- if old_keys & missing_keys:
- self.state_dict, self.mappings = old_state_dict, old_mappings
- # _derived_version is only used by flatten_state_dict now.
- # Set it back to None so that later we can save to a new version.
- _version._derived_version = None
- return create_default_local_load_plan(
- self.state_dict, self.metadata, not self.allow_partial_load
- )
- def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
- return create_default_global_load_plan(global_plan)
- def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
- return new_plan
- def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
- if self.flatten_state_dict:
- set_element(
- self.original_state_dict,
- self.mappings[read_item.dest_index.fqn],
- torch.load(value, weights_only=False),
- )
- else:
- self.state_dict[read_item.dest_index.fqn] = torch.load(
- value, weights_only=False
- )
- def resolve_tensor(self, read_item: ReadItem):
- tensor = self.lookup_tensor(read_item.dest_index)
- return self.transform_tensor(read_item, tensor)
- def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
- pass
- def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
- """Extension from the planner interface to make it easy to extend the default planner."""
- return find_state_dict_object(self.state_dict, index)
- def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
- """Extension from the planner interface to make it easy to extend the default planner."""
- return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
- class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
- """
- Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
- Useful for loading in state_dict without first initializing a model, such as
- when converting a DCP checkpoint into a Torch save file.
- . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
- .. warning::
- Because the entire state dict is initialized, It's recommended to only utilize
- this LoadPlanner on a single rank or process to avoid OOM.
- """
- def __init__(self, keys=None, *args, **kwargs):
- self.keys = keys
- super().__init__(*args, **kwargs)
- def _should_include_key(self, key: str, metadata: Metadata) -> bool:
- if self.keys is None:
- return True
- if key in self.keys:
- return True
- unflattened_keys: list[str] = []
- planner_data = metadata.planner_data.get(key)
- for unflattened_key in planner_data:
- if unflattened_keys:
- unflattened_keys.append(
- ".".join([unflattened_keys[-1], str(unflattened_key)])
- )
- else:
- unflattened_keys.append(unflattened_key)
- if any(unflattened_key in self.keys for unflattened_key in unflattened_keys):
- return True
- return False
- def set_up_planner(
- self,
- state_dict: STATE_DICT_TYPE,
- metadata: Metadata | None = None,
- is_coordinator: bool = False,
- ) -> None:
- if state_dict:
- raise AssertionError("not state_dict")
- if metadata is None:
- raise AssertionError("metadata is not None")
- # rebuild the state dict from the metadata
- for k, v in metadata.state_dict_metadata.items():
- if not self._should_include_key(k, metadata):
- continue
- if isinstance(v, TensorStorageMetadata):
- v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
- if metadata.planner_data is not None and k in metadata.planner_data:
- set_element(state_dict, metadata.planner_data[k], v)
- else:
- state_dict[k] = v
- super().set_up_planner(state_dict, metadata, is_coordinator)
- def create_default_local_load_plan(
- state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
- ) -> LoadPlan:
- requests = []
- """
- Create the ``LoadPlan`` used by DefaultLoadPlanner.
- It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
- The default behavior is to match key exactly between state_dict and metadata.
- It handles resharding by issuing multiple read requests against storage in order to match
- load requirements.
- """
- for fqn, obj in state_dict.items():
- # ignore state_dict keys which do not exist in `state_dict` if strict=False
- if fqn not in metadata.state_dict_metadata:
- if strict:
- raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
- else:
- continue
- md = metadata.state_dict_metadata[fqn]
- if (
- isinstance(md, TensorStorageMetadata)
- and getattr(obj, "size", None) is not None
- and md.size != obj.size()
- ):
- raise ValueError(
- f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",
- )
- # Since DTensor supports submesh, adding extra check to ensure _create_read_items()
- # gets called only when the current rank is part of the mesh for the corresponding DTensor.
- if isinstance(obj, DTensor):
- if obj.device_mesh.get_coordinate() is not None:
- requests += _create_read_items(fqn, md, obj)
- else:
- requests += _create_read_items(fqn, md, obj)
- return LoadPlan(requests)
- def create_default_global_load_plan(
- all_plans: list[LoadPlan],
- ) -> list[LoadPlan]:
- """
- Create global load plan used by DefaultLoadPlanner.
- The default load behavior involved no global coordination and this function
- currently doesn't change the local plans.
- """
- return all_plans
- def create_default_local_save_plan(
- state_dict: dict[str, Any], is_coordinator: bool
- ) -> SavePlan:
- """
- Create the ``SavePlan`` used by DefaultSavePlanner.
- On non-coordinator ranks, this function ignores tensors and non-tensor objects,
- only producing writes for ShardedTensor objects.
- On the coordinator rank, produce writes for all values.
- """
- requests = []
- for fqn, obj in state_dict.items():
- # Since DTensor supports submesh, adding extra check to ensure _create_write_items()
- # gets called only when the current rank is part of the mesh for the corresponding DTensor.
- if isinstance(obj, DTensor):
- if obj.device_mesh.get_coordinate() is not None:
- requests += _create_write_items(fqn, obj)
- else:
- # For the plain tensor and non-tensor values, add the request for all
- # the ranks. Coordinator will decides whether to deduplicate the
- # values based on the keys.
- requests += _create_write_items(fqn, obj)
- return SavePlan(requests)
- def create_default_global_save_plan(
- all_plans: list[SavePlan],
- rewrite_index_hints: bool = True,
- ) -> tuple[list[SavePlan], Metadata]:
- """
- Create the global plan and metadata used by DefaultSavePlanner.
- Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.
- The only global planning change is to update index hints in all ``MetadataIndex`` objects if
- ``rewrite_index_hints`` is True.
- """
- md: dict[str, STORAGE_TYPES] = {}
- new_plans = []
- for plan in all_plans:
- new_items = []
- for item in plan.items:
- if item.type != WriteItemType.SHARD:
- if item.index.fqn in md:
- raise AssertionError("item.index.fqn not in md")
- if item.type == WriteItemType.BYTE_IO:
- md[item.index.fqn] = BytesStorageMetadata()
- new_items.append(item)
- else:
- if item.tensor_data is None:
- raise AssertionError("item.tensor_data is not None")
- tensor_md = cast(
- TensorStorageMetadata,
- md.setdefault(
- item.index.fqn,
- TensorStorageMetadata(
- properties=item.tensor_data.properties,
- size=item.tensor_data.size,
- chunks=[],
- ),
- ),
- )
- new_item = item
- if rewrite_index_hints:
- new_index = dataclasses.replace(
- item.index, index=len(tensor_md.chunks)
- )
- new_item = dataclasses.replace(item, index=new_index)
- new_items.append(new_item)
- if item.tensor_data.chunk is None:
- raise AssertionError(f"""
- Cannot create MD for tensor without bounds.
- FQN: {item.index.fqn}
- """)
- tensor_md.chunks.append(item.tensor_data.chunk)
- new_plans.append(dataclasses.replace(plan, items=new_items))
- return (new_plans, Metadata(md))
- def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
- """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``."""
- plan = _create_default_metadata_only_plan(state_dict)
- _, md = create_default_global_save_plan([plan])
- return md
- def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
- """Check if two boxes overlap. Tuples are (offset, lengths)."""
- # For each dim of each shard, check if one shard resides on the other
- # end of second shard with respect to that dim. As an example for a 2D
- # shard, we would check if one shard is above or on the left of the
- # other shard.
- ndims = len(box0.offsets)
- for i in range(ndims):
- if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
- return False
- if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
- return False
- return True
- def _check_box_bounds(
- outer_box_size: torch.Size, inner_box: ChunkStorageMetadata
- ) -> bool:
- for i in range(len(outer_box_size)):
- if inner_box.offsets[i] < 0:
- return False
- if inner_box.sizes[i] < 0:
- return False
- if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
- return False
- return True
- def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool:
- all_good = True
- for key, value in metadata.state_dict_metadata.items():
- if isinstance(value, BytesStorageMetadata):
- continue
- if len(value.size) == 0:
- continue
- chunks = value.chunks
- chunks_volume = 0
- for chunk in chunks:
- # Compute the volume
- if not _check_box_bounds(value.size, chunk):
- logger.warning(
- """
- key:%s has out of bounds chunk:
- tensor-size:%s chunk: %s
- """,
- key,
- value.size,
- chunk,
- )
- all_good = False
- chunks_volume += math.prod(chunk.sizes)
- if len(chunks) > 1:
- dims = len(value.size)
- sweep_dim = max(range(dims), default=0, key=lambda d: value.size[d])
- sorted_indices = sorted(
- range(len(chunks)),
- key=lambda idx: (
- chunks[idx].offsets[sweep_dim],
- *(chunks[idx].offsets[d] for d in range(dims)),
- ),
- )
- active: list[tuple[int, int]] = []
- for idx in sorted_indices:
- current = chunks[idx]
- start = current.offsets[sweep_dim]
- end = start + current.sizes[sweep_dim]
- cutoff = bisect_right(active, (start, sys.maxsize))
- if cutoff:
- del active[:cutoff]
- for _, other_idx in active:
- other = chunks[other_idx]
- if _check_box_overlap(current, other):
- logger.warning(
- "key:%s has overlapping chunks: %s %s",
- key,
- current,
- other,
- )
- all_good = False
- insort(active, (end, idx))
- # Check whether combined chunk cover the whole tensor
- tensor_volume = math.prod(value.size)
- if len(global_plan) > 1 and chunks_volume != tensor_volume:
- logger.warning(
- """
- key:%s invalid fill tensor-volume:
- %s chunks-volume: %s
- """,
- key,
- tensor_volume,
- chunks_volume,
- )
- all_good = False
- return all_good
|