default_planner.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import dataclasses
  4. import io
  5. import logging
  6. import math
  7. import sys
  8. from bisect import bisect_right, insort
  9. from collections import ChainMap
  10. from typing import Any, cast
  11. import torch
  12. from torch.distributed._shard._utils import narrow_tensor_by_index
  13. from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
  14. from torch.distributed.checkpoint._nested_dict import (
  15. FLATTEN_MAPPING,
  16. flatten_state_dict,
  17. )
  18. from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
  19. from torch.distributed.checkpoint._traverse import set_element
  20. from torch.distributed.checkpoint.metadata import (
  21. BytesStorageMetadata,
  22. ChunkStorageMetadata,
  23. Metadata,
  24. MetadataIndex,
  25. STATE_DICT_TYPE,
  26. STORAGE_TYPES,
  27. StorageMeta,
  28. TensorStorageMetadata,
  29. )
  30. from torch.distributed.checkpoint.planner import (
  31. LoadPlan,
  32. LoadPlanner,
  33. ReadItem,
  34. SavePlan,
  35. SavePlanner,
  36. WriteItem,
  37. WriteItemType,
  38. )
  39. from torch.distributed.checkpoint.planner_helpers import (
  40. _compare_save_plans,
  41. _contains_usable_plan,
  42. _create_default_metadata_only_plan,
  43. _create_read_items,
  44. _create_write_items,
  45. _init_state_dict,
  46. _merge_delta_local_plans,
  47. )
  48. from torch.distributed.checkpoint.utils import find_state_dict_object
  49. from torch.distributed.tensor import DTensor
  50. from . import _version
  51. logger: logging.Logger = logging.getLogger(__name__)
  52. __all__ = [
  53. "DefaultSavePlanner",
  54. "DefaultLoadPlanner",
  55. "create_default_local_load_plan",
  56. "create_default_global_load_plan",
  57. "create_default_local_save_plan",
  58. "create_default_global_save_plan",
  59. ]
  60. # TODO: Update docstrings for default_planner.py
  61. class DefaultSavePlanner(SavePlanner):
  62. mappings: FLATTEN_MAPPING
  63. def __init__(
  64. self,
  65. flatten_state_dict: bool = True,
  66. flatten_sharded_tensors: bool = True,
  67. dedup_replicated_tensors: bool | None = None,
  68. dedup_save_to_lowest_rank: bool = False,
  69. enable_plan_caching: bool = False,
  70. ) -> None:
  71. self.flatten_state_dict = flatten_state_dict
  72. self.flatten_sharded_tensors = flatten_sharded_tensors
  73. self.mappings = {}
  74. self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank
  75. if dedup_replicated_tensors is not None:
  76. logger.warning(
  77. "DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
  78. "deprecated, and no longer has any effect. Please remove this argument "
  79. "from your call."
  80. )
  81. self._cached_plans_key: str = self.__class__.__name__
  82. self._enable_plan_caching = enable_plan_caching
  83. def set_up_planner(
  84. self,
  85. state_dict: STATE_DICT_TYPE,
  86. storage_meta: StorageMeta | None = None,
  87. is_coordinator: bool = False,
  88. ) -> None:
  89. if self.flatten_state_dict:
  90. state_dict, self.mappings = flatten_state_dict(state_dict)
  91. if self.flatten_sharded_tensors:
  92. state_dict = _flatten_sharded_tensors(state_dict)
  93. self.state_dict = state_dict
  94. self.is_coordinator = is_coordinator
  95. def create_local_plan(self) -> SavePlan:
  96. plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
  97. if self.flatten_state_dict:
  98. plan = dataclasses.replace(plan, planner_data=self.mappings)
  99. self.plan = plan
  100. if self._enable_plan_caching:
  101. # If plans are equal, we can skip sending the plan to the coordinator.
  102. if (
  103. self._cached_plans_key in SavePlanner._cached_save_plan
  104. and _compare_save_plans(
  105. plan, SavePlanner._cached_save_plan[self._cached_plans_key]
  106. )
  107. ):
  108. logger.info(
  109. "No change in the local plan. Skipping sending the plan to the coordinator"
  110. )
  111. return SavePlan([], usable=False)
  112. else:
  113. SavePlanner._cached_save_plan[self._cached_plans_key] = plan
  114. return self.plan
  115. def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]:
  116. return dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
  117. def _create_global_plan(
  118. self, all_plans: list[SavePlan]
  119. ) -> tuple[list[SavePlan], Metadata]:
  120. deduped_plans = self._dedup_save_plans(all_plans)
  121. global_plan, metadata = create_default_global_save_plan(deduped_plans)
  122. if self.flatten_state_dict:
  123. # | does not work for Python 3.8 or older version.
  124. # merged_mappings = reduce(
  125. # lambda x, y: x | y, (p.planner_data for p in global_plan)
  126. # )
  127. planner_data_dict = [p.planner_data for p in global_plan]
  128. merged_mappings = dict(ChainMap(*planner_data_dict))
  129. metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
  130. if not _validate_global_plan(global_plan, metadata):
  131. raise ValueError("Failed to validate global plan")
  132. return global_plan, metadata
  133. def _create_global_plan_with_caching(
  134. self, all_plans: list[SavePlan]
  135. ) -> tuple[list[SavePlan], list[SavePlan], Metadata]:
  136. """
  137. Create global plan with caching.
  138. Returns a tuple of global_plan_delta, global_plan, metadata.
  139. """
  140. global_plan_delta: list[SavePlan] = []
  141. if self._cached_plans_key not in SavePlanner._cached_all_plans:
  142. # Case 1: If the plans are not cached, the cache will be hydrated with the
  143. # all_plans, global_plans (Deduped), and metadata.
  144. # Cache the original all_plans
  145. SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans
  146. global_plan, metadata = self._create_global_plan(all_plans)
  147. # Cache the deduped and validated global_plan
  148. SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
  149. # Cache the metadata
  150. SavePlanner._cached_metadata[self._cached_plans_key] = metadata
  151. # If plans are not cached, global_plan delta will be the same as global plan.
  152. return global_plan, global_plan, metadata
  153. # Case 2: Plans are cached
  154. if not _contains_usable_plan(all_plans):
  155. # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans).
  156. # Global plan delta will be empty plans to avoid the collective overhead.
  157. # We can reuse the deduped global plan and metadata from the cache directly.
  158. global_plan_delta = [SavePlan([], usable=False)] * len(all_plans)
  159. global_plan = SavePlanner._cached_global_plan[self._cached_plans_key]
  160. metadata = SavePlanner._cached_metadata[self._cached_plans_key]
  161. else:
  162. # Case 2.2: Plans are cached but the local plans have changed.
  163. # We will merge the changed local plans with the cached local plans.
  164. # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached.
  165. # Global plan delta will be created by comparing the new global plan with the cached global plan.
  166. # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead.
  167. merged_plans = _merge_delta_local_plans(
  168. SavePlanner._cached_all_plans[self._cached_plans_key], all_plans
  169. )
  170. # Cache the updated local plans
  171. SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans
  172. global_plan, metadata = self._create_global_plan(merged_plans)
  173. if self._cached_plans_key in self._cached_global_plan:
  174. for cached_plan, new_plan in zip(
  175. SavePlanner._cached_global_plan[self._cached_plans_key], global_plan
  176. ):
  177. if _compare_save_plans(cached_plan, new_plan):
  178. global_plan_delta.append(SavePlan([], usable=False))
  179. else:
  180. global_plan_delta.append(new_plan)
  181. # Cache the new global plan and the metadata
  182. SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
  183. SavePlanner._cached_metadata[self._cached_plans_key] = metadata
  184. return global_plan_delta, global_plan, metadata
  185. def create_global_plan(
  186. self, all_plans: list[SavePlan]
  187. ) -> tuple[list[SavePlan], Metadata]:
  188. global_plan_delta: list[SavePlan] = []
  189. if self._enable_plan_caching:
  190. # If the plans are cached, we only need to send the global plan delta to be scattered
  191. # across ranks. Ranks will use the cached final plans instead.
  192. (
  193. global_plan_delta,
  194. global_plan,
  195. metadata,
  196. ) = self._create_global_plan_with_caching(all_plans)
  197. else:
  198. global_plan, metadata = self._create_global_plan(all_plans)
  199. # If the caching is not enabled, global delta plan will always be same as the new global plan.
  200. global_plan_delta = global_plan
  201. self.global_plan = global_plan
  202. self.metadata = metadata
  203. return global_plan_delta, self.metadata
  204. def _finish_plan_with_caching(self, new_plan: SavePlan) -> SavePlan:
  205. finished_plan: SavePlan = new_plan
  206. if not new_plan.usable:
  207. finished_plan = SavePlanner._cached_final_save_plan[self._cached_plans_key]
  208. else:
  209. finished_plan = new_plan
  210. SavePlanner._cached_final_save_plan[self._cached_plans_key] = new_plan
  211. return finished_plan
  212. def finish_plan(self, new_plan: SavePlan) -> SavePlan:
  213. finished_plan: SavePlan = new_plan
  214. if self._enable_plan_caching:
  215. finished_plan = self._finish_plan_with_caching(new_plan)
  216. self.plan = finished_plan
  217. return self.plan
  218. def resolve_data(self, write_item: WriteItem) -> torch.Tensor | io.BytesIO:
  219. object = self.lookup_object(write_item.index)
  220. return self.transform_object(write_item, object)
  221. def lookup_object(self, index: MetadataIndex) -> Any:
  222. """Extension from the planner interface to make it easy to extend the default planner."""
  223. return find_state_dict_object(self.state_dict, index)
  224. def transform_object(self, write_item: WriteItem, object: Any):
  225. """Extension from the planner interface to make it easy to extend the default planner."""
  226. if write_item.type == WriteItemType.BYTE_IO:
  227. bytes = io.BytesIO()
  228. torch.save(object, bytes)
  229. object = bytes
  230. return object
  231. class DefaultLoadPlanner(LoadPlanner):
  232. """
  233. DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
  234. In particular it adds the following:
  235. flatten_state_dict: Handle state_dict with nested dicts
  236. flatten_sharded_tensors: For FSDP in 2D parallel mode
  237. allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
  238. """
  239. original_state_dict: STATE_DICT_TYPE
  240. mappings: FLATTEN_MAPPING
  241. def __init__(
  242. self,
  243. flatten_state_dict: bool = True,
  244. flatten_sharded_tensors: bool = True,
  245. allow_partial_load: bool = False,
  246. ) -> None:
  247. self.flatten_state_dict = flatten_state_dict
  248. self.flatten_sharded_tensors = flatten_sharded_tensors
  249. self.original_state_dict = {}
  250. self.mappings = {}
  251. self.allow_partial_load = allow_partial_load
  252. def set_up_planner(
  253. self,
  254. state_dict: STATE_DICT_TYPE,
  255. metadata: Metadata | None = None,
  256. is_coordinator: bool = False,
  257. ) -> None:
  258. _init_state_dict(state_dict)
  259. self.original_state_dict = state_dict
  260. if self.flatten_sharded_tensors:
  261. state_dict = _flatten_sharded_tensors(state_dict)
  262. if self.flatten_state_dict:
  263. state_dict, self.mappings = flatten_state_dict(state_dict)
  264. self.state_dict = state_dict
  265. self.metadata = metadata
  266. self.is_coordinator = is_coordinator
  267. def create_local_plan(self) -> LoadPlan:
  268. if self.metadata is None:
  269. raise AssertionError("self.metadata is not None")
  270. if self.flatten_state_dict:
  271. # To support checkpoints that are saved before v2.4, we have to
  272. # differentiate if the missing keys are due to old checkpoints.
  273. # The contracts are:
  274. # 1. There are 3 cases when we found a missing key.
  275. # 1.1 Actual missing key, but allow_partial_load is False
  276. # 1.2 Actual missing key, but allow_partial load is True
  277. # 1.3 Old checkpoint, but allow_partial_load is False
  278. # 1.4 Old checkpoint, but allow_partial_load is True
  279. # 2. If we found a missing key, we first convert the keys back to
  280. # the key format of v2.3
  281. # 3. If the previous missing keys are in the v2.3 keys, we assume
  282. # this is a old checkpoint.
  283. # 4. Pass the state_dict to `create_default_local_load_plan()`,
  284. # which has the logic to check missing for allow_partial_load.
  285. # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
  286. # `create_default_local_load_plan()`. The logic here is to determine
  287. # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
  288. current_keys = set(self.state_dict.keys())
  289. load_keys = set(self.metadata.state_dict_metadata.keys())
  290. missing_keys = load_keys - current_keys
  291. if missing_keys:
  292. _version._derived_version = "2_3"
  293. old_state_dict, old_mappings = flatten_state_dict(
  294. self.original_state_dict
  295. )
  296. old_keys = set(old_state_dict.keys())
  297. if old_keys & missing_keys:
  298. self.state_dict, self.mappings = old_state_dict, old_mappings
  299. # _derived_version is only used by flatten_state_dict now.
  300. # Set it back to None so that later we can save to a new version.
  301. _version._derived_version = None
  302. return create_default_local_load_plan(
  303. self.state_dict, self.metadata, not self.allow_partial_load
  304. )
  305. def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
  306. return create_default_global_load_plan(global_plan)
  307. def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
  308. return new_plan
  309. def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
  310. if self.flatten_state_dict:
  311. set_element(
  312. self.original_state_dict,
  313. self.mappings[read_item.dest_index.fqn],
  314. torch.load(value, weights_only=False),
  315. )
  316. else:
  317. self.state_dict[read_item.dest_index.fqn] = torch.load(
  318. value, weights_only=False
  319. )
  320. def resolve_tensor(self, read_item: ReadItem):
  321. tensor = self.lookup_tensor(read_item.dest_index)
  322. return self.transform_tensor(read_item, tensor)
  323. def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
  324. pass
  325. def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
  326. """Extension from the planner interface to make it easy to extend the default planner."""
  327. return find_state_dict_object(self.state_dict, index)
  328. def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
  329. """Extension from the planner interface to make it easy to extend the default planner."""
  330. return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
  331. class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
  332. """
  333. Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
  334. Useful for loading in state_dict without first initializing a model, such as
  335. when converting a DCP checkpoint into a Torch save file.
  336. . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
  337. .. warning::
  338. Because the entire state dict is initialized, It's recommended to only utilize
  339. this LoadPlanner on a single rank or process to avoid OOM.
  340. """
  341. def __init__(self, keys=None, *args, **kwargs):
  342. self.keys = keys
  343. super().__init__(*args, **kwargs)
  344. def _should_include_key(self, key: str, metadata: Metadata) -> bool:
  345. if self.keys is None:
  346. return True
  347. if key in self.keys:
  348. return True
  349. unflattened_keys: list[str] = []
  350. planner_data = metadata.planner_data.get(key)
  351. for unflattened_key in planner_data:
  352. if unflattened_keys:
  353. unflattened_keys.append(
  354. ".".join([unflattened_keys[-1], str(unflattened_key)])
  355. )
  356. else:
  357. unflattened_keys.append(unflattened_key)
  358. if any(unflattened_key in self.keys for unflattened_key in unflattened_keys):
  359. return True
  360. return False
  361. def set_up_planner(
  362. self,
  363. state_dict: STATE_DICT_TYPE,
  364. metadata: Metadata | None = None,
  365. is_coordinator: bool = False,
  366. ) -> None:
  367. if state_dict:
  368. raise AssertionError("not state_dict")
  369. if metadata is None:
  370. raise AssertionError("metadata is not None")
  371. # rebuild the state dict from the metadata
  372. for k, v in metadata.state_dict_metadata.items():
  373. if not self._should_include_key(k, metadata):
  374. continue
  375. if isinstance(v, TensorStorageMetadata):
  376. v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
  377. if metadata.planner_data is not None and k in metadata.planner_data:
  378. set_element(state_dict, metadata.planner_data[k], v)
  379. else:
  380. state_dict[k] = v
  381. super().set_up_planner(state_dict, metadata, is_coordinator)
  382. def create_default_local_load_plan(
  383. state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
  384. ) -> LoadPlan:
  385. requests = []
  386. """
  387. Create the ``LoadPlan`` used by DefaultLoadPlanner.
  388. It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
  389. The default behavior is to match key exactly between state_dict and metadata.
  390. It handles resharding by issuing multiple read requests against storage in order to match
  391. load requirements.
  392. """
  393. for fqn, obj in state_dict.items():
  394. # ignore state_dict keys which do not exist in `state_dict` if strict=False
  395. if fqn not in metadata.state_dict_metadata:
  396. if strict:
  397. raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
  398. else:
  399. continue
  400. md = metadata.state_dict_metadata[fqn]
  401. if (
  402. isinstance(md, TensorStorageMetadata)
  403. and getattr(obj, "size", None) is not None
  404. and md.size != obj.size()
  405. ):
  406. raise ValueError(
  407. f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",
  408. )
  409. # Since DTensor supports submesh, adding extra check to ensure _create_read_items()
  410. # gets called only when the current rank is part of the mesh for the corresponding DTensor.
  411. if isinstance(obj, DTensor):
  412. if obj.device_mesh.get_coordinate() is not None:
  413. requests += _create_read_items(fqn, md, obj)
  414. else:
  415. requests += _create_read_items(fqn, md, obj)
  416. return LoadPlan(requests)
  417. def create_default_global_load_plan(
  418. all_plans: list[LoadPlan],
  419. ) -> list[LoadPlan]:
  420. """
  421. Create global load plan used by DefaultLoadPlanner.
  422. The default load behavior involved no global coordination and this function
  423. currently doesn't change the local plans.
  424. """
  425. return all_plans
  426. def create_default_local_save_plan(
  427. state_dict: dict[str, Any], is_coordinator: bool
  428. ) -> SavePlan:
  429. """
  430. Create the ``SavePlan`` used by DefaultSavePlanner.
  431. On non-coordinator ranks, this function ignores tensors and non-tensor objects,
  432. only producing writes for ShardedTensor objects.
  433. On the coordinator rank, produce writes for all values.
  434. """
  435. requests = []
  436. for fqn, obj in state_dict.items():
  437. # Since DTensor supports submesh, adding extra check to ensure _create_write_items()
  438. # gets called only when the current rank is part of the mesh for the corresponding DTensor.
  439. if isinstance(obj, DTensor):
  440. if obj.device_mesh.get_coordinate() is not None:
  441. requests += _create_write_items(fqn, obj)
  442. else:
  443. # For the plain tensor and non-tensor values, add the request for all
  444. # the ranks. Coordinator will decides whether to deduplicate the
  445. # values based on the keys.
  446. requests += _create_write_items(fqn, obj)
  447. return SavePlan(requests)
  448. def create_default_global_save_plan(
  449. all_plans: list[SavePlan],
  450. rewrite_index_hints: bool = True,
  451. ) -> tuple[list[SavePlan], Metadata]:
  452. """
  453. Create the global plan and metadata used by DefaultSavePlanner.
  454. Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.
  455. The only global planning change is to update index hints in all ``MetadataIndex`` objects if
  456. ``rewrite_index_hints`` is True.
  457. """
  458. md: dict[str, STORAGE_TYPES] = {}
  459. new_plans = []
  460. for plan in all_plans:
  461. new_items = []
  462. for item in plan.items:
  463. if item.type != WriteItemType.SHARD:
  464. if item.index.fqn in md:
  465. raise AssertionError("item.index.fqn not in md")
  466. if item.type == WriteItemType.BYTE_IO:
  467. md[item.index.fqn] = BytesStorageMetadata()
  468. new_items.append(item)
  469. else:
  470. if item.tensor_data is None:
  471. raise AssertionError("item.tensor_data is not None")
  472. tensor_md = cast(
  473. TensorStorageMetadata,
  474. md.setdefault(
  475. item.index.fqn,
  476. TensorStorageMetadata(
  477. properties=item.tensor_data.properties,
  478. size=item.tensor_data.size,
  479. chunks=[],
  480. ),
  481. ),
  482. )
  483. new_item = item
  484. if rewrite_index_hints:
  485. new_index = dataclasses.replace(
  486. item.index, index=len(tensor_md.chunks)
  487. )
  488. new_item = dataclasses.replace(item, index=new_index)
  489. new_items.append(new_item)
  490. if item.tensor_data.chunk is None:
  491. raise AssertionError(f"""
  492. Cannot create MD for tensor without bounds.
  493. FQN: {item.index.fqn}
  494. """)
  495. tensor_md.chunks.append(item.tensor_data.chunk)
  496. new_plans.append(dataclasses.replace(plan, items=new_items))
  497. return (new_plans, Metadata(md))
  498. def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
  499. """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``."""
  500. plan = _create_default_metadata_only_plan(state_dict)
  501. _, md = create_default_global_save_plan([plan])
  502. return md
  503. def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
  504. """Check if two boxes overlap. Tuples are (offset, lengths)."""
  505. # For each dim of each shard, check if one shard resides on the other
  506. # end of second shard with respect to that dim. As an example for a 2D
  507. # shard, we would check if one shard is above or on the left of the
  508. # other shard.
  509. ndims = len(box0.offsets)
  510. for i in range(ndims):
  511. if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
  512. return False
  513. if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
  514. return False
  515. return True
  516. def _check_box_bounds(
  517. outer_box_size: torch.Size, inner_box: ChunkStorageMetadata
  518. ) -> bool:
  519. for i in range(len(outer_box_size)):
  520. if inner_box.offsets[i] < 0:
  521. return False
  522. if inner_box.sizes[i] < 0:
  523. return False
  524. if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
  525. return False
  526. return True
  527. def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool:
  528. all_good = True
  529. for key, value in metadata.state_dict_metadata.items():
  530. if isinstance(value, BytesStorageMetadata):
  531. continue
  532. if len(value.size) == 0:
  533. continue
  534. chunks = value.chunks
  535. chunks_volume = 0
  536. for chunk in chunks:
  537. # Compute the volume
  538. if not _check_box_bounds(value.size, chunk):
  539. logger.warning(
  540. """
  541. key:%s has out of bounds chunk:
  542. tensor-size:%s chunk: %s
  543. """,
  544. key,
  545. value.size,
  546. chunk,
  547. )
  548. all_good = False
  549. chunks_volume += math.prod(chunk.sizes)
  550. if len(chunks) > 1:
  551. dims = len(value.size)
  552. sweep_dim = max(range(dims), default=0, key=lambda d: value.size[d])
  553. sorted_indices = sorted(
  554. range(len(chunks)),
  555. key=lambda idx: (
  556. chunks[idx].offsets[sweep_dim],
  557. *(chunks[idx].offsets[d] for d in range(dims)),
  558. ),
  559. )
  560. active: list[tuple[int, int]] = []
  561. for idx in sorted_indices:
  562. current = chunks[idx]
  563. start = current.offsets[sweep_dim]
  564. end = start + current.sizes[sweep_dim]
  565. cutoff = bisect_right(active, (start, sys.maxsize))
  566. if cutoff:
  567. del active[:cutoff]
  568. for _, other_idx in active:
  569. other = chunks[other_idx]
  570. if _check_box_overlap(current, other):
  571. logger.warning(
  572. "key:%s has overlapping chunks: %s %s",
  573. key,
  574. current,
  575. other,
  576. )
  577. all_good = False
  578. insort(active, (end, idx))
  579. # Check whether combined chunk cover the whole tensor
  580. tensor_volume = math.prod(value.size)
  581. if len(global_plan) > 1 and chunks_volume != tensor_volume:
  582. logger.warning(
  583. """
  584. key:%s invalid fill tensor-volume:
  585. %s chunks-volume: %s
  586. """,
  587. key,
  588. tensor_volume,
  589. chunks_volume,
  590. )
  591. all_good = False
  592. return all_good