tpu.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. import glob
  2. import logging
  3. import os
  4. import re
  5. from functools import lru_cache
  6. from typing import Dict, List, Optional, Set, Tuple
  7. import requests
  8. import ray
  9. from ray._private.accelerators.accelerator import AcceleratorManager
  10. from ray._private.ray_constants import env_bool
  11. from ray.util.placement_group import PlacementGroup, placement_group
  12. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  13. logger = logging.getLogger(__name__)
  14. TPU_VALID_CHIP_OPTIONS = (1, 2, 4, 8)
  15. GKE_TPU_ACCELERATOR_TYPE_ENV_VAR = "TPU_ACCELERATOR_TYPE"
  16. GKE_TPU_TOPOLOGY_ENV_VAR = "TPU_TOPOLOGY"
  17. GKE_TPU_WORKER_ID_ENV_VAR = "TPU_WORKER_ID"
  18. GKE_TPU_NAME_ENV_VAR = "TPU_NAME"
  19. # Constants for accessing the `accelerator-type` from TPU VM
  20. # instance metadata.
  21. # See https://cloud.google.com/compute/docs/metadata/overview
  22. # for more details about VM instance metadata.
  23. GCE_TPU_ACCELERATOR_ENDPOINT = (
  24. "http://metadata.google.internal/computeMetadata/v1/instance/attributes/"
  25. )
  26. GCE_TPU_HEADERS = {"Metadata-Flavor": "Google"}
  27. GCE_TPU_ACCELERATOR_KEY = "accelerator-type"
  28. GCE_TPU_ENV_KEY = "tpu-env"
  29. GCE_TPU_INSTANCE_ID_KEY = "instance-id"
  30. GCE_TPU_WORKER_ID_KEY = "agent-worker-number"
  31. TPU_VISIBLE_CHIPS_ENV_VAR = "TPU_VISIBLE_CHIPS"
  32. NOSET_TPU_VISIBLE_CHIPS_ENV_VAR = "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS"
  33. # The following defines environment variables that allow
  34. # us to access a subset of TPU visible chips.
  35. #
  36. # See: https://github.com/google/jax/issues/14977 for an example/more details.
  37. TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR = "TPU_CHIPS_PER_HOST_BOUNDS"
  38. TPU_CHIPS_PER_HOST_BOUNDS_1_CHIP_CONFIG = "1,1,1"
  39. TPU_CHIPS_PER_HOST_BOUNDS_2_CHIP_CONFIG = "1,2,1"
  40. TPU_HOST_BOUNDS_ENV_VAR = "TPU_HOST_BOUNDS"
  41. TPU_SINGLE_HOST_BOUNDS = "1,1,1"
  42. # By default TPU VMs come with 4 chips per host and 2 tensorcores per chip.
  43. # For more details: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm
  44. DEFAULT_TPU_NUM_CHIPS_PER_HOST = 4
  45. DEFAULT_TPU_NUM_CORES_PER_CHIP = 2
  46. # Accelerators that are 4 chips per host: v2, v3, v4, v5p, v7x
  47. # Accelerators that are 8 chips per host: v5e, v6e
  48. SINGLE_HOST_8_CHIPS_TPU_TYPES = ("v5litepod", "v6e")
  49. # Accelerators that are 2 cores per chip: v2, v3, v4, v5p, v7x
  50. # Accelerators that are 1 core per chip: v5e, v6e
  51. SINGLE_CORE_TPU_TYPES = ("v5litepod", "v6e")
  52. # The valid TPU types.
  53. VALID_TPU_TYPES = ("v2", "v3", "v4", "v5p", "v5litepod", "v6e", "v7x")
  54. # This is only used to construct TPU 3D topologies
  55. def _get_larger_3d_topologies(max_x: int, max_y: int, max_z: int) -> Set[str]:
  56. """Returns a set of larger 3D TPU topologies given the max x,y,z value. Using DEFAULT_TPU_NUM_CHIPS_PER_HOST as increment"""
  57. topologies = set()
  58. for x in range(
  59. DEFAULT_TPU_NUM_CHIPS_PER_HOST, max_x + 1, DEFAULT_TPU_NUM_CHIPS_PER_HOST
  60. ):
  61. for y in range(
  62. DEFAULT_TPU_NUM_CHIPS_PER_HOST, max_y + 1, DEFAULT_TPU_NUM_CHIPS_PER_HOST
  63. ):
  64. for z in range(
  65. DEFAULT_TPU_NUM_CHIPS_PER_HOST,
  66. max_z + 1,
  67. DEFAULT_TPU_NUM_CHIPS_PER_HOST,
  68. ):
  69. topologies.add(f"{x}x{y}x{z}")
  70. return topologies
  71. # The valid TPU topologies for each of the TPU types.
  72. VALID_TPU_TOPOLOGY = {
  73. "v2": {"4x4", "4x8", "8x8", "8x16", "16x16"},
  74. "v3": {"4x4", "4x8", "8x8", "8x16", "16x16", "16x32", "32x32"},
  75. "v4": {"2x2x1", "2x2x2", "2x2x4", "2x4x4"}.union(
  76. _get_larger_3d_topologies(12, 12, 16)
  77. ),
  78. "v5p": {
  79. "2x2x1",
  80. "2x2x2",
  81. "2x2x4",
  82. "2x4x4",
  83. }.union(_get_larger_3d_topologies(16, 16, 24)),
  84. "v5litepod": {"1x1", "2x2", "2x4", "2x8", "4x4", "4x8", "8x8", "8x16", "16x16"},
  85. "v6e": {"1x1", "2x2", "2x4", "2x8", "4x4", "4x8", "8x8", "8x16", "16x16"},
  86. "v7x": {
  87. "2x2x1",
  88. "2x2x2",
  89. "2x2x4",
  90. "2x4x4",
  91. "4x4x4",
  92. "4x4x8",
  93. "4x8x8",
  94. "8x8x8",
  95. "8x8x16",
  96. "8x16x16",
  97. },
  98. }
  99. def _get_tpu_metadata(key: str) -> Optional[str]:
  100. """Poll and get TPU metadata."""
  101. try:
  102. accelerator_type_request = requests.get(
  103. os.path.join(GCE_TPU_ACCELERATOR_ENDPOINT, key),
  104. headers=GCE_TPU_HEADERS,
  105. )
  106. if (
  107. accelerator_type_request.status_code == 200
  108. and accelerator_type_request.text
  109. ):
  110. return accelerator_type_request.text
  111. else:
  112. logging.debug(
  113. "Unable to poll TPU GCE Metadata. Got "
  114. f"status code: {accelerator_type_request.status_code} and "
  115. f"content: {accelerator_type_request.text}"
  116. )
  117. except requests.RequestException as e:
  118. logging.debug("Unable to poll the TPU GCE Metadata: %s", e)
  119. return None
  120. def _accelerator_type_check(accelerator_type: str):
  121. if not accelerator_type.startswith(VALID_TPU_TYPES):
  122. raise ValueError(
  123. f"Invalid accelerator type: {accelerator_type}. Must start with one of: {VALID_TPU_TYPES}"
  124. )
  125. def get_num_tpu_visible_chips_per_host(accelerator_type: str) -> int:
  126. _accelerator_type_check(accelerator_type)
  127. if accelerator_type.startswith(SINGLE_HOST_8_CHIPS_TPU_TYPES):
  128. return 8
  129. return DEFAULT_TPU_NUM_CHIPS_PER_HOST
  130. def get_tpu_cores_per_chip(accelerator_type: str) -> int:
  131. _accelerator_type_check(accelerator_type)
  132. if accelerator_type.startswith(SINGLE_CORE_TPU_TYPES):
  133. return 1
  134. return DEFAULT_TPU_NUM_CORES_PER_CHIP
  135. def get_num_chips_from_topology(topology: str) -> int:
  136. """
  137. Calculates the total number of chips in a TPU topology.
  138. Ex: "2x2x2" -> 8
  139. """
  140. total_chips = 1
  141. for dim in topology.strip().lower().split("x"):
  142. total_chips *= int(dim)
  143. return total_chips
  144. def infer_tpu_pod_type_from_topology(
  145. topology: str, accelerator_type: str
  146. ) -> Optional[str]:
  147. """Infer the TPU pod type (e.g. v4-32) from topology and accelerator type."""
  148. if not topology or not accelerator_type:
  149. return None
  150. try:
  151. num_chips = get_num_chips_from_topology(topology)
  152. generation = accelerator_type.lower().replace("tpu-", "")
  153. num_cores = num_chips * get_tpu_cores_per_chip(generation)
  154. return f"{generation}-{num_cores}"
  155. except Exception as e:
  156. raise ValueError(
  157. f"Failed to infer pod type from topology '{topology}' "
  158. f"and type '{accelerator_type}'"
  159. ) from e
  160. def fetch_tpu_slice_name_from_pg(pg):
  161. @ray.remote(num_cpus=0)
  162. def _get_tpu_slice_name():
  163. return TPUAcceleratorManager.get_current_node_tpu_name()
  164. tpu_name_ref = _get_tpu_slice_name.options(
  165. scheduling_strategy=PlacementGroupSchedulingStrategy(
  166. placement_group=pg, placement_group_bundle_index=0
  167. )
  168. ).remote()
  169. return ray.get(tpu_name_ref)
  170. def get_chips_per_host(topology: str, accelerator_version: str) -> int:
  171. """Get the number of chips per host based on topology and accelerator version.
  172. The current rule is as follows:
  173. Default chips per host is 4.
  174. If accelerator_version is v5e or v6e:
  175. If topology total chips < 8, return total chips (partial host).
  176. Otherwise return 8.
  177. If accelerator_version is v5p or other versions, the chips per host will be 4
  178. Args:
  179. topology: The TPU topology string (e.g. "2x2x2").
  180. accelerator_version: The accelerator version of the node (e.g. "V4", "v4").
  181. Returns:
  182. A int representing the number of chips per host
  183. """
  184. total_chips = get_num_chips_from_topology(topology)
  185. # Check for 8-chip host types (v5litepod, v6e)
  186. if accelerator_version.strip().lower() in SINGLE_HOST_8_CHIPS_TPU_TYPES:
  187. if total_chips < 8:
  188. return total_chips
  189. return 8
  190. return DEFAULT_TPU_NUM_CHIPS_PER_HOST
  191. def reserve_tpu_slice(
  192. topology: str,
  193. accelerator_type: str,
  194. ) -> Optional[Tuple[str, PlacementGroup]]:
  195. """Reserves a TPU slice using its head resource and returns the slice name.
  196. This enables gang scheduling of training workers with multi-host TPUs.
  197. This is used by JaxTrainer with TPUs in Ray Train.
  198. Args:
  199. topology: The TPU topology string (e.g. "2x2x2").
  200. accelerator_type: The accelerator type of the node (e.g. "TPU-V4").
  201. Returns:
  202. A tuple of a string representing a unique TPU slice name and the placement
  203. group handle reserving the TPU head.
  204. """
  205. pod_type = infer_tpu_pod_type_from_topology(topology, accelerator_type)
  206. if pod_type is None:
  207. return None
  208. # Reserve a slice by creating a placement group on the TPU head.
  209. head_label_selector = {
  210. "ray.io/tpu-worker-id": "0",
  211. "ray.io/tpu-pod-type": pod_type,
  212. }
  213. head_placement_group = placement_group(
  214. bundles=[{f"TPU-{pod_type}-head": 1}],
  215. bundle_label_selector=[head_label_selector],
  216. )
  217. logger.debug("Waiting to reserve multi-host slice head.")
  218. timeout = 100
  219. ready, _ = ray.wait([head_placement_group.ready()], timeout=timeout)
  220. if not ready:
  221. raise TimeoutError(
  222. "Failed to reserve TPU head for slice with shape: {}. "
  223. "Ensure your cluster has sufficient resources. Requesting TPU "
  224. "head node with labels: {}. Current resources: {}".format(
  225. pod_type, head_label_selector, ray.available_resources()
  226. )
  227. )
  228. # Retrieve the unique slice ID.
  229. slice_name = fetch_tpu_slice_name_from_pg(head_placement_group)
  230. if slice_name is None:
  231. raise RuntimeError(
  232. "Failed to retrieve TPU slice name after reserving head placement group. "
  233. "Ensure that TPU slice metadata is available and correctly configured on multi-host nodes."
  234. )
  235. return (slice_name, head_placement_group)
  236. class TPUAcceleratorManager(AcceleratorManager):
  237. """Google TPU accelerators."""
  238. @staticmethod
  239. def get_resource_name() -> str:
  240. return "TPU"
  241. @staticmethod
  242. def get_visible_accelerator_ids_env_var() -> str:
  243. return TPU_VISIBLE_CHIPS_ENV_VAR
  244. @staticmethod
  245. def get_current_process_visible_accelerator_ids() -> Optional[List[str]]:
  246. tpu_visible_chips = os.environ.get(
  247. TPUAcceleratorManager.get_visible_accelerator_ids_env_var(), None
  248. )
  249. if tpu_visible_chips is None:
  250. return None
  251. if tpu_visible_chips == "":
  252. return []
  253. return list(tpu_visible_chips.split(","))
  254. @staticmethod
  255. @lru_cache()
  256. def get_current_node_num_accelerators() -> int:
  257. """Attempt to detect the number of TPUs on this machine.
  258. TPU chips are represented as devices within `/dev/`, either as
  259. `/dev/accel*` or `/dev/vfio/*`.
  260. Returns:
  261. The number of TPUs if any were detected, otherwise 0.
  262. """
  263. accel_files = glob.glob("/dev/accel*")
  264. if accel_files:
  265. return len(accel_files)
  266. try:
  267. vfio_entries = os.listdir("/dev/vfio")
  268. numeric_entries = [int(entry) for entry in vfio_entries if entry.isdigit()]
  269. return len(numeric_entries)
  270. except FileNotFoundError as e:
  271. logger.debug("Failed to detect number of TPUs: %s", e)
  272. return 0
  273. @staticmethod
  274. def is_valid_tpu_accelerator_type(tpu_accelerator_type: str) -> bool:
  275. """Check whether the tpu accelerator_type is formatted correctly.
  276. The accelerator_type field typically follows a form of v{generation}-{cores/chips},
  277. but newer generations like 7x may follow tpu{generation}-{cores/chips}.
  278. See the following for more information:
  279. https://cloud.google.com/sdk/gcloud/reference/compute/tpus/tpu-vm/accelerator-types/describe
  280. Args:
  281. tpu_accelerator_type: The string representation of the accelerator type
  282. to be checked for validity.
  283. Returns:
  284. True if it's valid, false otherwise.
  285. """
  286. # 1. Legacy format: v2-8, v3-32.
  287. # 2. Newer format with letters in generation: v5litepod-16, v6e-4.
  288. # 3. Ironwood TPU format which contains a tpu prefix: tpu7x-16.
  289. expected_pattern = re.compile(r"^(v|tpu)\d+[a-zA-Z]*-\d+$")
  290. if not expected_pattern.match(tpu_accelerator_type):
  291. return False
  292. return True
  293. @staticmethod
  294. def is_valid_tpu_accelerator_topology(
  295. tpu_accelerator_version: str, tpu_topology: str
  296. ) -> bool:
  297. """Check whether the tpu topology is valid.
  298. The accelerator_type field follows a form of v{generation}.
  299. The accelerator_topology field follows either the form {A}x{B} or {A}x{B}x{C} depending on the v{generation}
  300. Args:
  301. tpu_accelerator_version: The string representation of the accelerator version. (e.g. v6e, V5P)
  302. tpu_topology: The string representation of the accelerator topology
  303. to be checked for validity
  304. Returns:
  305. True if it's a valid topology, False otherwise.
  306. """
  307. tpu_version_formatted = tpu_accelerator_version.strip().lower().split("-")[0]
  308. if tpu_version_formatted.startswith("tpu"):
  309. tpu_version_formatted = "v" + tpu_version_formatted[3:]
  310. if (
  311. tpu_version_formatted.lower() not in VALID_TPU_TOPOLOGY
  312. or tpu_topology.strip().lower()
  313. not in VALID_TPU_TOPOLOGY[tpu_version_formatted]
  314. ):
  315. return False
  316. return True
  317. @staticmethod
  318. def validate_resource_request_quantity(
  319. quantity: float,
  320. ) -> Tuple[bool, Optional[str]]:
  321. if quantity not in TPU_VALID_CHIP_OPTIONS:
  322. return (
  323. False,
  324. f"The number of requested 'TPU' was set to {quantity} which "
  325. "is not a supported chip configuration. Supported configs: "
  326. f"{TPU_VALID_CHIP_OPTIONS}",
  327. )
  328. else:
  329. return (True, None)
  330. @staticmethod
  331. def set_current_process_visible_accelerator_ids(
  332. visible_tpu_chips: List[str],
  333. ) -> None:
  334. """Set TPU environment variables based on the provided visible_tpu_chips.
  335. To access a subset of the TPU visible chips, we must use a combination of
  336. environment variables that tells the compiler (via ML framework) the:
  337. - Visible chips
  338. - The physical bounds of chips per host
  339. - The host bounds within the context of a TPU pod.
  340. See: https://github.com/google/jax/issues/14977 for an example/more details.
  341. Args:
  342. visible_tpu_chips (List[str]): List of int representing TPU chips.
  343. """
  344. if env_bool(NOSET_TPU_VISIBLE_CHIPS_ENV_VAR, False):
  345. return
  346. num_visible_tpu_chips = len(visible_tpu_chips)
  347. num_accelerators_on_node = (
  348. TPUAcceleratorManager.get_current_node_num_accelerators()
  349. )
  350. if num_visible_tpu_chips == num_accelerators_on_node:
  351. # Let the ML framework use the defaults
  352. os.environ.pop(TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR, None)
  353. os.environ.pop(TPU_HOST_BOUNDS_ENV_VAR, None)
  354. return
  355. os.environ[
  356. TPUAcceleratorManager.get_visible_accelerator_ids_env_var()
  357. ] = ",".join([str(i) for i in visible_tpu_chips])
  358. if num_visible_tpu_chips == 1:
  359. os.environ[
  360. TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR
  361. ] = TPU_CHIPS_PER_HOST_BOUNDS_1_CHIP_CONFIG
  362. os.environ[TPU_HOST_BOUNDS_ENV_VAR] = TPU_SINGLE_HOST_BOUNDS
  363. elif num_visible_tpu_chips == 2:
  364. os.environ[
  365. TPU_CHIPS_PER_HOST_BOUNDS_ENV_VAR
  366. ] = TPU_CHIPS_PER_HOST_BOUNDS_2_CHIP_CONFIG
  367. os.environ[TPU_HOST_BOUNDS_ENV_VAR] = TPU_SINGLE_HOST_BOUNDS
  368. @staticmethod
  369. def get_current_node_tpu_pod_type() -> Optional[str]:
  370. """Get the TPU pod type of the current node if applicable.
  371. Individual TPU VMs within a TPU pod must know what type
  372. of pod it is a part of. This is necessary for the
  373. ML framework to work properly.
  374. The logic is different if the TPU was provisioned via:
  375. ```
  376. gcloud tpus tpu-vm create ...
  377. ```
  378. (i.e. a GCE VM), vs through GKE:
  379. - GCE VMs will always have a metadata server to poll this info
  380. - GKE VMS will have environment variables preset.
  381. Returns:
  382. A string representing the current TPU pod type, e.g.
  383. v4-16.
  384. """
  385. # Start with GKE-based check
  386. accelerator_type = os.getenv(GKE_TPU_ACCELERATOR_TYPE_ENV_VAR, "")
  387. if not accelerator_type:
  388. # GCE-based VM check
  389. accelerator_type = _get_tpu_metadata(key=GCE_TPU_ACCELERATOR_KEY)
  390. if accelerator_type and TPUAcceleratorManager.is_valid_tpu_accelerator_type(
  391. tpu_accelerator_type=accelerator_type
  392. ):
  393. if accelerator_type.lower().startswith("tpu"):
  394. return "v" + accelerator_type.lower()[3:]
  395. return accelerator_type
  396. logging.debug("Failed to get a valid accelerator type.")
  397. return None
  398. @staticmethod
  399. def get_current_node_tpu_name() -> Optional[str]:
  400. """Return the name of the TPU pod that this worker node is a part of.
  401. For instance, if the TPU was created with name "my-tpu", this function
  402. will return "my-tpu".
  403. If created through the Ray cluster launcher, the
  404. name will typically be something like "ray-my-tpu-cluster-worker-aa946781-tpu".
  405. In case the TPU was created through KubeRay, we currently expect that the
  406. environment variable TPU_NAME is set per TPU pod slice, in which case
  407. this function will return the value of that environment variable.
  408. """
  409. try:
  410. # Start with GKE-based check
  411. tpu_name = os.getenv(GKE_TPU_NAME_ENV_VAR, None)
  412. if not tpu_name:
  413. # GCE-based VM check
  414. tpu_name = _get_tpu_metadata(key=GCE_TPU_INSTANCE_ID_KEY)
  415. return tpu_name
  416. except ValueError as e:
  417. logging.debug("Could not get TPU name: %s", e)
  418. return None
  419. @staticmethod
  420. def get_current_node_tpu_worker_id() -> Optional[int]:
  421. """Return the worker index of the TPU pod."""
  422. try:
  423. # Start with GKE-based check
  424. worker_id = os.getenv(GKE_TPU_WORKER_ID_ENV_VAR, None)
  425. if not worker_id:
  426. # GCE-based VM check
  427. worker_id = _get_tpu_metadata(key=GCE_TPU_WORKER_ID_KEY)
  428. if worker_id:
  429. return int(worker_id)
  430. else:
  431. return None
  432. except ValueError as e:
  433. logging.debug("Could not get TPU worker id: %s", e)
  434. return None
  435. @staticmethod
  436. def get_num_workers_in_current_tpu_pod() -> Optional[int]:
  437. """Return the total number of workers in a TPU pod."""
  438. tpu_pod_type = TPUAcceleratorManager.get_current_node_tpu_pod_type()
  439. chips_per_host = TPUAcceleratorManager.get_current_node_num_accelerators()
  440. cores_per_chip = get_tpu_cores_per_chip(tpu_pod_type) # Hard-coded map.
  441. cores_per_host = chips_per_host * cores_per_chip
  442. if tpu_pod_type and cores_per_host > 0:
  443. num_cores = int(tpu_pod_type.split("-")[1])
  444. num_workers = num_cores // cores_per_host
  445. # If the chip count doesn't fill a full host, a sub-host is still treated as a host.
  446. if num_cores % cores_per_host != 0:
  447. num_workers += 1
  448. return num_workers
  449. else:
  450. logging.debug("Could not get num workers in TPU pod.")
  451. return None
  452. @staticmethod
  453. def get_current_node_tpu_topology() -> Optional[str]:
  454. try:
  455. # Attempt GKE based lookup first
  456. if topology := os.environ.get(GKE_TPU_TOPOLOGY_ENV_VAR):
  457. return topology.strip().lower()
  458. # GCE-based VM check using TPU env string.
  459. tpu_env = _get_tpu_metadata(key=GCE_TPU_ENV_KEY)
  460. if tpu_env:
  461. topology = re.search(r"TOPOLOGY:\s*'([^']+)'", tpu_env)
  462. if topology:
  463. return topology.group(1).strip().lower()
  464. except ValueError as e:
  465. logging.debug("Could not get TPU topology: %s", e)
  466. return None
  467. @staticmethod
  468. def get_current_node_accelerator_type() -> Optional[str]:
  469. """Attempt to detect the TPU accelerator type.
  470. The output of this function will return the "ray accelerator type"
  471. resource (e.g. TPU-V4) that indicates the TPU version.
  472. We also expect that our TPU nodes contain a "TPU pod type"
  473. resource, which indicates information about the topology of
  474. the TPU pod slice.
  475. We expect that the "TPU pod type" resource to be used when
  476. running multi host workers, i.e. when TPU units are pod slices.
  477. We expect that the "ray accelerator type" resource to be used when
  478. running single host workers, i.e. when TPU units are single hosts.
  479. Returns:
  480. A string representing the TPU accelerator type,
  481. e.g. "TPU-V2", "TPU-V3", "TPU-V4" if applicable, else None.
  482. """
  483. def tpu_pod_type_to_ray_accelerator_type(
  484. tpu_pod_type: str,
  485. ) -> Optional[str]:
  486. return "TPU-" + str(tpu_pod_type.split("-")[0].upper())
  487. ray_accelerator_type = None
  488. tpu_pod_type = TPUAcceleratorManager.get_current_node_tpu_pod_type()
  489. if tpu_pod_type is not None:
  490. ray_accelerator_type = tpu_pod_type_to_ray_accelerator_type(
  491. tpu_pod_type=tpu_pod_type
  492. )
  493. if ray_accelerator_type is None:
  494. logger.info(
  495. "While trying to autodetect a TPU type, "
  496. f"received malformed accelerator_type: {tpu_pod_type}"
  497. )
  498. if ray_accelerator_type is None:
  499. logging.info("Failed to auto-detect TPU type.")
  500. return ray_accelerator_type
  501. @staticmethod
  502. def get_current_node_additional_resources() -> Optional[Dict[str, float]]:
  503. """Get additional resources required for TPU nodes.
  504. This will populate the TPU pod type and the TPU name which
  505. is used for TPU pod execution.
  506. When running workloads on a TPU pod, we need a way to run
  507. the same binary on every worker in the TPU pod.
  508. See https://jax.readthedocs.io/en/latest/multi_process.html
  509. for more information.
  510. To do this in ray, we take advantage of custom resources. We
  511. mark worker 0 of the TPU pod as a "coordinator" that identifies
  512. the other workers in the TPU pod. We therefore need:
  513. - worker 0 to be targetable.
  514. - all workers in the TPU pod to have a unique identifier consistent
  515. within a TPU pod.
  516. So assuming we want to run the following workload:
  517. @ray.remote
  518. def my_jax_fn():
  519. import jax
  520. return jax.device_count()
  521. We could broadcast this on a TPU pod (e.g. a v4-16) as follows:
  522. @ray.remote(resources={"TPU-v4-16-head"})
  523. def run_jax_fn(executable):
  524. # Note this will execute on worker 0
  525. tpu_name = ray.util.tpu.get_current_pod_name()
  526. num_hosts = ray.util.tpu.get_current_pod_worker_count()
  527. tpu_executable = executable.options(resources={"TPU": 4, tpu_name: 1})
  528. return [tpu_executable.remote() for _ in range(num_hosts)]
  529. Returns:
  530. A dictionary representing additional resources that may be
  531. necessary for a particular accelerator type.
  532. """
  533. resources = {}
  534. tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
  535. worker_id = TPUAcceleratorManager.get_current_node_tpu_worker_id()
  536. tpu_pod_type = TPUAcceleratorManager.get_current_node_tpu_pod_type()
  537. if tpu_name and worker_id is not None and tpu_pod_type:
  538. pod_head_resource_name = f"TPU-{tpu_pod_type}-head"
  539. # Add the name of the TPU to the resource.
  540. resources[tpu_name] = 1
  541. # Only add in the TPU pod type resource to worker 0.
  542. if worker_id == 0:
  543. resources[pod_head_resource_name] = 1
  544. else:
  545. logging.info(
  546. "Failed to configure TPU pod. Got: "
  547. "tpu_name: %s, worker_id: %s, accelerator_type: %s",
  548. tpu_name,
  549. worker_id,
  550. tpu_pod_type,
  551. )
  552. if resources:
  553. return resources
  554. return None
  555. @staticmethod
  556. def get_current_node_accelerator_labels() -> Dict[str, str]:
  557. """Get default TPU-specific Ray node labels for the current node.
  558. For TPUs, these labels include:
  559. - ray.io/tpu-slice-name: the name of the TPU Pod or slice
  560. - ray.io/tpu-worker-id: the integer worker ID within the slice
  561. - ray.io/tpu-topology: the TPU topology (e.g. 4x4)
  562. - ray.io/tpu-pod-type: the TPU pod type (e.g. v4-8)
  563. Returns:
  564. A dictionary of TPU label keys and resolved values.
  565. """
  566. tpu_labels = {}
  567. tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
  568. if tpu_name:
  569. tpu_labels[ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY] = tpu_name
  570. worker_id = TPUAcceleratorManager.get_current_node_tpu_worker_id()
  571. if worker_id is not None:
  572. tpu_labels[ray._raylet.RAY_NODE_TPU_WORKER_ID_KEY] = str(worker_id)
  573. tpu_topology = TPUAcceleratorManager.get_current_node_tpu_topology()
  574. if tpu_topology:
  575. tpu_labels[ray._raylet.RAY_NODE_TPU_TOPOLOGY_KEY] = tpu_topology
  576. pod_type = TPUAcceleratorManager.get_current_node_tpu_pod_type()
  577. if pod_type:
  578. tpu_labels[ray._raylet.RAY_NODE_TPU_POD_TYPE_KEY] = pod_type
  579. return tpu_labels