device_mesh.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. import os
  5. import threading
  6. import warnings
  7. from collections.abc import Iterator
  8. from itertools import zip_longest
  9. from typing import Optional, TYPE_CHECKING, Union
  10. import torch
  11. from torch.distributed import is_available
  12. from torch.distributed._mesh_layout import _MeshLayout
  13. from torch.distributed._pycute import IntTuple, is_int, suffix_product
  14. from torch.utils._typing_utils import not_none
  15. __all__ = ["init_device_mesh", "DeviceMesh"]
  16. if not is_available():
  17. import sys
  18. # We need to create the stubs when distributed is not available.
  19. # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
  20. # since it would try to import ``torch.distributed.device_mesh`` or
  21. # ``torch.distributed.init_device_mesh`` but cannot find them.
  22. class _DeviceMeshStub:
  23. pass
  24. def _init_device_mesh_stub():
  25. pass
  26. sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
  27. # pyrefly: ignore [missing-attribute]
  28. sys.modules[
  29. "torch.distributed.device_mesh"
  30. ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
  31. else:
  32. from torch._C._distributed_c10d import Backend as C10dBackend
  33. from torch.distributed import config as dist_config
  34. from torch.distributed.distributed_c10d import (
  35. _get_default_group,
  36. _resolve_process_group,
  37. get_backend,
  38. get_process_group_ranks,
  39. get_rank,
  40. get_world_size,
  41. GroupName,
  42. init_process_group,
  43. is_initialized,
  44. new_group,
  45. ProcessGroup,
  46. split_group,
  47. )
  48. logger = logging.getLogger(__name__)
  49. # only import numpy typing when type checking
  50. if TYPE_CHECKING:
  51. try:
  52. from numpy.typing import ArrayLike
  53. except ImportError:
  54. logger.warning(
  55. "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
  56. )
  57. BackendConfig = tuple[str | None, C10dBackend.Options | None]
  58. torch.serialization.add_safe_globals([_MeshLayout])
  59. def _get_pg_from_name(mesh: "DeviceMesh", name: str) -> ProcessGroup:
  60. """
  61. This method allows us to torch.compile through DeviceMesh and lift its
  62. PGs a inputs to the graph since all PGs will have a source from the
  63. DeviceMesh through the `_pg_registry`.
  64. This will be moved to the DeviceMesh backend object once we separate
  65. DeviceMesh into the frontend and backend.
  66. """
  67. if torch.compiler.is_compiling():
  68. pg = mesh._pg_registry.get(name, None)
  69. if pg is None:
  70. raise RuntimeError(
  71. f"PG {name} was not found while torch.compile tracing "
  72. "This is probably because we pickle/unpickled a device mesh "
  73. "before the PGs were created."
  74. )
  75. return pg
  76. else:
  77. return _resolve_process_group(name) # pyrefly: ignore[bad-argument-type]
  78. class _MeshEnv(threading.local):
  79. def __init__(self) -> None:
  80. self.mesh_stack: list[DeviceMesh] = []
  81. def get_current_mesh(self) -> "DeviceMesh":
  82. if len(self.mesh_stack) == 0:
  83. raise RuntimeError("No device mesh is currently active!")
  84. return self.mesh_stack[-1]
  85. # TODO: to remove it once we move all use cases into new API.
  86. def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh":
  87. # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself.
  88. # A root mesh is not created through slicing.
  89. # We considers the root mesh of a root mesh is itself.
  90. # We keep this function for backward compatibility.
  91. warnings.warn(
  92. "This get_root_mesh API will be deprecated soon."
  93. "Please use `get_root_mesh` inside DeviceMesh instead.",
  94. stacklevel=2,
  95. )
  96. if not device_mesh:
  97. return device_mesh
  98. return device_mesh._get_root_mesh()
  99. @staticmethod
  100. def num_devices_per_host(device_type: str) -> int:
  101. return _get_device_handle(device_type).device_count()
  102. @staticmethod
  103. def num_hosts(device_type: str) -> int:
  104. # ProcessGroup can't tell us this info so we have to infer it, assume
  105. # homogeneous hardware for now
  106. return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
  107. # TODO: to remove it once we move all use cases into new API.
  108. # We keep this API for backward compatibility.
  109. def _get_all_submeshes(
  110. self, device_mesh: "DeviceMesh", mesh_dim_name: str
  111. ) -> list["DeviceMesh"]:
  112. warnings.warn(
  113. "This _get_all_submeshes API will be deprecated soon."
  114. "Please use `_get_all_submeshes` inside DeviceMesh instead.",
  115. stacklevel=2,
  116. )
  117. return device_mesh._get_all_submeshes(mesh_dim_name)
  118. _mesh_resources: _MeshEnv = _MeshEnv()
  119. def _get_device_handle(device_type: str = "cuda"):
  120. """
  121. Get the module corresponding to the device_type which is cuda or cuda-like device.
  122. For example, when the device_type is cuda, the module `torch.cuda` is returned.
  123. Return None when there is no corresponding module for device_type, otherwise
  124. return the corresponding module.
  125. """
  126. return getattr(torch, device_type, None)
  127. class DeviceMesh:
  128. """
  129. DeviceMesh represents a mesh of devices, where layout of devices could be
  130. represented as a n-d dimension array, and each value of the n-d dimensional
  131. array is the global id of the default process group ranks.
  132. DeviceMesh could be used to setup the N dimensional device connections across the cluster,
  133. and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on
  134. each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects
  135. already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization),
  136. and will select/set the device for the current process if user does not set the device
  137. beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization.
  138. DeviceMesh can also be used as a context manager when using together with DTensor APIs.
  139. .. note::
  140. DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
  141. is running on all processes/ranks in the cluster. Therefore, users need to make sure the
  142. `mesh` array (which describes the layout of devices) should be identical across all ranks.
  143. Inconsistent `mesh` will lead to silent hang.
  144. Args:
  145. device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
  146. mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
  147. of devices, where the IDs are global IDs of the default process group.
  148. _rank (int): (experimental/internal)
  149. The global rank of the current process. If not provided, it will
  150. be inferred from the default process group.
  151. Returns:
  152. DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
  153. The following program runs on each process/rank in an SPMD manner. In this example, we have 2
  154. hosts with 4 GPUs each.
  155. A reduction over the first dimension of mesh will reduce across
  156. columns (0, 4), .. and (3, 7), a reduction over the second dimension
  157. of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
  158. Example::
  159. >>> # xdoctest: +SKIP("no rank")
  160. >>> from torch.distributed.device_mesh import DeviceMesh
  161. >>>
  162. >>> # Initialize device mesh as (2, 4) to represent the topology
  163. >>> # of cross-host(dim 0), and within-host (dim 1).
  164. >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
  165. """
  166. _rank: int
  167. _device_type: str
  168. _rank_map: torch.Tensor
  169. _mesh_dim_names: tuple[str, ...] | None
  170. _layout: _MeshLayout
  171. _root_mesh: Optional["DeviceMesh"] = None
  172. # Record flatten mesh name to its flattened mesh in root mesh.
  173. _flatten_mapping: dict[str, "DeviceMesh"]
  174. # Registry mapping group names to ProcessGroup objects (to avoid C++ lookup)
  175. _pg_registry: dict[str, ProcessGroup]
  176. def __init__(
  177. self,
  178. device_type: str,
  179. mesh: Union[torch.Tensor, "ArrayLike"] | None = None,
  180. *,
  181. mesh_dim_names: tuple[str, ...] | None = None,
  182. backend_override: tuple[BackendConfig, ...] | None = None,
  183. _init_backend: bool = True,
  184. _rank: int | None = None,
  185. _layout: _MeshLayout | None = None,
  186. _rank_map: torch.Tensor | None = None,
  187. _root_mesh: Optional["DeviceMesh"] = None,
  188. ) -> None:
  189. # no-op in OSS, logs API usage metrics in meta-internal runs
  190. torch._C._log_api_usage_once(
  191. "torch.distributed.device_mesh.DeviceMesh.__init__"
  192. )
  193. if mesh is not None:
  194. if _layout is not None or _rank_map is not None:
  195. raise TypeError(
  196. "Cannot provide _layout and/or _rank_map if passing explicit mesh"
  197. )
  198. if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
  199. raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
  200. mesh_tensor = (
  201. mesh.detach().to(dtype=torch.int).contiguous()
  202. if isinstance(mesh, torch.Tensor)
  203. else torch.tensor(mesh, device="cpu", dtype=torch.int)
  204. )
  205. _layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
  206. _rank_map = mesh_tensor.flatten()
  207. else:
  208. if _layout is None or _rank_map is None:
  209. raise TypeError(
  210. "The mesh argument is required except for PRIVATE USAGE ONLY!"
  211. )
  212. assert _layout.check_non_overlap(), (
  213. "Please use a non-overlapping layout when creating a DeviceMesh."
  214. )
  215. assert _rank_map.ndim == 1, "The rank map must be 1-dimensional"
  216. assert _rank_map.is_contiguous(), "The rank map must be contiguous"
  217. assert _rank_map.numel() >= _layout.cosize(), (
  218. f"The rank map contains {_rank_map.numel()} element, "
  219. f"which isn't large enough for layout {_layout}"
  220. )
  221. self._device_type = device_type
  222. self._layout = _layout
  223. self._rank_map = _rank_map
  224. self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
  225. self._root_mesh = _root_mesh
  226. if backend_override is None:
  227. backend_override = ((None, None),) * len(self._layout)
  228. elif len(backend_override) != len(self._layout):
  229. raise ValueError(
  230. f"backend_override should have the same length as the number of mesh dimensions, "
  231. f"but got {len(backend_override)} and {len(self._layout)}."
  232. )
  233. # Internal bookkeeping for the device mesh.
  234. self._layout = (
  235. _layout
  236. if _layout
  237. else _MeshLayout(self.mesh.size(), self.mesh.stride())
  238. )
  239. if not self._layout.check_non_overlap():
  240. raise AssertionError(
  241. "Please use a non-overlapping layout when creating a DeviceMesh."
  242. )
  243. # Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
  244. if self._layout.numel() != self.mesh.numel():
  245. raise AssertionError(
  246. "Please use a valid layout when creating a DeviceMesh."
  247. f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
  248. )
  249. # private field to pre-generate DeviceMesh's hash
  250. self._flatten_rank_map = tuple(self._rank_map.tolist())
  251. self._thread_id = None
  252. # Initialize instance-specific flatten mapping
  253. self._flatten_mapping = {}
  254. # Initialize process group registry
  255. self._pg_registry = {}
  256. # Skip process group initialization if xla device or init backend is False
  257. # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
  258. if device_type != "xla":
  259. # always try to create default (world) pg, even if it is not initialized
  260. # already. The world pg is used for device mesh identity (rank) on each
  261. # process (we need to know if the current global rank is in the mesh or not).
  262. if _init_backend:
  263. self._setup_world_group_and_device()
  264. self._dim_group_names = self._init_process_groups(
  265. self._layout,
  266. self._rank_map,
  267. self._mesh_dim_names,
  268. backend_override,
  269. )
  270. # Populate the process group registry
  271. # If we have a root mesh, add to root's registry for lookups
  272. target_registry = (
  273. self._root_mesh._pg_registry
  274. if self._root_mesh is not None
  275. else self._pg_registry
  276. )
  277. for name in self._dim_group_names:
  278. pg = _resolve_process_group(name)
  279. if pg is not None:
  280. target_registry[name] = pg
  281. if is_initialized() and get_backend() == "threaded":
  282. # pyrefly: ignore [bad-assignment]
  283. self._thread_id = threading.get_ident()
  284. # Now that the process group is initialized, we can get the rank
  285. if _rank is None:
  286. self._rank = get_rank()
  287. else:
  288. self._rank = _rank
  289. self._coordinate_on_dim = self._compute_coordinate_on_dim()
  290. @staticmethod
  291. def _compute_coordinates_from_mesh(
  292. mesh_tensor: torch.Tensor,
  293. rank: int,
  294. ) -> tuple[int, ...] | None:
  295. """
  296. Compute the coordinates of a rank within a mesh tensor.
  297. Args:
  298. mesh_tensor: The mesh tensor to search in
  299. rank: The rank to find coordinates for
  300. Returns:
  301. A tuple of coordinates if the rank is found in the mesh, None otherwise
  302. Raises:
  303. AssertionError: If the rank appears more than once in the mesh
  304. """
  305. rank_coords = (mesh_tensor == rank).nonzero()
  306. if rank_coords.size(0) not in (0, 1):
  307. raise AssertionError(
  308. f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
  309. )
  310. if rank_coords.size(0) == 0:
  311. return None
  312. coords = rank_coords[0].tolist()
  313. return tuple(coords)
  314. def _compute_coordinate_on_dim(self) -> tuple[int, ...] | None:
  315. # calculate the coordinates of the current global rank on the mesh
  316. return self._compute_coordinates_from_mesh(self.mesh, self._rank)
  317. def __getstate__(self) -> dict:
  318. # Exclude _pg_registry from pickle since ProcessGroup objects can't be pickled
  319. state = self.__dict__.copy()
  320. state.pop("_pg_registry", None)
  321. return state
  322. def __setstate__(self, state: dict) -> None:
  323. self.__dict__.update(state)
  324. # Reconstruct _pg_registry from _dim_group_names
  325. self._pg_registry = {}
  326. if hasattr(self, "_dim_group_names"):
  327. for name in self._dim_group_names:
  328. try:
  329. pg = _resolve_process_group(name)
  330. if pg is not None:
  331. self._pg_registry[name] = pg
  332. except RuntimeError:
  333. # Note: process groups may not exist if loading in a different process
  334. logger.warning(
  335. "It seems like pickling/unpickling of the DeviceMesh "
  336. "occurred before the PGs were created. This will cause PG "
  337. "lookup to fail when torch.compile is enabled"
  338. )
  339. @property
  340. def device_type(self) -> str:
  341. """Returns the device type of the mesh."""
  342. return self._device_type
  343. @staticmethod
  344. def _get_mesh_tensor_from_full_mesh(
  345. full_mesh: torch.Tensor,
  346. current_rank: int | None = None,
  347. ) -> torch.Tensor:
  348. if full_mesh.size(0) == 1:
  349. return full_mesh[0]
  350. if current_rank is None:
  351. current_rank = get_rank()
  352. my_coords = (full_mesh == current_rank).nonzero()
  353. if my_coords.size(0) > 0:
  354. return full_mesh[my_coords[0, 0]]
  355. raise RuntimeError(
  356. "In order to get the mesh Tensor of a DeviceMesh it needs to "
  357. "either have all its original dimensions (e.g., no slicing) "
  358. "or it needs to contain the local rank"
  359. )
  360. @property
  361. def mesh(self) -> torch.Tensor:
  362. """Returns the tensor representing the layout of devices."""
  363. full_mesh = self._layout.remap_to_tensor(self._rank_map)
  364. return self._get_mesh_tensor_from_full_mesh(full_mesh)
  365. @property
  366. def mesh_dim_names(self) -> tuple[str, ...] | None:
  367. """Returns the names of mesh dimensions."""
  368. return self._mesh_dim_names
  369. def _setup_world_group_and_device(self):
  370. default_initialized = is_initialized()
  371. if not default_initialized:
  372. init_process_group()
  373. world_size = get_world_size()
  374. if self._layout.numel() > world_size:
  375. raise RuntimeError(
  376. f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
  377. )
  378. # Skip device setup for fake backend (cross-compilation mode).
  379. # The fake backend is used to simulate distributed training on a
  380. # single process without actual devices, enabling compilation of
  381. # GPU programs on CPU-only machines.
  382. backend = get_backend()
  383. if backend == "fake":
  384. return _get_default_group()
  385. # ONLY set the device if the current device is not initialized, if user already
  386. # set the device before DeviceMesh init, we respect the user's choice.
  387. device_handle = _get_device_handle(self._device_type)
  388. if device_handle and not device_handle.is_initialized():
  389. # auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK
  390. # env variable from launchers, we use it to set the device.
  391. if "LOCAL_RANK" in os.environ:
  392. local_rank = int(os.environ["LOCAL_RANK"])
  393. logger.info(
  394. "Setting default device for the current process based on LOCAL_RANK=%s",
  395. local_rank,
  396. )
  397. device_handle.set_device(local_rank)
  398. else:
  399. # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host
  400. # NOTE: This device selection would only work for homogeneous hardware.
  401. num_devices_per_host = device_handle.device_count()
  402. # Skip device setup if no devices are available (cross-compilation mode)
  403. if num_devices_per_host == 0:
  404. return _get_default_group()
  405. warnings.warn(
  406. "It seems like you did not set/select the default device for the current process before the DeviceMesh "
  407. "initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. "
  408. "It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that "
  409. "the underlying communicator (i.e. NCCL) can be initialized properly. "
  410. "Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the "
  411. "device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. ",
  412. stacklevel=2,
  413. )
  414. if (
  415. world_size > num_devices_per_host
  416. and world_size % num_devices_per_host != 0
  417. ):
  418. raise RuntimeError(
  419. f"DeviceMesh only support homogeneous hardware, but found "
  420. f"{world_size} ranks and {num_devices_per_host} {self._device_type} devices!"
  421. )
  422. device_handle.set_device(get_rank() % num_devices_per_host)
  423. return _get_default_group()
  424. @staticmethod
  425. def _init_one_process_group(
  426. sub_layout: _MeshLayout,
  427. rank_map: torch.Tensor,
  428. dim_name: str,
  429. backend_override: BackendConfig,
  430. ) -> GroupName | None:
  431. # Generate a 2D global mesh tensor for the current dim for PG creation.
  432. pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
  433. backend, pg_options = backend_override
  434. # We need to explicitly pass in timeout when specified in option, otherwise
  435. # the default timeout will be used to override the timeout set in option.
  436. # TODO: remove this once we have fixed inside c10d level.
  437. timeout = pg_options._timeout if pg_options else None
  438. # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description
  439. # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`.
  440. # If the mesh doesn't have a mesh_dim_names, then the group description of the
  441. # subgroup would be `mesh_dim_0` and `mesh_dim_1`.
  442. group_desc = f"mesh_{dim_name}"
  443. dim_group = None
  444. default_group = _get_default_group()
  445. # Early return if there is only one sub_layout in the mesh layout.
  446. if sub_layout.numel() == get_world_size() and backend_override == (
  447. None,
  448. None,
  449. ):
  450. # Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
  451. # Otherwise, create new pg.
  452. ranks = list(range(get_world_size()))
  453. dim_group = (
  454. new_group(
  455. backend=backend,
  456. ranks=ranks,
  457. group_desc="mesh_default",
  458. )
  459. if torch.cuda.is_available()
  460. and get_backend(default_group) == "gloo"
  461. else default_group
  462. )
  463. return dim_group.group_name # type: ignore[union-attr]
  464. # If bound_device_id exists, it means the nccl communicator has been eagerly initialized
  465. # so that we can use `split_group` to create subgroups through `ncclCommSplit`.
  466. # In this case, we only need to make one API call (`split_group``) for the subgroup creation
  467. # for each mesh dimension. In a 2 * 4 mesh, we only need to make two API calls per ranks to create
  468. # all the subgroups.
  469. # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The
  470. # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
  471. # mesh, we need to make two API calls per ranks to create all the subgroups.
  472. if (
  473. (
  474. getattr(default_group, "bound_device_id", None) is not None
  475. or dist_config.use_torchcomms
  476. )
  477. and torch.cuda.is_available()
  478. and (
  479. backend is None
  480. or default_group._get_backend(torch.device("cuda")).name()
  481. == backend
  482. )
  483. ):
  484. dim_group = split_group(
  485. parent_pg=default_group,
  486. timeout=timeout,
  487. pg_options=pg_options,
  488. split_ranks=pg_ranks_by_dim.tolist(),
  489. group_desc=group_desc,
  490. )
  491. if dim_group is None:
  492. return None
  493. return dim_group.group_name
  494. # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
  495. # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
  496. # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
  497. # along with appending information to the `dim_group_names` list whenever necessary.
  498. pg_name = None
  499. for dim_mesh in pg_ranks_by_dim:
  500. subgroup_ranks = dim_mesh.tolist()
  501. dim_group = new_group(
  502. ranks=subgroup_ranks,
  503. timeout=timeout,
  504. backend=backend,
  505. pg_options=pg_options,
  506. group_desc=group_desc,
  507. )
  508. # only add to dim_groups if the current rank in the subgroup
  509. if get_rank() in subgroup_ranks:
  510. if pg_name is not None:
  511. raise RuntimeError(
  512. f"Each device mesh dimension should get only one process group, but got {get_rank()} "
  513. f"in {subgroup_ranks}!"
  514. )
  515. pg_name = dim_group.group_name
  516. return pg_name
  517. @staticmethod
  518. def _init_process_groups(
  519. layout: _MeshLayout,
  520. rank_map: torch.Tensor,
  521. mesh_dim_names: tuple[str, ...] | None,
  522. backend_override: tuple[BackendConfig, ...],
  523. ) -> list[GroupName]:
  524. # group_name associated with each mesh dimension, each
  525. # mesh dimension should have one sub-group per rank
  526. dim_group_names: list[GroupName | None] = []
  527. # create sub pgs base on the mesh argument specified
  528. for dim in range(len(layout)):
  529. dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}"
  530. dim_group_names.append(
  531. DeviceMesh._init_one_process_group(
  532. layout[dim],
  533. rank_map,
  534. dim_name,
  535. backend_override[dim],
  536. )
  537. )
  538. # Filter out None values. If any are None then they should all be None.
  539. dim_non_none_group_names = [n for n in dim_group_names if n is not None]
  540. assert not dim_non_none_group_names or len(dim_non_none_group_names) == len(
  541. dim_group_names
  542. )
  543. return dim_non_none_group_names
  544. def _get_root_mesh(self) -> "DeviceMesh":
  545. return self._root_mesh if self._root_mesh else self
  546. def __enter__(self) -> "DeviceMesh":
  547. # set this mesh as the current mesh in mesh env
  548. _mesh_resources.mesh_stack.append(self)
  549. return self
  550. # pyre-fixme[2]: Parameter must be annotated.
  551. def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
  552. # pop this mesh from mesh env
  553. _mesh_resources.mesh_stack.pop()
  554. def __repr__(self) -> str:
  555. device_mesh_repr = (
  556. f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
  557. if self._mesh_dim_names
  558. else f"{self._layout.top_level_sizes}"
  559. )
  560. device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
  561. # We only print the mesh tensor if the debug mode is turned on.
  562. if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
  563. device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
  564. return f"{device_mesh_repr})"
  565. def __hash__(self):
  566. # lazily compute hash
  567. self._hash = getattr(self, "_hash", None)
  568. if not self._hash:
  569. self._hash = hash(
  570. (
  571. self._flatten_rank_map,
  572. self._layout,
  573. self._device_type,
  574. self._mesh_dim_names,
  575. self._thread_id,
  576. )
  577. )
  578. return self._hash
  579. def __eq__(self, other: object) -> bool:
  580. if self is other:
  581. return True
  582. if not isinstance(other, DeviceMesh):
  583. return False
  584. return (
  585. self._flatten_rank_map == other._flatten_rank_map
  586. and self._layout == other._layout
  587. and self._device_type == other._device_type
  588. and self._mesh_dim_names == other._mesh_dim_names
  589. and self._thread_id == other._thread_id
  590. )
  591. def __getitem__(self, mesh_dim_names: str | tuple[str, ...]) -> "DeviceMesh":
  592. """
  593. Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh.
  594. The submesh created consists of the dimensions and the communicators indicated by
  595. ``mesh_dim_names``
  596. Args:
  597. mesh_dim_names (Union[str, tuple[str, ...]]): the name or the tuple of names of the
  598. mesh dimension of the DeviceMesh to create the submesh for.
  599. Returns:
  600. A :class:`DeviceMesh` object
  601. The following program runs on each process/rank in an SPMD manner in a world size of 8.
  602. In the first example:
  603. Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]).
  604. Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]).
  605. Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]).
  606. Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]).
  607. Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]).
  608. Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]).
  609. In the second example:
  610. Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]).
  611. Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]).
  612. Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]).
  613. Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]).
  614. Example::
  615. >>> # xdoctest: +SKIP("no rank")
  616. >>> from torch.distributed.device_mesh import DeviceMesh
  617. >>>
  618. >>> # Initialize a 2D device mesh as (2, 4) to represent the topology
  619. >>> # of cross-host(dim 0), and within-host (dim 1).
  620. >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp"))
  621. >>> tp_mesh = mesh_2d["tp"]
  622. >>> dp_mesh = mesh_2d["dp"]
  623. >>>
  624. >>> # Initialize a 3D mesh.
  625. >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp"))
  626. >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh.
  627. >>> dp_cp_mesh = mesh_3d["dp", "cp"]
  628. >>> cp_dp_mesh = mesh_3d["cp", "dp"]
  629. """
  630. if not self._mesh_dim_names:
  631. raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!")
  632. mesh_dim_names = (
  633. (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names
  634. )
  635. if mesh_dim_names == self._mesh_dim_names:
  636. return self
  637. else:
  638. sliced_mesh_layout = self._get_slice_mesh_layout(mesh_dim_names)
  639. # When using FakeTensorMode to trace the model, `_create_sub_mesh()` will
  640. # fail as it will require a real tensor to manipulate.
  641. # `unset_fake_temporarily()` will allow us to materialize the tensors
  642. # within `_create_sub_mesh`, which should not affect modling.
  643. #
  644. # Note that this should be orthogonal to torch.compile(). But whether
  645. # we can compile device_mesh `slicing` (no graph break) is not verified
  646. # yet and need a follow-up,
  647. # TODO: compiler + device_mesh slicing.
  648. with torch._subclasses.fake_tensor.unset_fake_temporarily():
  649. submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
  650. return submesh
  651. def get_group(self, mesh_dim: int | str | None = None) -> ProcessGroup:
  652. """
  653. Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the
  654. DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh.
  655. Args:
  656. mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
  657. of the mesh dimension. Default is None.
  658. Returns:
  659. A :class:`ProcessGroup` object.
  660. """
  661. if not hasattr(self, "_dim_group_names"):
  662. raise RuntimeError("DeviceMesh process groups not initialized!")
  663. if len(self._layout) > 1 and mesh_dim is None:
  664. raise RuntimeError(
  665. f"Found the DeviceMesh have {len(self._layout)} dimensions",
  666. "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
  667. "If you want to get the list of all the ProcessGroups in the DeviceMesh,"
  668. "please use `get_all_groups()` instead.",
  669. )
  670. root_mesh = self._get_root_mesh()
  671. # Quick return if the current device_mesh is a 1D mesh.
  672. if len(self._layout) == 1 and mesh_dim is None:
  673. return not_none(_get_pg_from_name(root_mesh, self._dim_group_names[0]))
  674. root_to_flatten_mapping = root_mesh._flatten_mapping
  675. if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping:
  676. dim_group_name = root_to_flatten_mapping[
  677. mesh_dim # type: ignore[index]
  678. ]._dim_group_names[0]
  679. return not_none(_get_pg_from_name(root_mesh, dim_group_name))
  680. else:
  681. mesh_dim = (
  682. self._get_mesh_dim_by_name(mesh_dim)
  683. if isinstance(mesh_dim, str)
  684. else mesh_dim
  685. )
  686. if not isinstance(mesh_dim, int):
  687. raise AssertionError(
  688. f"mesh_dim must be an int, got {type(mesh_dim)}"
  689. )
  690. return not_none(
  691. _get_pg_from_name(root_mesh, self._dim_group_names[mesh_dim])
  692. )
  693. def get_all_groups(self) -> list[ProcessGroup]:
  694. """
  695. Returns a list of ProcessGroups for all mesh dimensions.
  696. Returns:
  697. A list of :class:`ProcessGroup` object.
  698. """
  699. return [self.get_group(i) for i in range(len(self._layout))]
  700. def _create_sub_mesh(
  701. self,
  702. layout: _MeshLayout,
  703. submesh_dim_names: tuple[str, ...],
  704. ) -> "DeviceMesh":
  705. root_mesh = self._get_root_mesh()
  706. slice_dim_group_name = []
  707. if len(self._dim_group_names) > 0:
  708. assert len(self._dim_group_names) == len(
  709. not_none(self._mesh_dim_names)
  710. ), (
  711. "The number of dim_group_names and mesh_dim_names "
  712. "should have the same length if the rank is in the mesh."
  713. )
  714. for name in submesh_dim_names:
  715. if name in not_none(self._mesh_dim_names):
  716. slice_dim_group_name.append(
  717. self._dim_group_names[
  718. not_none(self._mesh_dim_names).index(name)
  719. ]
  720. )
  721. else:
  722. # If device_mesh is not root_mesh, we already throw error in _get_slice_mesh_layout
  723. # Since we will deprecate the slicing of flattened dim_name from root mesh soon,
  724. # we don't want to optimize the code furthermore.
  725. flatten_mesh = self._flatten_mapping[name]
  726. slice_dim_group_name.append(
  727. flatten_mesh._dim_group_names[
  728. not_none(flatten_mesh._mesh_dim_names).index(name)
  729. ]
  730. )
  731. res_submesh = DeviceMesh(
  732. self._device_type,
  733. _layout=layout,
  734. _rank_map=root_mesh._rank_map,
  735. mesh_dim_names=submesh_dim_names,
  736. _root_mesh=root_mesh,
  737. _init_backend=False,
  738. )
  739. res_submesh._dim_group_names = slice_dim_group_name
  740. return res_submesh
  741. def _create_flatten_mesh(
  742. self,
  743. mesh_dim_name: str | None = None,
  744. backend_override: BackendConfig = (None, None),
  745. ) -> "DeviceMesh":
  746. root_mesh = self._get_root_mesh()
  747. if not mesh_dim_name:
  748. mesh_dim_name = "_".join(not_none(self._mesh_dim_names))
  749. # Flatten a 1D device mesh into its original mesh_dim_name will return itself.
  750. if self.ndim == 1 and mesh_dim_name in not_none(self._mesh_dim_names):
  751. return self
  752. # Check whether the mesh_dim_name for flattened mesh is valid.
  753. invalid_dim_names = not_none(root_mesh._mesh_dim_names)
  754. if mesh_dim_name in invalid_dim_names:
  755. raise ValueError(
  756. f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ",
  757. f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. "
  758. f"Please specify another valid mesh_dim_name.",
  759. )
  760. flattened_mesh_layout = self._layout.coalesce()
  761. if len(flattened_mesh_layout) > 1:
  762. flattened_mesh_layout = flattened_mesh_layout.nest()
  763. # Quick return if the flatten mesh has been created before.
  764. if mesh_dim_name in root_mesh._flatten_mapping:
  765. if (
  766. flattened_mesh_layout
  767. == root_mesh._flatten_mapping[mesh_dim_name]._layout
  768. ):
  769. return root_mesh._flatten_mapping[mesh_dim_name]
  770. else:
  771. raise ValueError(
  772. f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, "
  773. f"Please specify another valid mesh_dim_name."
  774. )
  775. res_flattened_mesh = DeviceMesh(
  776. root_mesh._device_type,
  777. _layout=flattened_mesh_layout,
  778. _rank_map=root_mesh._rank_map,
  779. mesh_dim_names=(mesh_dim_name,),
  780. _root_mesh=root_mesh,
  781. backend_override=(backend_override,),
  782. )
  783. root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh
  784. return res_flattened_mesh
  785. def _get_root_mesh_dim(self) -> int | None:
  786. """
  787. Returns the index of the mesh dim in the root mesh.
  788. The device_mesh passed in needs to be sliced out from the root mesh
  789. or submesh of the root mesh.
  790. """
  791. root_mesh = self._get_root_mesh()
  792. child_mesh_dim_names = self._mesh_dim_names
  793. if root_mesh and child_mesh_dim_names:
  794. if len(child_mesh_dim_names) != 1:
  795. raise AssertionError("The submesh can only be a 1D mesh.")
  796. child_mesh_dim_name = child_mesh_dim_names[0]
  797. return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
  798. return None
  799. def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
  800. if self._mesh_dim_names is None or len(self._mesh_dim_names) == 0:
  801. raise KeyError(
  802. "No `mesh_dim_names` found.",
  803. )
  804. if mesh_dim_name not in self._mesh_dim_names:
  805. raise KeyError(
  806. f"Mesh dimension '{mesh_dim_name}' does not exist.",
  807. f"Available mesh dimensions are: mesh_dim_names={self._mesh_dim_names}",
  808. )
  809. return not_none(self._mesh_dim_names.index(mesh_dim_name))
  810. def _get_slice_mesh_layout(
  811. self, mesh_dim_names: tuple[str, ...]
  812. ) -> _MeshLayout:
  813. """
  814. Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
  815. If valid, return dim indexes of the slice mesh in the device mesh.
  816. """
  817. slice_from_root = True
  818. if self != self._get_root_mesh():
  819. slice_from_root = False
  820. # The slice mesh_dim_names should consist either the current device_mesh's mesh_dim_names
  821. # or its flattened mesh's mesh_dim_names if it's root_mesh.
  822. flatten_name_to_root_layout = (
  823. {
  824. key: mesh._layout
  825. for key, mesh in self._get_root_mesh()._flatten_mapping.items()
  826. }
  827. if slice_from_root
  828. else {}
  829. )
  830. valid_mesh_dim_names = [
  831. *not_none(self._mesh_dim_names),
  832. *flatten_name_to_root_layout,
  833. ]
  834. if not all(
  835. mesh_dim_name in valid_mesh_dim_names
  836. for mesh_dim_name in mesh_dim_names
  837. ):
  838. raise KeyError(
  839. f"Invalid mesh_dim_names {mesh_dim_names} specified. "
  840. f"Valid mesh_dim_names are {valid_mesh_dim_names}."
  841. )
  842. layout_sliced = []
  843. for name in mesh_dim_names:
  844. if name in not_none(self._mesh_dim_names):
  845. layout_sliced.append(
  846. self._layout[not_none(self._mesh_dim_names).index(name)]
  847. )
  848. elif name in flatten_name_to_root_layout:
  849. warnings.warn(
  850. "Slicing a flattened dim from root mesh will be deprecated in PT 2.11. "
  851. "Users need to bookkeep the flattened mesh directly. ",
  852. stacklevel=2,
  853. )
  854. layout_sliced.append(flatten_name_to_root_layout[name])
  855. sliced_sizes = tuple(l.sizes for l in layout_sliced)
  856. sliced_strides = tuple(l.strides for l in layout_sliced)
  857. # The check below is from DeviceMesh's implementation before adopting CuTe layout for internal
  858. # bookkeeping and it can be removed but we need to define what is the expected behavior.
  859. # TODO: Remove the below check and define the expected behavior.
  860. # Validate the order of the slice mesh dim indices.
  861. # This needs to be in ascending order.
  862. pre_stride = -1
  863. for stride in reversed(sliced_strides):
  864. # Note that with CuTe layout, we can support slicing flattened non-contiguous mesh dims with no problem.
  865. # But we don't see a use case for now so we don't want to support it.
  866. if not is_int(stride):
  867. raise NotImplementedError(
  868. "Currently, this only allows slicing out a contiguous flattened dim."
  869. )
  870. # Note that with CuTe layout, we can support slicing non-ascending order dims with no problem.
  871. # But we don't see a use case for now so we don't want to support it.
  872. if stride < pre_stride:
  873. raise KeyError(
  874. f"Invalid mesh_dim_names {mesh_dim_names} specified. "
  875. "Mesh dim indices should be in ascending order."
  876. )
  877. pre_stride = stride
  878. # When users sliced dim_names outside from current mesh, we will check whether
  879. # there is layout overlap.
  880. # TODO: Eventually we will just directly throw error here because
  881. # we will deprecate the slicing of flattened dim_name from root mesh.
  882. layout_sliced = _MeshLayout(sliced_sizes, sliced_strides)
  883. if not layout_sliced.check_non_overlap():
  884. raise RuntimeError(
  885. f"Slicing overlapping dim_names {mesh_dim_names} is not allowed."
  886. )
  887. return layout_sliced
  888. # TODO: to make this use case by other components public API in the future.
  889. def _get_all_submeshes(self, mesh_dim_name: str) -> list["DeviceMesh"]:
  890. """
  891. Return all the submeshes of a given mesh dimension of the device mesh.
  892. """
  893. mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
  894. layout = self._layout[mesh_dim]
  895. pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
  896. cur_rank = self.get_rank()
  897. res_submeshes = []
  898. for mesh_1d in pg_ranks_by_dim:
  899. submesh = DeviceMesh(
  900. self._device_type,
  901. mesh_1d,
  902. mesh_dim_names=(mesh_dim_name,),
  903. _init_backend=False,
  904. )
  905. submesh._dim_group_names = ( # type: ignore[has-type]
  906. [self._dim_group_names[mesh_dim]] # type: ignore[has-type]
  907. if cur_rank in mesh_1d
  908. else []
  909. )
  910. res_submeshes.append(submesh)
  911. return res_submeshes
  912. @staticmethod
  913. def from_group(
  914. group: ProcessGroup | list[ProcessGroup],
  915. device_type: str,
  916. mesh: Union[torch.Tensor, "ArrayLike"] | None = None,
  917. *,
  918. mesh_dim_names: tuple[str, ...] | None = None,
  919. ) -> "DeviceMesh":
  920. """
  921. Constructs a :class:`DeviceMesh` with ``device_type`` from an
  922. existing :class:`ProcessGroup` or a list of existing :class:`ProcessGroup`.
  923. The constructed device mesh has number of dimensions equal to the
  924. number of groups passed. For example, if a single process group is passed in,
  925. the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in,
  926. the resulted DeviceMesh is a 2D mesh.
  927. If more than one group is passed, then the ``mesh`` and ``mesh_dim_names`` arguments
  928. are required. The order of the process groups passed in determines the topology of
  929. the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh.
  930. The `mesh` tensor passed in must have the same number of dimensions as the number of process
  931. groups passed in, and the order of the dimensions in the `mesh` tensor must match the order
  932. in the process groups passed in.
  933. Args:
  934. group (ProcessGroup or list[ProcessGroup]): the existing ProcessGroup
  935. or a list of existing ProcessGroups.
  936. device_type (str): The device type of the mesh. Currently supports: "cpu",
  937. "cuda/cuda-like". Passing in a device type with a GPU index, such as "cuda:0",
  938. is not allowed.
  939. mesh (torch.Tensor or ArrayLike, optional): A multi-dimensional array or an
  940. integer tensor describing the layout of devices, where the IDs are global IDs
  941. of the default process group. Default is None.
  942. mesh_dim_names (tuple[str, ...], optional): A tuple of mesh dimension names to assign
  943. to each dimension of the multi-dimensional array describing the layout of devices.
  944. Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names`
  945. must be unique. Default is None.
  946. Returns:
  947. DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
  948. """
  949. # 1D scenario
  950. if isinstance(group, ProcessGroup):
  951. group_ranks = get_process_group_ranks(group)
  952. if (
  953. isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks
  954. ) or (
  955. mesh is not None
  956. and not isinstance(mesh, torch.Tensor)
  957. and mesh != group_ranks
  958. ):
  959. raise ValueError(
  960. f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}"
  961. )
  962. mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int)
  963. device_mesh = DeviceMesh(
  964. device_type,
  965. mesh,
  966. mesh_dim_names=mesh_dim_names,
  967. _init_backend=False,
  968. )
  969. device_mesh._dim_group_names = [group.group_name]
  970. device_mesh._pg_registry[group.group_name] = group
  971. return device_mesh
  972. # nD scenario
  973. groups = list(group)
  974. if len(groups) == 0:
  975. raise ValueError("Expects at least one ProcessGroup to be passed")
  976. if mesh is None:
  977. raise ValueError("Must pass mesh if passing multiple ProcessGroups")
  978. if mesh_dim_names is None:
  979. raise ValueError(
  980. "Must pass mesh_dim_names if passing multiple ProcessGroups"
  981. )
  982. # When init a DeviceMesh with multiple ProcessGroups directly, we need to make sure
  983. # the mesh tensor is contiguous. Otherwise, the layout we inferred from the mesh tensor
  984. # will have larger span than the actual tensor. This is just internal implementation detail
  985. # and does not affect user facing behavior.
  986. mesh = (
  987. mesh.detach().to(dtype=torch.int, device="cpu")
  988. if isinstance(mesh, torch.Tensor)
  989. else torch.tensor(mesh, device="cpu", dtype=torch.int)
  990. )
  991. if mesh.ndim != len(groups):
  992. raise ValueError(
  993. "Expects mesh with ndim equal to number of ProcessGroups but got "
  994. f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups"
  995. )
  996. device_mesh = DeviceMesh(
  997. device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
  998. )
  999. device_mesh._dim_group_names = [group.group_name for group in groups]
  1000. for group in groups:
  1001. device_mesh._pg_registry[group.group_name] = group
  1002. return device_mesh
  1003. def size(self, mesh_dim: int | None = None) -> int:
  1004. if mesh_dim is not None:
  1005. return self._layout[mesh_dim].numel()
  1006. return self._layout.numel()
  1007. @property
  1008. def ndim(self) -> int:
  1009. return len(self._layout)
  1010. @property
  1011. def shape(self) -> tuple[int, ...]:
  1012. return self._layout.top_level_sizes
  1013. def get_rank(self) -> int:
  1014. """
  1015. Returns the current global rank.
  1016. """
  1017. return get_rank()
  1018. def get_local_rank(self, mesh_dim: int | str | None = None) -> int:
  1019. """
  1020. Returns the local rank of the given mesh_dim of the DeviceMesh.
  1021. Args:
  1022. mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
  1023. of the mesh dimension. Default is None.
  1024. Returns:
  1025. An integer denotes the local rank.
  1026. The following program runs on each process/rank in an SPMD manner. In this example, we have 2
  1027. hosts with 4 GPUs each.
  1028. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
  1029. Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
  1030. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
  1031. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
  1032. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
  1033. Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
  1034. Example::
  1035. >>> # xdoctest: +SKIP("no rank")
  1036. >>> from torch.distributed.device_mesh import DeviceMesh
  1037. >>>
  1038. >>> # Initialize device mesh as (2, 4) to represent the topology
  1039. >>> # of cross-host(dim 0), and within-host (dim 1).
  1040. >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
  1041. """
  1042. if self.ndim > 1 and mesh_dim is None:
  1043. raise RuntimeError(
  1044. f"Found the DeviceMesh have {len(self._layout)} dimensions",
  1045. "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
  1046. )
  1047. elif mesh_dim is None:
  1048. mesh_dim = 0
  1049. mesh_dim_group = not_none(self.get_group(mesh_dim))
  1050. if not isinstance(mesh_dim_group, ProcessGroup):
  1051. raise AssertionError(
  1052. "We expect ProcessGroup before calling `get_rank`!"
  1053. )
  1054. return not_none(get_rank(mesh_dim_group))
  1055. def _is_current_rank_part_of_mesh(self) -> bool:
  1056. """
  1057. Return True if the current rank is part of this mesh.
  1058. When a DeviceMesh is created with a subset of ranks (a sub-mesh),
  1059. ranks not included in the mesh are "non-participating ranks". These
  1060. ranks:
  1061. - Return None from get_coordinate()
  1062. - Hold empty tensors as their local DTensor representation
  1063. - Skip computation during DTensor dispatch (returning default values)
  1064. - Skip collective operations during redistribute
  1065. - Return 0 cost for redistribute cost calculations
  1066. This allows DTensor operations to execute correctly across the entire
  1067. process group while only performing actual work on participating ranks.
  1068. """
  1069. return self._coordinate_on_dim is not None
  1070. def get_coordinate(self) -> tuple[int, ...] | None:
  1071. """
  1072. Return the relative indices of this rank relative to all
  1073. dimensions of the mesh. If this rank is not part of the mesh, return None.
  1074. """
  1075. return self._coordinate_on_dim
  1076. def _sym_get_coordinate(self, index: int) -> int:
  1077. import torch.distributed.config as config
  1078. from torch._guards import detect_fake_mode
  1079. if not detect_fake_mode() or not config.compile_on_one_rank:
  1080. # This is only valid when the current rank is part of the mesh.
  1081. assert self._coordinate_on_dim is not None
  1082. return self._coordinate_on_dim[index]
  1083. # This will cause the ops to be registered - so don't let RUFF
  1084. # delete this import because it thinks it's unused...
  1085. from ._ops import device_mesh # noqa: F401
  1086. # Temporarily turn off tracing while we lift the constant
  1087. # rank_map to a list so it can be a constant in the graph.
  1088. with torch._subclasses.fake_tensor.unset_fake_temporarily():
  1089. rank_map_list = self._rank_map.tolist()
  1090. rank_map = torch.tensor(rank_map_list, device="cpu", dtype=torch.int)
  1091. full_mesh = self._layout.remap_to_tensor(rank_map)
  1092. return torch.ops.device_mesh._runtime_compute_coordinate_on_dim(
  1093. full_mesh, index
  1094. )
  1095. def _flatten(
  1096. self,
  1097. mesh_dim_name: str | None = None,
  1098. backend_override: str
  1099. | C10dBackend.Options
  1100. | tuple[str, C10dBackend.Options]
  1101. | None = None,
  1102. ) -> "DeviceMesh":
  1103. """
  1104. Returns a 1D DeviceMesh by flattening the current DeviceMesh.
  1105. If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the
  1106. given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh
  1107. DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling
  1108. mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 2, 4, 6], mesh_dim_names=("dp_cp",))
  1109. on rank 0, 2, 4, 6 and a 1D submesh DeviceMesh([1, 3, 5, 7], mesh_dim_names=("dp_cp",)) on rank 1, 3, 5, 7.
  1110. After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the
  1111. existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"].
  1112. """
  1113. if not self._mesh_dim_names:
  1114. raise RuntimeError(
  1115. "Cannot flatten a DeviceMesh without mesh_dim_names!"
  1116. )
  1117. if backend_override is not None:
  1118. (backend_override_tuple,) = _normalize_backend_override(
  1119. {0: backend_override}, 1
  1120. )
  1121. else:
  1122. backend_override_tuple = (None, None)
  1123. return self._create_flatten_mesh(mesh_dim_name, backend_override_tuple)
  1124. def _create_unflatten_mesh(
  1125. self,
  1126. dim: int,
  1127. mesh_sizes: tuple[int, ...],
  1128. mesh_dim_names: tuple[str, ...],
  1129. backend_override: tuple[
  1130. tuple[str | None, C10dBackend.Options | None], ...
  1131. ] = ((None, None),),
  1132. ) -> "DeviceMesh":
  1133. inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes))
  1134. if inner_layout.numel() != self._layout[dim].numel():
  1135. raise ValueError(
  1136. f"The product of {mesh_sizes=} is {inner_layout.numel()}, "
  1137. f"but the original dimension at dim={dim} has size {self._layout[dim].numel()}. "
  1138. f"These must be equal for unflatten to work correctly."
  1139. )
  1140. partial_layout = self._layout[dim].composition(inner_layout)
  1141. unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout)
  1142. unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
  1143. unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
  1144. root_mesh = self._get_root_mesh()
  1145. res_mesh = DeviceMesh(
  1146. self.device_type,
  1147. _layout=unflattened_layout,
  1148. _rank_map=root_mesh._rank_map,
  1149. mesh_dim_names=tuple(unflattened_mesh_dim_names),
  1150. _root_mesh=root_mesh,
  1151. _init_backend=False,
  1152. )
  1153. # If original mesh has initiated its backend, we need to initialize the backend
  1154. # of unflatten dims as well.
  1155. # TODO: To make backend init more efficient with cute layout representation and support
  1156. # per dim backend init.
  1157. if hasattr(self, "_dim_group_names"):
  1158. dim_group_names = self._dim_group_names.copy()
  1159. new_group_names = self._init_process_groups(
  1160. partial_layout,
  1161. root_mesh._rank_map,
  1162. mesh_dim_names,
  1163. backend_override,
  1164. )
  1165. dim_group_names[dim : dim + 1] = new_group_names
  1166. res_mesh._dim_group_names = dim_group_names
  1167. # Populate root mesh's pg registry with new groups
  1168. for name in new_group_names:
  1169. pg = _resolve_process_group(name)
  1170. if pg is not None:
  1171. root_mesh._pg_registry[name] = pg
  1172. return res_mesh
  1173. def _unflatten(
  1174. self,
  1175. dim: int | str,
  1176. mesh_sizes: tuple[int, ...],
  1177. mesh_dim_names: tuple[str, ...],
  1178. backend_override: dict[
  1179. str, str | C10dBackend.Options | tuple[str, C10dBackend.Options]
  1180. ]
  1181. | None = None,
  1182. ) -> "DeviceMesh":
  1183. """
  1184. Returns a DeviceMesh by unflatten the current DeviceMesh.
  1185. This api can be used to unflatten a N-D DeviceMesh into N-1+len(mesh_sizes)-D meshes or submeshes.
  1186. The dim is the dimension to be unflattened which can be either a string or an integer.
  1187. The mesh_sizes is a tuple which specifies the shape of the mesh unflatten into for the given dim.
  1188. The mesh_dim_names is a list of strings which specifies the names of the dimensions of the mesh unflatten into.
  1189. Its length must match the length of mesh_sizes.
  1190. For example, if we have a 1D mesh DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=("world")),
  1191. calling mesh_1d._unflatten(0, (2, 2, 4), ["dp", "pp", "tp"]) will create a 3D mesh
  1192. DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")).
  1193. Note that after calling the unflatten, there is no access to the unflattened dimension in mesh_1d, one can only
  1194. use the newly unflattened mesh to slice out the unflattened mesh dims.
  1195. """
  1196. if isinstance(dim, int) and dim >= self.ndim:
  1197. raise ValueError(
  1198. f"dim {dim} specified in `_unflatten` is out of range {self.ndim}"
  1199. )
  1200. elif isinstance(dim, str) and dim in not_none(self.mesh_dim_names):
  1201. raise ValueError(
  1202. f"dim {dim} specified in `_unflatten` is not in {self.mesh_dim_names}"
  1203. )
  1204. if len(mesh_sizes) != len(mesh_dim_names):
  1205. raise RuntimeError(
  1206. "mesh_dim_names must have same length as mesh_sizes in _unflatten!"
  1207. )
  1208. if isinstance(dim, str):
  1209. dim = not_none(self.mesh_dim_names).index(dim)
  1210. if backend_override is not None:
  1211. backend_override_tuple = tuple(
  1212. _normalize_backend_override(
  1213. backend_override, # type: ignore[arg-type]
  1214. len(mesh_sizes),
  1215. mesh_dim_names,
  1216. )
  1217. )
  1218. else:
  1219. backend_override_tuple = ((None, None),) * len(mesh_dim_names)
  1220. return self._create_unflatten_mesh(
  1221. dim,
  1222. mesh_sizes,
  1223. mesh_dim_names,
  1224. backend_override_tuple,
  1225. )
  1226. @staticmethod
  1227. def _concatenate(device_mesh_list: list["DeviceMesh"]) -> "DeviceMesh":
  1228. concat_dim_names: list[str] = []
  1229. concat_sizes: list[IntTuple] = []
  1230. concat_strides: list[IntTuple] = []
  1231. concat_dim_group_name: list[GroupName] = []
  1232. flatten_rank_map = device_mesh_list[0]._flatten_rank_map
  1233. for dm in device_mesh_list:
  1234. for i in range(len(dm._layout)):
  1235. concat_sizes.append(dm._layout[i].sizes)
  1236. concat_strides.append(dm._layout[i].strides)
  1237. concat_dim_names.extend(not_none(dm.mesh_dim_names))
  1238. concat_dim_group_name.extend(not_none(dm._dim_group_names))
  1239. # Concatenate device mesh having different root mesh tensors are meaningless
  1240. # because the concatenated indices should be indexed by the same root mesh tensor.
  1241. if dm._flatten_rank_map != flatten_rank_map:
  1242. raise RuntimeError(
  1243. "Cannot concatenate DeviceMeshes derived from different device meshs"
  1244. )
  1245. concat_mesh_layout = _MeshLayout(tuple(concat_sizes), tuple(concat_strides))
  1246. if not concat_mesh_layout.check_non_overlap():
  1247. raise RuntimeError(
  1248. f"Cannot concatenate overlapping meshes: {device_mesh_list}"
  1249. )
  1250. res_mesh = DeviceMesh(
  1251. device_mesh_list[0].device_type,
  1252. _layout=concat_mesh_layout,
  1253. _rank_map=device_mesh_list[0]._rank_map,
  1254. mesh_dim_names=tuple(concat_dim_names),
  1255. _root_mesh=device_mesh_list[0]._get_root_mesh(),
  1256. _init_backend=False,
  1257. )
  1258. res_mesh._dim_group_names = concat_dim_group_name
  1259. return res_mesh
  1260. def _normalize_backend_override(
  1261. backend_override: dict[
  1262. int | str,
  1263. str | C10dBackend.Options | tuple[str, C10dBackend.Options],
  1264. ],
  1265. ndim: int,
  1266. mesh_dim_names: tuple[str, ...] | None = None,
  1267. ) -> Iterator[BackendConfig]:
  1268. if mesh_dim_names is None:
  1269. mesh_dim_names = ()
  1270. for dim_idx, dim_name in zip_longest(range(ndim), mesh_dim_names):
  1271. if dim_name is not None and dim_name in backend_override:
  1272. if dim_idx in backend_override:
  1273. raise RuntimeError(
  1274. f"Found redundant dim index {dim_idx} and "
  1275. f"name {dim_name} in backend_override"
  1276. )
  1277. val = backend_override.pop(dim_name)
  1278. elif dim_idx in backend_override:
  1279. val = backend_override.pop(dim_idx)
  1280. else:
  1281. yield (None, None)
  1282. continue
  1283. if isinstance(val, str):
  1284. yield (val, None)
  1285. elif isinstance(val, C10dBackend.Options):
  1286. yield (None, val)
  1287. else:
  1288. yield val
  1289. if backend_override:
  1290. raise RuntimeError(
  1291. f"Found invalid keys in backend_override: got {list(backend_override.keys())}, "
  1292. f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}"
  1293. )
  1294. def init_device_mesh(
  1295. device_type: str,
  1296. mesh_shape: tuple[int, ...],
  1297. *,
  1298. mesh_dim_names: tuple[str, ...] | None = None,
  1299. backend_override: dict[
  1300. int | str, str | C10dBackend.Options | tuple[str, C10dBackend.Options]
  1301. ]
  1302. | None = None,
  1303. ) -> DeviceMesh:
  1304. """
  1305. Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
  1306. This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
  1307. If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
  1308. .. note::
  1309. `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
  1310. runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
  1311. describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
  1312. .. note::
  1313. If no process group is found, init_device_mesh will initialize distributed process group/groups
  1314. required for distributed communications behind the scene.
  1315. Args:
  1316. device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu".
  1317. Passing in a device type with a GPU index, such as "cuda:0", is not allowed.
  1318. mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
  1319. describing the layout of devices.
  1320. mesh_dim_names (tuple[str, ...], optional): A tuple of mesh dimension names to assign to each dimension
  1321. of the multi-dimensional array describing the layout of devices. Its length must match the length
  1322. of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
  1323. backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of
  1324. the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a
  1325. dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name
  1326. of the backend and its options, or just one of these two components (in which case the other will be
  1327. set to its default value).
  1328. Returns:
  1329. DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
  1330. Example::
  1331. >>> # xdoctest: +SKIP("no rank")
  1332. >>> from torch.distributed.device_mesh import init_device_mesh
  1333. >>>
  1334. >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
  1335. >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
  1336. """
  1337. if mesh_dim_names is not None:
  1338. if len(set(mesh_dim_names)) != len(mesh_dim_names):
  1339. raise RuntimeError(
  1340. "Each mesh_dim_name must be unique.",
  1341. f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
  1342. )
  1343. if len(mesh_shape) != len(mesh_dim_names):
  1344. raise RuntimeError(
  1345. "mesh_shape and mesh_dim_names should have same length!",
  1346. f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
  1347. )
  1348. if backend_override is not None:
  1349. backend_override_tuple = tuple(
  1350. _normalize_backend_override(
  1351. backend_override, len(mesh_shape), mesh_dim_names
  1352. )
  1353. )
  1354. else:
  1355. backend_override_tuple = None
  1356. # assume valid device types are all letters
  1357. if device_type and not device_type.isalpha():
  1358. raise RuntimeError(
  1359. f"Device type with index is not supported but got {device_type}. ",
  1360. "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
  1361. )
  1362. layout = _MeshLayout(tuple(mesh_shape), suffix_product(tuple(mesh_shape)))
  1363. # Always initialize the (identity) rank map on CPU, regardless of what the
  1364. # external device type has been set to be (e.g. meta)
  1365. with torch.device("cpu"):
  1366. rank_map = torch.arange(layout.numel(), dtype=torch.int)
  1367. device_mesh = DeviceMesh(
  1368. device_type=device_type,
  1369. _layout=layout,
  1370. _rank_map=rank_map,
  1371. mesh_dim_names=mesh_dim_names,
  1372. backend_override=backend_override_tuple,
  1373. )
  1374. return device_mesh