resharding.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
  2. __all__: list[str] = []
  3. def _check_shard_metadata_pair_overlap(
  4. shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
  5. ) -> bool:
  6. """Check if two shards overlap."""
  7. # For each dim of each shard, check if one shard resides on the other
  8. # end of second shard with respect to that dim. As an example for a 2D
  9. # shard, we would check if one shard is above or on the left of the
  10. # other shard.
  11. ndims = len(shard1.offsets)
  12. for i in range(ndims):
  13. if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]:
  14. return False
  15. if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]:
  16. return False
  17. return True
  18. def _shards_get_overlap_region_wrt_saved_tensor(
  19. saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
  20. ) -> list[tuple[int, int, int, int]]:
  21. """
  22. Return the overlapping region between saved_shard and current_shard.
  23. There returned list has the same number of elements as the tensor's dimension.
  24. For each element, we produce a tuple with the following contents:
  25. (dimension, `saved_shard` offset, `current_shard` offset, length)
  26. Offsets are relative to each shard.
  27. """
  28. narrows = []
  29. for dim, (
  30. saved_shard_offset,
  31. current_shard_offset,
  32. saved_shard_size,
  33. current_shard_size,
  34. ) in enumerate(
  35. zip(
  36. saved_shard.offsets,
  37. current_shard.offsets,
  38. saved_shard.sizes,
  39. current_shard.sizes,
  40. )
  41. ):
  42. min_range_end = min(
  43. saved_shard_offset + saved_shard_size,
  44. current_shard_offset + current_shard_size,
  45. )
  46. length = min_range_end - max(current_shard_offset, saved_shard_offset)
  47. if saved_shard_offset > current_shard_offset:
  48. offset_for_saved_tensor = 0
  49. offset_for_current_tensor = saved_shard_offset - current_shard_offset
  50. else:
  51. offset_for_saved_tensor = current_shard_offset - saved_shard_offset
  52. offset_for_current_tensor = 0
  53. narrows.append(
  54. (dim, offset_for_saved_tensor, offset_for_current_tensor, length)
  55. )
  56. return narrows