distributed.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. """
  2. Distributed computing variable tracking classes for PyTorch Dynamo.
  3. This module implements variable tracking for distributed computing components:
  4. - Process Groups (for collective communication)
  5. - Device Meshes (for distributed tensor sharding)
  6. - Placement Types (for specifying distribution strategies)
  7. - Distributed Tensors and their operations
  8. - Backward hooks for distributed module operations
  9. These classes are responsible for tracking distributed operations during graph
  10. compilation while maintaining proper guards and handling distributed-specific
  11. behaviors. They ensure correct handling of distributed components like process
  12. groups, device meshes, and placement strategies while preserving proper semantics
  13. for distributed tensor operations in the compiled code.
  14. The implementation provides special handling for distributed package availability
  15. checks and proper tracking of distributed state and operations across processes.
  16. """
  17. import functools
  18. import inspect
  19. from collections.abc import Sequence
  20. from typing import Any, Literal, TYPE_CHECKING
  21. import torch
  22. from torch.fx.experimental._backward_state import BackwardState
  23. from .. import compiled_autograd, variables
  24. from .._trace_wrapped_higher_order_op import trace_wrapped
  25. from ..bytecode_transformation import create_call_function
  26. from ..exc import unimplemented
  27. from ..external_utils import call_module_hooks_from_backward_state
  28. from ..guards import GuardBuilder, install_guard
  29. from ..source import AttrSource
  30. from ..utils import istype
  31. from .base import VariableTracker
  32. from .constant import CONSTANT_VARIABLE_NONE, ConstantVariable, EnumVariable
  33. if TYPE_CHECKING:
  34. from torch._dynamo.codegen import PyCodegen
  35. from torch._dynamo.symbolic_convert import InstructionTranslator
  36. class DistributedVariable(VariableTracker):
  37. """
  38. The base distributed variable that encapsulates common methods
  39. for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.).
  40. Concrete distributed objects could inherit this class and add object
  41. specific logic.
  42. i.e. It provides the check on the distributed package existence
  43. and hold the tracking value for the corresponding distributed object.
  44. """
  45. def __init__(self, value: Any, **kwargs: Any) -> None:
  46. super().__init__(**kwargs)
  47. if not DistributedVariable.is_available():
  48. unimplemented(
  49. gb_type="torch.distributed package is not available!",
  50. context="",
  51. explanation="The PyTorch package doesn't include torch.distributed when building from source.",
  52. hints=[
  53. "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source."
  54. ],
  55. )
  56. self.value = value
  57. def python_type(self) -> type:
  58. return type(self.value)
  59. @staticmethod
  60. def is_available() -> bool:
  61. # check if the distributed package is available or not
  62. return torch.distributed.is_available()
  63. def is_python_hashable(self) -> Literal[True]:
  64. return True
  65. def get_python_hash(self) -> int:
  66. return hash(self.value)
  67. def is_python_equal(self, other: object) -> bool:
  68. return (
  69. isinstance(other, VariableTracker)
  70. and self.as_python_constant() == other.as_python_constant()
  71. )
  72. def is_from_local(value: object) -> bool:
  73. if not DistributedVariable.is_available():
  74. return False
  75. from torch.distributed.tensor import DTensor
  76. return inspect.isfunction(value) and value is DTensor.from_local
  77. def is_constant_pg_functions(value: object) -> bool:
  78. if not DistributedVariable.is_available():
  79. return False
  80. from torch.distributed.distributed_c10d import (
  81. _get_group_size_by_name,
  82. _get_group_tag,
  83. _rank_not_in_group,
  84. _resolve_group_name_by_ranks_and_tag,
  85. get_process_group_ranks,
  86. )
  87. constant_processgroup_functions = [
  88. _get_group_size_by_name,
  89. _get_group_tag,
  90. _rank_not_in_group,
  91. get_process_group_ranks,
  92. _resolve_group_name_by_ranks_and_tag,
  93. ]
  94. return inspect.isfunction(value) and value in constant_processgroup_functions
  95. class WorldMetaClassVariable(DistributedVariable):
  96. """
  97. Tracks torch.distributed.GroupMember and torch.distributed.group, which are
  98. instances of the metaclass _WorldMeta.
  99. """
  100. @classmethod
  101. def is_group_member_type(cls, value: object) -> bool:
  102. if not cls.is_available():
  103. return False
  104. from torch.distributed.distributed_c10d import _WorldMeta
  105. return type(value) is _WorldMeta
  106. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  107. if name == "WORLD":
  108. assert self.source
  109. source = AttrSource(base=self.source, member="WORLD")
  110. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  111. return ProcessGroupVariable(self.value.WORLD)
  112. elif name == "NON_GROUP_MEMBER":
  113. assert self.source
  114. source = AttrSource(base=self.source, member="NON_GROUP_MEMBER")
  115. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  116. return EnumVariable(self.value.NON_GROUP_MEMBER)
  117. return super().var_getattr(tx, name)
  118. class PlacementClassVariable(DistributedVariable):
  119. @staticmethod
  120. def is_placement_type(value: object) -> bool:
  121. # we can't rely on importing/accessing torch distributed, it is not always built.
  122. if not DistributedVariable.is_available():
  123. return False
  124. from torch.distributed.tensor.placement_types import Placement
  125. return isinstance(value, type) and issubclass(value, Placement)
  126. def as_python_constant(self) -> Any:
  127. return self.value
  128. def call_function(
  129. self,
  130. tx: "InstructionTranslator",
  131. args: Sequence[VariableTracker],
  132. kwargs: dict[str, VariableTracker],
  133. ) -> VariableTracker:
  134. if self.source:
  135. # NOTE: we don't need to track mutations to the placement class as they
  136. # are supposed to be immutable.
  137. new_obj = self.value.__new__(self.value)
  138. var = PlacementVariable(new_obj)
  139. if inspect.getattr_static(self.value, "__init__", None):
  140. var.call_method(tx, "__init__", args, kwargs)
  141. return var
  142. return super().call_function(tx, args, kwargs)
  143. class PlacementVariable(DistributedVariable):
  144. @staticmethod
  145. def is_placement(value: object) -> bool:
  146. # we can't rely on importing/accessing torch distributed, it is not always built.
  147. if not DistributedVariable.is_available():
  148. return False
  149. from torch.distributed.tensor.placement_types import Placement
  150. return isinstance(value, Placement)
  151. def as_python_constant(self) -> Any:
  152. return self.value
  153. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  154. if name == "dim":
  155. return ConstantVariable.create(self.value.dim)
  156. return super().var_getattr(tx, name)
  157. def call_method(
  158. self,
  159. tx: "InstructionTranslator",
  160. name: str,
  161. args: Sequence[VariableTracker],
  162. kwargs: dict[str, VariableTracker],
  163. ) -> VariableTracker:
  164. from . import ConstantVariable
  165. # Placement types dynamo tracking only allows following methods
  166. # and __setattr__ is for case like `Shard(dim)` and methods.
  167. # Methods in the list must satisfy:
  168. # 1. Input arguments are constants and do not need to be guarded on;
  169. # 2. Output is constant with respect to their inputs
  170. constant_fold_functions = [
  171. "__init__",
  172. "__setattr__",
  173. "is_shard",
  174. "is_partial",
  175. "is_replicate",
  176. ]
  177. if name in constant_fold_functions:
  178. try:
  179. value_type = type(self.value)
  180. if inspect.getattr_static(value_type, "__getattr__", None) is not None:
  181. unimplemented(
  182. gb_type="Placement with custom __getattr__ not supported",
  183. context=f"{value_type.__name__} with custom __getattr__",
  184. explanation="Dynamo does not support Placement types with custom __getattr__ methods",
  185. hints=[
  186. "Use Placement types without custom __getattr__ methods",
  187. "Move the Placement usage outside the compiled region",
  188. ],
  189. )
  190. method = inspect.getattr_static(value_type, name)
  191. except AttributeError:
  192. method = None
  193. if method is object.__init__:
  194. return CONSTANT_VARIABLE_NONE
  195. args = [x.as_python_constant() for x in args]
  196. kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  197. assert method is not None
  198. if name == "__setattr__":
  199. method(self.value, *args, **kwargs)
  200. return self
  201. constant_val = method(self.value, *args, **kwargs)
  202. return ConstantVariable.create(constant_val)
  203. return super().call_method(tx, name, args, kwargs) # type: ignore[arg-type]
  204. def reconstruct(self, codegen: "PyCodegen") -> None:
  205. # Reconstruct the Placement object by calling its constructor
  206. # e.g., Shard(0), Replicate(), Partial()
  207. from torch.distributed.tensor.placement_types import Partial, Replicate, Shard
  208. placement_type = type(self.value)
  209. # Load the placement class
  210. codegen.add_push_null(
  211. lambda: codegen.load_import_from(
  212. "torch.distributed.tensor.placement_types", placement_type.__name__
  213. )
  214. )
  215. # For Shard, we need to pass the dim argument
  216. if isinstance(self.value, Shard):
  217. codegen(ConstantVariable.create(self.value.dim))
  218. codegen.extend_output(create_call_function(1, False))
  219. # Replicate and Partial have no required args
  220. elif istype(self.value, (Replicate, Partial)):
  221. codegen.extend_output(create_call_function(0, False))
  222. else:
  223. super().reconstruct(codegen)
  224. class DeviceMeshVariable(DistributedVariable):
  225. @staticmethod
  226. def is_device_mesh(value: object) -> bool:
  227. # we can't rely on importing/accessing torch distributed, it is not always built.
  228. if not DistributedVariable.is_available():
  229. return False
  230. from torch.distributed.device_mesh import DeviceMesh
  231. return istype(value, DeviceMesh)
  232. def as_python_constant(self) -> Any:
  233. return self.value
  234. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  235. if name == "ndim":
  236. return ConstantVariable.create(self.value.ndim)
  237. if name == "device_type":
  238. return ConstantVariable.create(self.value.device_type)
  239. if name == "mesh_dim_names":
  240. source = self.source
  241. if source:
  242. source = AttrSource(base=source, member="mesh_dim_names")
  243. return VariableTracker.build(tx, self.value.mesh_dim_names, source)
  244. return super().var_getattr(tx, name)
  245. def call_method(
  246. self,
  247. tx: "InstructionTranslator",
  248. name: str,
  249. args: list[VariableTracker],
  250. kwargs: dict[str, VariableTracker],
  251. ) -> VariableTracker:
  252. if name == "size":
  253. const_args = [x.as_python_constant() for x in args]
  254. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  255. return ConstantVariable.create(self.value.size(*const_args, **const_kwargs))
  256. if name == "get_coordinate":
  257. return ConstantVariable.create(self.value.get_coordinate())
  258. if name == "get_rank":
  259. return ConstantVariable.create(self.value.get_rank())
  260. if name == "get_local_rank":
  261. const_args = [x.as_python_constant() for x in args]
  262. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  263. return ConstantVariable.create(
  264. self.value.get_local_rank(*const_args, **const_kwargs)
  265. )
  266. if name == "get_group":
  267. const_args = [x.as_python_constant() for x in args]
  268. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  269. return ProcessGroupVariable(
  270. self.value.get_group(*const_args, **const_kwargs)
  271. )
  272. if name == "_is_current_rank_part_of_mesh":
  273. return ConstantVariable.create(self.value._is_current_rank_part_of_mesh())
  274. if name == "_get_or_create_default_group":
  275. return ProcessGroupVariable(self.value._get_or_create_default_group())
  276. if name == "_flatten":
  277. from .builder import SourcelessBuilder
  278. const_args = [x.as_python_constant() for x in args]
  279. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  280. return SourcelessBuilder.create(
  281. tx, self.value._flatten(*const_args, **const_kwargs)
  282. )
  283. if name == "_sym_get_coordinate":
  284. const_args = [x.as_python_constant() for x in args]
  285. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  286. return ConstantVariable.create(
  287. self.value._sym_get_coordinate(*const_args, **const_kwargs)
  288. )
  289. return super().call_method(tx, name, args, kwargs)
  290. class ProcessGroupVariable(DistributedVariable):
  291. """
  292. We don't want a ProcessGroup object to end up in our output graph.
  293. But it's common for dynamo to intercept a PG that is then used to get info like
  294. rank() or world_size(), as well as passed to utility functions in distributed_c10d
  295. which desugar it into plain types like a ranklist and tag.
  296. For convenience and proper guarding, we construct a variable type.
  297. TODO: make it possible to use ProcessGroupVariable as input to simple functions
  298. like _expand_group without dynamo complaining about making a proxy for it.
  299. It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
  300. torch library functions are dealing with tensor-like types and would have proxies
  301. for their args.
  302. TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
  303. or just graph-break whenever one of our special cases is not hit?
  304. """
  305. def as_python_constant(self) -> Any:
  306. return self.value
  307. def call_method(
  308. self,
  309. tx: "InstructionTranslator",
  310. name: str,
  311. args: list[VariableTracker],
  312. kwargs: dict[str, VariableTracker],
  313. ) -> VariableTracker:
  314. if name == "rank":
  315. return variables.ConstantVariable.create(self.value.rank())
  316. if name == "size":
  317. return variables.ConstantVariable.create(self.value.size())
  318. if name == "_get_backend_name":
  319. return variables.ConstantVariable.create(self.value._get_backend_name())
  320. return super().call_method(tx, name, args, kwargs)
  321. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  322. if name == "group_name":
  323. return variables.ConstantVariable.create(self.value.group_name)
  324. if name in ["rank", "size"]:
  325. return variables.LambdaVariable(
  326. lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
  327. )
  328. # TODO should this just raise unimplemented?
  329. return super().var_getattr(tx, name)
  330. @staticmethod
  331. def is_process_group(value: object) -> bool:
  332. # we can't rely on importing/accessing torch distributed, it is not always built.
  333. if not DistributedVariable.is_available():
  334. return False
  335. from torch._C._distributed_c10d import ProcessGroup
  336. from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
  337. return istype(value, (ProcessGroup, FakeProcessGroup))
  338. class BackwardHookVariable(VariableTracker):
  339. """
  340. Handles torch.utils.hooks.BackwardHook for module-level backward
  341. hooks.
  342. """
  343. @staticmethod
  344. def create(
  345. tx: "InstructionTranslator",
  346. module: VariableTracker,
  347. user_hooks: VariableTracker,
  348. user_pre_hooks: VariableTracker,
  349. ) -> "BackwardHookVariable":
  350. if not compiled_autograd.compiled_autograd_enabled:
  351. unimplemented(
  352. gb_type="Module-level backwards hooks require compiled autograd.",
  353. context="",
  354. explanation="",
  355. hints=[
  356. "Enable compiled autograd by setting torch._dynamo.config.compiled_autograd = True."
  357. ],
  358. )
  359. def _in_graph_bw_hooks(
  360. bw_state: BackwardState,
  361. ) -> torch.utils.hooks.BackwardHook:
  362. """
  363. Rather than installing the user hooks in the graph (which
  364. don't survive AotAutograd), we install hooks that will call
  365. trace_wrapped in the backward pass that CompiledAutograd
  366. can turn into actual hook calls.
  367. """
  368. return torch.utils.hooks.BackwardHook(
  369. None,
  370. (
  371. functools.partial(
  372. trace_wrapped,
  373. fn=call_module_hooks_from_backward_state,
  374. bw_state=bw_state,
  375. hooks_name=user_hooks_name,
  376. module_name=module_name,
  377. ),
  378. ),
  379. (
  380. functools.partial(
  381. trace_wrapped,
  382. fn=call_module_hooks_from_backward_state,
  383. bw_state=bw_state,
  384. hooks_name=user_pre_hooks_name,
  385. module_name=module_name,
  386. ),
  387. ),
  388. )
  389. module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod")
  390. user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
  391. user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
  392. proxy = tx.output.create_proxy(
  393. "call_function",
  394. _in_graph_bw_hooks,
  395. (bw_state_proxy,),
  396. {},
  397. )
  398. proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ())
  399. return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks)
  400. def __init__(
  401. self,
  402. proxy: torch.fx.Proxy,
  403. module: VariableTracker,
  404. user_hooks: VariableTracker,
  405. user_pre_hooks: VariableTracker,
  406. **options: Any,
  407. ) -> None:
  408. super().__init__(**options)
  409. self.proxy = proxy
  410. self.module = module
  411. self.user_hooks = user_hooks
  412. self.user_pre_hooks = user_pre_hooks
  413. def as_proxy(self) -> torch.fx.Proxy:
  414. return self.proxy
  415. def call_method(
  416. self,
  417. tx: "InstructionTranslator",
  418. name: str,
  419. args: list[VariableTracker],
  420. kwargs: dict[str, VariableTracker],
  421. ) -> VariableTracker:
  422. if name in ("setup_input_hook", "setup_output_hook"):
  423. return self._setup_hook(tx, name, *args, **kwargs)
  424. return super().call_method(tx, name, args, kwargs)
  425. def _setup_hook(
  426. self, tx: "InstructionTranslator", hook_method_name: str, args: VariableTracker
  427. ) -> VariableTracker:
  428. from .builder import wrap_fx_proxy
  429. return wrap_fx_proxy(
  430. tx,
  431. tx.output.create_proxy(
  432. "call_method",
  433. hook_method_name,
  434. (self.as_proxy(), args.as_proxy()),
  435. {},
  436. ),
  437. )