metadata.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # mypy: allow-untyped-defs
  2. from dataclasses import dataclass
  3. from functools import reduce
  4. from torch.distributed.remote_device import _remote_device
  5. @dataclass
  6. class ShardMetadata:
  7. """
  8. Represents a shard of the overall Tensor including its
  9. offsets, lengths and device placement.
  10. Args:
  11. shard_offsets(List[int]): Offsets in the original tensor indicating
  12. the start offsets for this shard. Should have the same rank as
  13. the original tensor.
  14. shard_sizes(List[int]): Integers indicating the size of each
  15. dimension for this shard. Should have the same rank as the
  16. original tensor.
  17. placement(:class:`torch.distributed._remote_device`):
  18. Specifies the placement of this shard.
  19. """
  20. __slots__ = ["shard_offsets", "shard_sizes", "placement"]
  21. shard_offsets: list[int]
  22. shard_sizes: list[int]
  23. placement: _remote_device | None
  24. def __init__(
  25. self,
  26. shard_offsets: list[int],
  27. shard_sizes: list[int],
  28. placement: str | _remote_device | None = None,
  29. ):
  30. self.shard_offsets = shard_offsets
  31. self.shard_sizes = shard_sizes
  32. if isinstance(placement, str):
  33. self.placement = _remote_device(placement)
  34. else:
  35. self.placement = placement
  36. if len(self.shard_offsets) != len(self.shard_sizes):
  37. raise ValueError(
  38. f"shard_offsets and shard_sizes should have "
  39. f"the same number of elements, found {len(self.shard_offsets)} "
  40. f"and {self.shard_sizes} respectively"
  41. )
  42. for i in range(len(self.shard_offsets)):
  43. if self.shard_offsets[i] < 0:
  44. raise ValueError("shard_offsets should be >=0")
  45. if self.shard_sizes[i] < 0:
  46. raise ValueError("shard_sizes should be >= 0")
  47. def __hash__(self):
  48. def _hash_reduce(a, b):
  49. return (a << 8) + hash(b)
  50. res = reduce(_hash_reduce, self.shard_offsets, 37)
  51. res = reduce(_hash_reduce, self.shard_sizes, res)
  52. res = _hash_reduce(res, self.placement)
  53. return res