remote_device.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # mypy: allow-untyped-defs
  2. import torch
  3. class _remote_device:
  4. """
  5. Represents a device on a remote worker.
  6. Args:
  7. remote_device (str or torch.device): Represents a device on a remote worker.
  8. The string format should be one of the following:
  9. 1. "<workername>/<device>", where the device field can be parsed as torch.device type.
  10. E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
  11. In addition, the device field can be optional and the default value is "cpu".
  12. 2. "rank:<rank>/<device>", where <rank> is the rank of the
  13. process and device can be parsed as torch.device type.
  14. E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
  15. 3. <workername> and <rank> are optional and formats like "cpu"
  16. and "cuda:1", just represent local devices.
  17. """
  18. def __init__(self, remote_device: str | torch.device):
  19. PARSE_ERROR = (
  20. f"Could not parse remote_device: {remote_device}. The valid format is "
  21. "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
  22. )
  23. self._worker_name = None
  24. self._rank = None
  25. self._device: str | int | torch.device | None = None
  26. if isinstance(remote_device, torch.device):
  27. self._device = remote_device
  28. elif isinstance(remote_device, str):
  29. fields = remote_device.split("/")
  30. if len(fields) == 2:
  31. # pyrefly: ignore [bad-assignment]
  32. self._worker_name, self._device = fields
  33. elif len(fields) == 1:
  34. # Check if this is a valid device.
  35. if _remote_device._is_valid_local_device(fields[0]):
  36. self._device = fields[0]
  37. else:
  38. # pyrefly: ignore [bad-assignment]
  39. self._worker_name = fields[0]
  40. self._device = "cpu"
  41. else:
  42. raise ValueError(PARSE_ERROR)
  43. else:
  44. raise TypeError(f"Invalid type for remote_device: {type(remote_device)}")
  45. # Do some basic sanity check (no empty string)
  46. if self._worker_name is not None and not self._worker_name:
  47. raise ValueError(PARSE_ERROR)
  48. # Validate the device.
  49. self._device = torch.device(self._device)
  50. # Check for rank based format.
  51. if self._worker_name is not None:
  52. fields = self._worker_name.split(":")
  53. if len(fields) == 2:
  54. # rank:<rank>/device format, extract rank
  55. if fields[0] == "rank" and fields[1].isdigit():
  56. self._rank = int(fields[1]) # type: ignore[assignment]
  57. # pyrefly: ignore [bad-assignment]
  58. self._worker_name = None
  59. else:
  60. raise ValueError(PARSE_ERROR)
  61. elif len(fields) > 2:
  62. raise ValueError(PARSE_ERROR)
  63. @staticmethod
  64. def _is_valid_local_device(device):
  65. # Check for torch.device
  66. try:
  67. torch.device(device)
  68. return True
  69. except Exception:
  70. return False
  71. def worker_name(self) -> str | None:
  72. """Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
  73. return self._worker_name
  74. def rank(self) -> int | None:
  75. """
  76. Returns the rank of remote worker representing the remote device.
  77. Returns ``None`` if no rank is available.
  78. """
  79. return self._rank
  80. def device(self) -> torch.device:
  81. """Return the local device on the remote worker."""
  82. return self._device # type: ignore[return-value]
  83. def __repr__(self):
  84. if self._device is not None:
  85. if self._worker_name is not None:
  86. return f"{self._worker_name}/{self._device}"
  87. elif self._rank is not None:
  88. return f"rank:{self._rank}/{self._device}"
  89. else:
  90. return str(self._device)
  91. else:
  92. if self._worker_name is not None:
  93. return f"{self._worker_name}"
  94. elif self._rank is not None:
  95. return f"{self._rank}"
  96. else:
  97. raise RuntimeError("Invalid state!")
  98. def __eq__(self, other):
  99. return isinstance(other, _remote_device) and (
  100. self._worker_name == other._worker_name
  101. and self._device == other._device
  102. and self._rank == other._rank
  103. )
  104. def __hash__(self):
  105. return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank)