| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
- __all__: list[str] = []
- def _check_shard_metadata_pair_overlap(
- shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
- ) -> bool:
- """Check if two shards overlap."""
- # For each dim of each shard, check if one shard resides on the other
- # end of second shard with respect to that dim. As an example for a 2D
- # shard, we would check if one shard is above or on the left of the
- # other shard.
- ndims = len(shard1.offsets)
- for i in range(ndims):
- if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]:
- return False
- if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]:
- return False
- return True
- def _shards_get_overlap_region_wrt_saved_tensor(
- saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
- ) -> list[tuple[int, int, int, int]]:
- """
- Return the overlapping region between saved_shard and current_shard.
- There returned list has the same number of elements as the tensor's dimension.
- For each element, we produce a tuple with the following contents:
- (dimension, `saved_shard` offset, `current_shard` offset, length)
- Offsets are relative to each shard.
- """
- narrows = []
- for dim, (
- saved_shard_offset,
- current_shard_offset,
- saved_shard_size,
- current_shard_size,
- ) in enumerate(
- zip(
- saved_shard.offsets,
- current_shard.offsets,
- saved_shard.sizes,
- current_shard.sizes,
- )
- ):
- min_range_end = min(
- saved_shard_offset + saved_shard_size,
- current_shard_offset + current_shard_size,
- )
- length = min_range_end - max(current_shard_offset, saved_shard_offset)
- if saved_shard_offset > current_shard_offset:
- offset_for_saved_tensor = 0
- offset_for_current_tensor = saved_shard_offset - current_shard_offset
- else:
- offset_for_saved_tensor = current_shard_offset - saved_shard_offset
- offset_for_current_tensor = 0
- narrows.append(
- (dim, offset_for_saved_tensor, offset_for_current_tensor, length)
- )
- return narrows
|