planner_helpers.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # mypy: allow-untyped-defs
  2. import io
  3. import itertools
  4. from bisect import bisect_right, insort
  5. from collections.abc import Callable
  6. from typing import Any, cast
  7. import torch
  8. import torch.distributed as dist
  9. from torch._utils import _get_device_module
  10. from torch.distributed._shard.metadata import ShardMetadata
  11. from torch.distributed._shard.sharded_tensor import ShardedTensor
  12. from torch.distributed.tensor import DTensor
  13. from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
  14. from .metadata import (
  15. BytesStorageMetadata,
  16. ChunkStorageMetadata,
  17. MetadataIndex,
  18. STATE_DICT_TYPE,
  19. STORAGE_TYPES,
  20. TensorProperties,
  21. TensorStorageMetadata,
  22. )
  23. from .planner import (
  24. LoadItemType,
  25. ReadItem,
  26. SavePlan,
  27. TensorWriteData,
  28. WriteItem,
  29. WriteItemType,
  30. )
  31. from .resharding import (
  32. _check_shard_metadata_pair_overlap,
  33. _shards_get_overlap_region_wrt_saved_tensor,
  34. )
  35. __all__: list[str] = ["create_read_items_for_chunk_list"]
  36. def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool:
  37. """
  38. Compare the two Save plans and return True if they are equal.
  39. Args:
  40. plan (SavePlan): First SavePlan to compare.
  41. other_plan (SavePlan): Second SavePlan to compare.
  42. Returns:
  43. True if the two plans are equal, False otherwise.
  44. """
  45. if plan.usable != other_plan.usable:
  46. return False
  47. # Both the plans should have the same number of items
  48. if len(plan.items) != len(other_plan.items):
  49. return False
  50. # Both the plans should have the same write items.
  51. for plan_item, other_plan_item in zip(plan.items, other_plan.items):
  52. # Write item type should be same
  53. if plan_item.type != other_plan_item.type:
  54. return False
  55. plan_metadata_index = plan_item.index
  56. other_plan_metadata_index = other_plan_item.index
  57. # Write item metadata_index should be same
  58. if (
  59. plan_metadata_index.fqn != other_plan_metadata_index.fqn
  60. or plan_metadata_index.offset != other_plan_metadata_index.offset
  61. or plan_metadata_index.index != other_plan_metadata_index.index
  62. ):
  63. return False
  64. # Write item tensor_data should be present in both the write items plans, if it exists in either of them.
  65. tensor_data = plan_item.tensor_data
  66. other_tensor_data = other_plan_item.tensor_data
  67. if (tensor_data and not other_tensor_data) or (
  68. not tensor_data and other_tensor_data
  69. ):
  70. return False
  71. if tensor_data and other_tensor_data:
  72. # Write item tensor_data size should be same
  73. if tensor_data.size != other_tensor_data.size:
  74. return False
  75. # Write item tensor_data chunk should be present in both the write items, if it exists in either of them.
  76. chunk = tensor_data.chunk
  77. other_chunk = other_tensor_data.chunk
  78. if (chunk and not other_chunk) or (not chunk and other_chunk):
  79. return False
  80. # Write item tensor_data chunk offsets and sizes should be same
  81. if chunk and other_chunk:
  82. if (
  83. chunk.offsets != other_chunk.offsets
  84. or chunk.sizes != other_chunk.sizes
  85. ):
  86. return False
  87. return True
  88. def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool:
  89. """
  90. Check if any delta plan is usable, indicating the plan has changed.
  91. Args:
  92. delta_plans (List[SavePlan]): A list of delta plans to check.
  93. Returns:
  94. True if any delta plan is usable, False otherwise.
  95. """
  96. return any(delta_plan and delta_plan.usable for delta_plan in delta_plans)
  97. def _merge_delta_local_plans(
  98. cached_plans: list[SavePlan],
  99. delta_plans: list[SavePlan],
  100. ) -> list[SavePlan]:
  101. """
  102. Merge a list of delta plans into a single plan.
  103. Args:
  104. cached_plans (List[SavePlan]): A list of cached plans.
  105. delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans
  106. Returns:
  107. A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan.
  108. """
  109. merged_plans = []
  110. for cached_plan, delta_plan in zip(cached_plans, delta_plans):
  111. if delta_plan and not delta_plan.usable:
  112. merged_plans.append(cached_plan)
  113. else:
  114. merged_plans.append(delta_plan)
  115. return merged_plans
  116. def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
  117. return ChunkStorageMetadata(
  118. offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
  119. )
  120. def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
  121. return ChunkStorageMetadata(
  122. offsets=torch.Size(shard_md.shard_offsets),
  123. sizes=torch.Size(shard_md.shard_sizes),
  124. )
  125. def _sharded_tensor_metadata(
  126. sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  127. ) -> TensorWriteData:
  128. shard_properties = sharded_tensor.metadata().tensor_properties
  129. properties = TensorProperties(
  130. dtype=shard_properties.dtype,
  131. layout=shard_properties.layout,
  132. requires_grad=shard_properties.requires_grad,
  133. memory_format=shard_properties.memory_format,
  134. pin_memory=shard_properties.pin_memory,
  135. )
  136. return TensorWriteData(
  137. chunk=_chunk_for_shard(shard_md),
  138. properties=properties,
  139. size=sharded_tensor.metadata().size,
  140. )
  141. def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
  142. sizes, offsets = compute_local_shape_and_global_offset(
  143. tensor.shape, tensor.device_mesh, tensor.placements
  144. )
  145. sizes, offsets = torch.Size(sizes), torch.Size(offsets)
  146. return WriteItem(
  147. index=MetadataIndex(fqn, offsets),
  148. type=WriteItemType.SHARD,
  149. tensor_data=TensorWriteData(
  150. chunk=ChunkStorageMetadata(
  151. offsets=offsets,
  152. sizes=sizes,
  153. ),
  154. properties=TensorProperties.create_from_tensor(tensor.to_local()),
  155. size=tensor.size(),
  156. ),
  157. )
  158. def _create_write_item_for_shard(
  159. fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  160. ) -> WriteItem:
  161. offsets = torch.Size(shard_md.shard_offsets)
  162. return WriteItem(
  163. index=MetadataIndex(fqn, offsets),
  164. type=WriteItemType.SHARD,
  165. tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
  166. )
  167. def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
  168. offsets = torch.Size([0] * len(tensor.size()))
  169. return WriteItem(
  170. index=MetadataIndex(fqn, offsets),
  171. type=WriteItemType.TENSOR,
  172. tensor_data=TensorWriteData(
  173. chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
  174. properties=TensorProperties.create_from_tensor(tensor),
  175. size=tensor.size(),
  176. ),
  177. )
  178. def _create_write_item_for_bytesio(fqn: str, bytes: Any):
  179. return WriteItem(
  180. index=MetadataIndex(fqn),
  181. type=WriteItemType.BYTE_IO,
  182. )
  183. def _create_read_item_for_byteio(
  184. dest_index, dest_offset, storage_index, storage_offset, length
  185. ):
  186. return ReadItem(
  187. type=LoadItemType.BYTE_IO,
  188. dest_index=dest_index,
  189. dest_offsets=torch.Size((dest_offset,)),
  190. storage_index=storage_index,
  191. storage_offsets=torch.Size((storage_offset,)),
  192. lengths=torch.Size((length,)),
  193. )
  194. def _create_read_item_for_tensor(
  195. dest_index, dest_offsets, storage_index, storage_offsets, lengths
  196. ):
  197. return ReadItem(
  198. type=LoadItemType.TENSOR,
  199. dest_index=dest_index,
  200. dest_offsets=torch.Size(dest_offsets),
  201. storage_index=storage_index,
  202. storage_offsets=torch.Size(storage_offsets),
  203. lengths=torch.Size(lengths),
  204. )
  205. def create_read_items_for_chunk_list(
  206. fqn: str,
  207. checkpoint_md: TensorStorageMetadata,
  208. local_chunks: list[ChunkStorageMetadata],
  209. ) -> list[ReadItem]:
  210. """
  211. Create a list of ``ReadItem`` based on the checkpoint and local chunks.
  212. This applies the resharding algorithm and computes the reads needed
  213. to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
  214. Args:
  215. fqn (str) : The state_dict FQN to pass to ``ReadItem``.
  216. checkpoint_md (TensorStorageMetadata): metadata for a given tensor
  217. from a checkpoint.
  218. local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
  219. loaded.
  220. Returns:
  221. A list of ``ReadItem`` that will satisfy all input chunks.
  222. """
  223. read_items: list[ReadItem] = []
  224. saved_chunks = checkpoint_md.chunks
  225. if not local_chunks or not saved_chunks:
  226. return read_items
  227. num_dims = len(local_chunks[0].offsets)
  228. # Find sweep dimension (dimension with largest extent for better pruning)
  229. sweep_dim = 0
  230. if num_dims > 1:
  231. max_size = 0
  232. for dim in range(num_dims):
  233. dim_size = max(
  234. chunk.offsets[dim] + chunk.sizes[dim]
  235. for chunk in itertools.chain(local_chunks, saved_chunks)
  236. )
  237. if dim_size > max_size:
  238. max_size = dim_size
  239. sweep_dim = dim
  240. # Pre-compute bounds: (start, end) for each chunk in sweep dimension
  241. # For 0-d tensors, use (0, 1) so all chunks overlap in the sweep line
  242. if num_dims == 0:
  243. saved_bounds = [(0, 1)] * len(saved_chunks)
  244. local_bounds = [(0, 1)] * len(local_chunks)
  245. else:
  246. saved_bounds = [
  247. (c.offsets[sweep_dim], c.offsets[sweep_dim] + c.sizes[sweep_dim])
  248. for c in saved_chunks
  249. ]
  250. local_bounds = [
  251. (c.offsets[sweep_dim], c.offsets[sweep_dim] + c.sizes[sweep_dim])
  252. for c in local_chunks
  253. ]
  254. saved_sorted_indices = sorted(
  255. range(len(saved_chunks)),
  256. key=lambda idx: saved_bounds[idx][0],
  257. )
  258. local_sorted_indices = sorted(
  259. range(len(local_chunks)),
  260. key=lambda idx: local_bounds[idx][0],
  261. )
  262. active_saved: list[tuple[int, int]] = []
  263. saved_ptr = 0
  264. num_saved = len(saved_sorted_indices)
  265. for local_idx in local_sorted_indices:
  266. local_chunk = local_chunks[local_idx]
  267. local_start, local_end = local_bounds[local_idx]
  268. cutoff = bisect_right(active_saved, (local_start, -1))
  269. if cutoff:
  270. del active_saved[:cutoff]
  271. while saved_ptr < num_saved:
  272. storage_idx = saved_sorted_indices[saved_ptr]
  273. storage_chunk = saved_chunks[storage_idx]
  274. saved_start, saved_end = saved_bounds[storage_idx]
  275. if saved_start >= local_end:
  276. break
  277. insort(active_saved, (saved_end, storage_idx))
  278. saved_ptr += 1
  279. for _, storage_idx in active_saved:
  280. storage_chunk = saved_chunks[storage_idx]
  281. if not _check_shard_metadata_pair_overlap(local_chunk, storage_chunk):
  282. continue
  283. storage_offsets = []
  284. dest_offsets = []
  285. lengths = []
  286. for (
  287. _dim,
  288. offset_for_saved_tensor,
  289. offset_for_current_tensor,
  290. length,
  291. ) in _shards_get_overlap_region_wrt_saved_tensor(
  292. saved_shard=storage_chunk, current_shard=local_chunk
  293. ):
  294. storage_offsets.append(offset_for_saved_tensor)
  295. dest_offsets.append(offset_for_current_tensor)
  296. lengths.append(length)
  297. read_items.append(
  298. _create_read_item_for_tensor(
  299. dest_index=MetadataIndex(fqn, local_chunk.offsets, local_idx),
  300. dest_offsets=dest_offsets,
  301. storage_index=MetadataIndex(
  302. fqn, storage_chunk.offsets, storage_idx
  303. ),
  304. storage_offsets=storage_offsets,
  305. lengths=lengths,
  306. )
  307. )
  308. return read_items
  309. def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
  310. requests = []
  311. for fqn, obj in state_dict.items():
  312. if isinstance(obj, DTensor):
  313. requests.append(_create_write_items_for_dtensor(fqn, obj))
  314. elif isinstance(obj, ShardedTensor):
  315. requests.extend(
  316. _create_write_item_for_shard(fqn, obj, shard_md)
  317. for shard_md in obj.metadata().shards_metadata
  318. )
  319. elif isinstance(obj, torch.Tensor):
  320. requests.append(_create_write_item_for_tensor(fqn, obj))
  321. else:
  322. requests.append(_create_write_item_for_bytesio(fqn, obj))
  323. return SavePlan(requests)
  324. def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
  325. if hasattr(object, "__create_write_items__"):
  326. # DTensor implements _Checkpointable
  327. return object.__create_write_items__(fqn, object)
  328. elif isinstance(object, ShardedTensor):
  329. return [
  330. _create_write_item_for_shard(fqn, object, shard.metadata)
  331. for shard in object.local_shards()
  332. ]
  333. elif isinstance(object, torch.Tensor):
  334. return [_create_write_item_for_tensor(fqn, object)]
  335. else:
  336. return [_create_write_item_for_bytesio(fqn, object)]
  337. def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
  338. sizes, offsets = compute_local_shape_and_global_offset(
  339. tensor.shape, tensor.device_mesh, tensor.placements
  340. )
  341. sizes, offsets = torch.Size(sizes), torch.Size(offsets)
  342. return ChunkStorageMetadata(
  343. offsets=offsets,
  344. sizes=sizes,
  345. )
  346. def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
  347. if hasattr(tensor, "__create_chunk_list__"):
  348. # DTensor implements _Checkpointable
  349. local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
  350. elif isinstance(tensor, ShardedTensor):
  351. local_chunks = [
  352. _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
  353. ]
  354. elif isinstance(tensor, torch.Tensor):
  355. local_chunks = [_create_chunk_from_tensor(tensor)]
  356. else:
  357. raise ValueError(
  358. "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
  359. f",but got {type(tensor)}"
  360. )
  361. return local_chunks
  362. def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
  363. if not isinstance(md, BytesStorageMetadata):
  364. try:
  365. local_chunks = _create_chunk_list(obj)
  366. except ValueError as ex:
  367. raise ValueError(
  368. f"Invalid checkpoint metadata for {fqn}, "
  369. + f"expected BytesStorageMetadata but found {type(md)}",
  370. ) from ex
  371. return create_read_items_for_chunk_list(fqn, md, local_chunks)
  372. else:
  373. return [
  374. _create_read_item_for_byteio(
  375. dest_index=MetadataIndex(fqn),
  376. dest_offset=0,
  377. storage_index=MetadataIndex(fqn),
  378. storage_offset=0,
  379. length=0,
  380. )
  381. ]
  382. def _init_state_dict(state_dict: dict[str, Any]) -> Any:
  383. """
  384. Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
  385. """
  386. def dtensor_func(value: DTensor):
  387. device = getattr(value, "device", None)
  388. if device == torch.device("meta"):
  389. device_type = dist.distributed_c10d._get_pg_default_device().type
  390. device = cast(
  391. torch.device, _get_device_module(device_type).current_device()
  392. )
  393. new_local_tensor = torch.empty_like(value.to_local(), device=device)
  394. # We need to pass shape and stride explicitly, since DTensor might be
  395. # sharded unevenly.
  396. dtensor = DTensor.from_local(
  397. new_local_tensor,
  398. device_mesh=value.device_mesh,
  399. placements=value.placements,
  400. shape=value.size(),
  401. stride=value.stride(),
  402. )
  403. return dtensor
  404. else:
  405. return value
  406. def sharded_tensor_func(value: Any):
  407. device = getattr(value, "device", None)
  408. if device == torch.device("meta"):
  409. raise RuntimeError(
  410. f"Found unsupported type {type(value)} for meta device loading."
  411. )
  412. else:
  413. return value
  414. def tensor_func(value: torch.Tensor):
  415. device = getattr(value, "device", None)
  416. if device == torch.device("meta"):
  417. device_type = dist.distributed_c10d._get_pg_default_device().type
  418. device = cast(
  419. torch.device, _get_device_module(device_type).current_device()
  420. )
  421. tensor = torch.empty_like(value, device=device)
  422. return tensor
  423. else:
  424. return value
  425. _iterate_state_dict(
  426. state_dict,
  427. dtensor_func,
  428. sharded_tensor_func,
  429. tensor_func,
  430. )
  431. def _iterate_state_dict(
  432. iter_object: Any,
  433. dtensor_func: Callable,
  434. sharded_tensor_func: Callable,
  435. tensor_func: Callable,
  436. ):
  437. """
  438. Iterate through the state dict, applying the given functions to each tensor type
  439. and update the state dict in place.
  440. Args:
  441. iter_object (Any): the target state_dict.
  442. sharded_tensor_func (Callable): the function to apply to ShardedTensor
  443. dtensor_func (Callable): the function to apply to DTensor
  444. tensor_func (Callable): the function to apply to Tensor
  445. # TODO: let state_dict_util._iterate_state_dict() to support in place option
  446. so we don't need to have two versions of _iterate_state_dict.
  447. """
  448. if isinstance(iter_object, DTensor):
  449. return dtensor_func(iter_object)
  450. elif isinstance(iter_object, ShardedTensor):
  451. return sharded_tensor_func(iter_object)
  452. elif isinstance(iter_object, torch.Tensor):
  453. return tensor_func(iter_object)
  454. elif (
  455. isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
  456. or iter_object is None
  457. ):
  458. return iter_object
  459. elif isinstance(iter_object, dict):
  460. for key, value in iter_object.items():
  461. iter_object[key] = _iterate_state_dict(
  462. value, dtensor_func, sharded_tensor_func, tensor_func
  463. )
  464. return iter_object
  465. elif isinstance(iter_object, (list, tuple)):
  466. ret = [
  467. _iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func)
  468. for v in iter_object
  469. ]
  470. if isinstance(iter_object, tuple):
  471. ret = tuple(ret) # type: ignore[assignment]
  472. return ret