tpu.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import logging
  2. import math
  3. from typing import Dict, List, Optional, Tuple
  4. import ray
  5. from ray._private.accelerators import TPUAcceleratorManager
  6. from ray._private.accelerators.tpu import (
  7. VALID_TPU_TYPES,
  8. get_chips_per_host,
  9. get_num_chips_from_topology,
  10. reserve_tpu_slice,
  11. )
  12. from ray._private.client_mode_hook import client_mode_wrap
  13. from ray.util.annotations import PublicAPI
  14. from ray.util.placement_group import (
  15. PlacementGroup,
  16. placement_group,
  17. remove_placement_group,
  18. )
  19. logger = logging.getLogger(__name__)
  20. @PublicAPI(stability="alpha")
  21. def get_tpu_version_from_type(accelerator_type: str) -> str:
  22. """Extracts the version from the accelerator type.
  23. Args:
  24. accelerator_type: The full accelerator type string (e.g. "TPU-V6E").
  25. Returns:
  26. The version string (e.g. "v6e").
  27. Raises:
  28. ValueError: If the accelerator type is invalid.
  29. """
  30. accel_type_lower = accelerator_type.lower()
  31. if accel_type_lower.startswith("tpu-"):
  32. version = accel_type_lower.replace("tpu-", "")
  33. elif accel_type_lower.startswith("tpu"):
  34. version = accel_type_lower.replace("tpu", "v")
  35. else:
  36. version = accel_type_lower
  37. if version not in VALID_TPU_TYPES:
  38. raise ValueError(
  39. f"Invalid accelerator_type: {accelerator_type}. "
  40. f"Must be one of {list(VALID_TPU_TYPES)} or start with 'TPU-' followed by a valid type."
  41. )
  42. return version
  43. @PublicAPI(stability="alpha")
  44. def get_current_pod_name() -> Optional[str]:
  45. """
  46. Return the name of the TPU pod that the worker is a part of.
  47. Returns:
  48. The name of the TPU pod. Returns None if not part of a TPU pod.
  49. """
  50. tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
  51. if tpu_name == "":
  52. tpu_name = None
  53. return tpu_name
  54. @PublicAPI(stability="alpha")
  55. def get_current_pod_worker_count() -> Optional[int]:
  56. """
  57. Count the number of workers associated with the TPU pod that the worker belongs to.
  58. Returns:
  59. The total number of workers in the TPU pod. Returns None if the worker is not
  60. part of a TPU pod.
  61. """
  62. return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod()
  63. @PublicAPI(stability="alpha")
  64. def get_num_tpu_chips_on_node() -> int:
  65. """
  66. Return the number of TPU chips on the node.
  67. Returns:
  68. The total number of chips on the TPU node. Returns 0 if none are found.
  69. """
  70. return TPUAcceleratorManager.get_current_node_num_accelerators()
  71. @PublicAPI(stability="alpha")
  72. def get_tpu_num_slices_for_workers(
  73. topology: str,
  74. accelerator_type: str,
  75. num_workers: int,
  76. resources_per_worker: Optional[Dict[str, float]] = None,
  77. ) -> int:
  78. """
  79. Calculates the number of slices needed to accommodate the specified number of workers.
  80. Args:
  81. topology: The TPU topology string.
  82. accelerator_type: The accelerator type string.
  83. num_workers: The desired number of workers.
  84. resources_per_worker: Optional dict of resources per worker.
  85. Returns:
  86. The number of slices required. Returns 1 if inputs are invalid or incomplete.
  87. """
  88. if not topology or not accelerator_type:
  89. return 1
  90. try:
  91. # Calculate how many workers fit in a single slice (num_slices=1)
  92. # given the topology and resources per worker.
  93. workers_per_slice, _ = get_tpu_worker_resources(
  94. topology=topology,
  95. accelerator_type=accelerator_type,
  96. resources_per_unit=resources_per_worker,
  97. num_slices=1,
  98. )
  99. if workers_per_slice == 0:
  100. return 1
  101. return max(1, math.ceil(num_workers / workers_per_slice))
  102. except Exception:
  103. # Fallback to 1 if calculation fails.
  104. return 1
  105. @PublicAPI(stability="alpha")
  106. def get_tpu_worker_resources(
  107. topology: str,
  108. accelerator_type: str,
  109. resources_per_unit: Optional[Dict[str, float]] = None,
  110. num_slices: int = 1,
  111. ) -> Tuple[int, Dict[str, float]]:
  112. """
  113. Calculates the number of workers and the resources required for each worker
  114. to run based on a TPU topology.
  115. Args:
  116. topology: The TPU topology string.
  117. accelerator_type: The accelerator string.
  118. resources_per_unit: Optional manual override for resources per unit. If
  119. unspecified, the number of TPU chips in a host is assumed.
  120. num_slices: The number of TPU slices.
  121. Returns:
  122. A tuple containing:
  123. - num_workers: Total workers required.
  124. - unit_resources: The resource dictionary for a single worker.
  125. """
  126. accelerator_version = get_tpu_version_from_type(accelerator_type)
  127. chips_per_host = get_chips_per_host(topology, accelerator_version)
  128. total_chips_per_slice = get_num_chips_from_topology(topology)
  129. total_chips_available = total_chips_per_slice * num_slices
  130. # Calculate the per-unit resources based on the TPU topology.
  131. final_resources = resources_per_unit.copy() if resources_per_unit else {}
  132. if "CPU" not in final_resources:
  133. final_resources["CPU"] = 1
  134. # If user didn't specify TPU, default to # of chips on 1 host.
  135. if "TPU" not in final_resources:
  136. final_resources["TPU"] = chips_per_host
  137. tpus_per_unit = final_resources["TPU"]
  138. # Validate TPU resource values.
  139. if tpus_per_unit <= 0:
  140. raise ValueError("TPU resources must be positive.")
  141. if total_chips_available % tpus_per_unit != 0:
  142. raise ValueError(
  143. f"Total chips ({total_chips_available}) not divisible by "
  144. f"TPUs requested per unit ({tpus_per_unit})."
  145. )
  146. if total_chips_per_slice % tpus_per_unit != 0:
  147. raise ValueError(
  148. f"The requested resources per bundle ({tpus_per_unit} TPU chips) do not "
  149. f"divide evenly into the chips available per slice ({total_chips_per_slice}). "
  150. "This configuration results in an uneven distribution of workers across slices, "
  151. "which is not supported."
  152. )
  153. num_workers = int(total_chips_available // tpus_per_unit)
  154. return num_workers, final_resources
  155. @PublicAPI(stability="alpha")
  156. def get_tpu_coordinator_env_vars(
  157. coordinator_address: str,
  158. num_slices: int,
  159. slice_id: int,
  160. coordinator_port: str = "8081",
  161. ) -> Dict[str, str]:
  162. """
  163. Returns the environment variables required for JAX multi-slice coordination.
  164. Args:
  165. coordinator_address: The IP address or hostname of the coordinator.
  166. num_slices: The total number of slices in the cluster.
  167. slice_id: The index of the current slice.
  168. coordinator_port: The port the coordinator is listening on.
  169. Returns:
  170. A dictionary mapping environment variable names to their values.
  171. """
  172. return {
  173. "MEGASCALE_COORDINATOR_ADDRESS": coordinator_address,
  174. "MEGASCALE_PORT": coordinator_port,
  175. "MEGASCALE_NUM_SLICES": str(num_slices),
  176. "MEGASCALE_SLICE_ID": str(slice_id),
  177. }
  178. @PublicAPI(stability="alpha")
  179. class SlicePlacementGroup:
  180. """
  181. A handle to a placement group reservation for a TPU slice.
  182. The following definitions are added for clarity:
  183. - Accelerator type: A string describing the accelerator type and version (e.g. TPU-V2, TPU-V6E).
  184. - Accelerator version: The accelerator generation only (e.g. v6e, v5p, v5litepod).
  185. - Pod type: The TPU accelerator version and the number of chips in a topology. (e.g. v6e-128, v5p-8).
  186. - Accelerator topology: The physical topology representing the structure (e.g. 2x2x2, 16x16).
  187. Args:
  188. topology: The TPU topology string (e.g. "2x2x2").
  189. accelerator_version: The TPU accelerator generation (e.g. "v6e", "v5p", "v4").
  190. resources_per_bundle: Optionally specify the resources to include in every worker bundle.
  191. strategy: PlacementGroup parameter. The strategy to create the placement group. Currently default to "SPREAD"
  192. - "PACK": Packs Bundles into as few nodes as possible.
  193. - "SPREAD": Places Bundles across distinct nodes as even as possible.
  194. - "STRICT_PACK": Packs Bundles into one node. The group is
  195. not allowed to span multiple nodes.
  196. - "STRICT_SPREAD": Packs Bundles across distinct nodes.
  197. lifetime: PlacementGroup parameter. Either `None`, which defaults to the placement group
  198. will fate share with its creator and will be deleted once its
  199. creator is dead, or "detached", which means the placement group
  200. will live as a global object independent of the creator.
  201. num_slices: Number of TPU slices in the SlicePlacementGroup. Defaults to 1 when unspecified.
  202. Examples:
  203. .. testcode:: python
  204. :skipif: True
  205. import ray
  206. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  207. from ray.util.tpu import SlicePlacementGroup
  208. slice_handle = SlicePlacementGroup(topology="4x4", accelerator_version="v6e")
  209. slice_pg = slice_handle.placement_group
  210. ray.get(slice_pg.ready(), timeout=10)
  211. @ray.remote(num_cpus=0, resources={'TPU': 4})
  212. def spmd_task(world, rank):
  213. print(f"Current TPU is rank {rank} of {world}")
  214. tasks = [
  215. spmd_task.options(
  216. scheduling_strategy=PlacementGroupSchedulingStrategy(
  217. placement_group=slice_pg,
  218. )
  219. ).remote(world=4, rank=i)
  220. for i in range(slice_handle.num_hosts)
  221. ]
  222. """
  223. def __init__(
  224. self,
  225. topology: str,
  226. accelerator_version: str,
  227. resources_per_bundle: Optional[Dict[str, float]] = None,
  228. # below are args related to PG
  229. strategy: str = "SPREAD",
  230. name: str = "",
  231. lifetime: Optional[str] = None,
  232. # default
  233. num_slices: int = 1,
  234. ):
  235. self._topology = topology.strip().lower()
  236. self._accelerator_version = accelerator_version.strip().lower()
  237. self._resources_per_bundle = resources_per_bundle or {}
  238. self._num_slices = num_slices
  239. # Calculate number of bundles and bundle resources for specified TPU topology.
  240. self._num_bundles, self._bundle_resources = get_tpu_worker_resources(
  241. topology=self._topology,
  242. accelerator_type=self._accelerator_version,
  243. resources_per_unit=resources_per_bundle,
  244. num_slices=self._num_slices,
  245. )
  246. self._chips_per_host = get_chips_per_host(
  247. self._topology, self._accelerator_version
  248. )
  249. total_chips = get_num_chips_from_topology(self._topology)
  250. hosts_per_slice = max(1, total_chips // self._chips_per_host)
  251. self._num_hosts = hosts_per_slice * self._num_slices
  252. self._head_pgs: List[PlacementGroup] = []
  253. self._bundle_label_selector: List[Dict[str, str]] = []
  254. self._validate_tpu_config()
  255. self._placement_group = None
  256. # Reserve a TPU slice of the provided accelerator version and topology.
  257. self._placement_group = self._reserve_slice(
  258. strategy,
  259. name,
  260. lifetime,
  261. )
  262. def _accelerator_version_check(self, accelerator_version: str):
  263. if accelerator_version not in VALID_TPU_TYPES:
  264. raise ValueError(
  265. f"Invalid accelerator version: {accelerator_version}. Must be one of: {VALID_TPU_TYPES}"
  266. )
  267. def _validate_tpu_config(self):
  268. # Should validate topology and generation values and return a
  269. # ValueError if invalid.
  270. self._accelerator_version_check(self.accelerator_version)
  271. if not TPUAcceleratorManager.is_valid_tpu_accelerator_topology(
  272. tpu_accelerator_version=self.accelerator_version,
  273. tpu_topology=self._topology,
  274. ):
  275. raise ValueError(
  276. f"Invalid accelerator topology: '{self._topology}' for "
  277. f"accelerator version: '{self.accelerator_version}'"
  278. )
  279. def _reserve_slice(
  280. self,
  281. strategy: str = "SPREAD",
  282. name: str = "",
  283. lifetime: Optional[str] = None,
  284. ) -> PlacementGroup:
  285. """Performs the two-step scheduling to reserve a TPU slice."""
  286. self._bundle_label_selector = []
  287. bundles = []
  288. bundles_per_slice = self._num_bundles // self._num_slices
  289. # Construct accelerator format for reserve_tpu_slice. e.g. From "v6e" to "TPU-V6E", "v5p" to "TPU-V5P".
  290. accelerator_type = "TPU-" + self.accelerator_version.upper()
  291. try:
  292. for _ in range(self.num_slices):
  293. reservation = reserve_tpu_slice(self._topology, accelerator_type)
  294. if not reservation:
  295. raise RuntimeError(
  296. f"Failed to reserve TPU slice. Requested {self.num_slices} "
  297. f"slice(s) of topology '{self._topology}' with accelerator type "
  298. f"'{accelerator_type}'. Ensure that sufficient TPU resources are "
  299. "available in the cluster."
  300. )
  301. # Store the head placement group for clean-up when un-reserving the slice.
  302. slice_name, head_pg = reservation
  303. self._head_pgs.append(head_pg)
  304. # Reserving a slice is done through constructing num_hosts bundles, each with a label selector for
  305. # the unique name of an available TPU slice.
  306. selector = {ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name}
  307. self._bundle_label_selector.extend([selector] * bundles_per_slice)
  308. bundles += [
  309. self._bundle_resources.copy() for _ in range(bundles_per_slice)
  310. ]
  311. pg = placement_group(
  312. bundles=bundles,
  313. strategy=strategy,
  314. name=name,
  315. lifetime=lifetime,
  316. bundle_label_selector=self._bundle_label_selector,
  317. )
  318. return pg
  319. except Exception:
  320. self.shutdown()
  321. raise
  322. @property
  323. def placement_group(self) -> PlacementGroup:
  324. """The underlying PlacementGroup object."""
  325. return self._placement_group
  326. @property
  327. def chips_per_host(self) -> int:
  328. """The number of chips per host for this TPU slice."""
  329. # This is the same value as resources per worker for TPU.
  330. return self._chips_per_host
  331. @property
  332. def num_hosts(self) -> int:
  333. """The total number of hosts in the SlicePlacementGroup."""
  334. return self._num_hosts
  335. @property
  336. def num_bundles(self) -> int:
  337. """The total number of bundles in the SlicePlacementGroup."""
  338. return self._num_bundles
  339. @property
  340. def topology(self) -> str:
  341. """The physical topology of the TPU slice."""
  342. return self._topology
  343. @property
  344. def accelerator_version(self) -> str:
  345. """The TPU accelerator type of the slice."""
  346. return self._accelerator_version
  347. @property
  348. def num_slices(self) -> int:
  349. """The number of TPU slices this SlicePlacementGroup spans."""
  350. return self._num_slices
  351. @property
  352. def head_placement_groups(self) -> List[PlacementGroup]:
  353. """The internal head PGs used to reserve the slices."""
  354. return self._head_pgs
  355. @property
  356. def bundle_label_selector(self) -> List[Dict[str, str]]:
  357. """The bundle label selector list for the worker PG."""
  358. return self._bundle_label_selector
  359. @property
  360. def bundle_resources(self) -> Dict[str, float]:
  361. """The resources that are assigned to each bundle."""
  362. return self._bundle_resources
  363. def shutdown(self):
  364. """Removes the worker placement group and all internal head PGs."""
  365. if self._placement_group:
  366. remove_placement_group(self._placement_group)
  367. self._placement_group = None
  368. for head_pg in self._head_pgs:
  369. remove_placement_group(head_pg)
  370. self._head_pgs = []
  371. @PublicAPI(stability="alpha")
  372. @client_mode_wrap
  373. def slice_placement_group(
  374. topology: str,
  375. accelerator_version: str,
  376. resources_per_bundle: Optional[Dict[str, float]] = None,
  377. num_slices: int = 1,
  378. **kwargs,
  379. ) -> SlicePlacementGroup:
  380. """Asynchronously creates a PlacementGroup for a TPU slice.
  381. A slice placement group reserves num_slices TPU slice(s) and creates a placement
  382. group for scheduling tasks or actors.
  383. Args:
  384. topology: The desired TPU pod topology (e.g. "4x4", "2x8").
  385. accelerator_version: The TPU accelerator generation, (e.g. "v4", "v5p", "v6e").
  386. resources_per_bundle: Specify the number of resources to reserve per bundle.
  387. When unspecified, SlicePlacementGroup defaults to reserving 1 bundle per TPU host in
  388. a topology, with the bundle resources set to the number of TPU in a host.
  389. Ex: Specifying {"TPU": 1} for a 4x4 topology would result in 16 bundles, each with 1 TPU.
  390. If resources_per_bundle=None for the same topology, there would be 4 bundles with 4 TPU each.
  391. num_slices: The number of tpu slices within the placement group
  392. **kwargs: Additional arguments for the placement group, such as 'name', 'lifetime', or 'strategy'.
  393. Returns:
  394. The handle for the created SlicePlacementGroup.
  395. """
  396. return SlicePlacementGroup(
  397. topology=topology,
  398. accelerator_version=accelerator_version,
  399. resources_per_bundle=resources_per_bundle,
  400. num_slices=num_slices,
  401. **kwargs,
  402. )