contract.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # mypy: allow-untyped-defs
  2. import uuid
  3. from collections import OrderedDict
  4. from collections.abc import Callable
  5. from functools import wraps
  6. from typing import Concatenate, Generic, Protocol
  7. from typing_extensions import ParamSpec, TypeVar
  8. import torch
  9. import torch.nn as nn
  10. from torch.distributed._composable_state import _State
  11. from torch.distributed.utils import _get_root_modules
  12. _T = TypeVar("_T", covariant=True)
  13. _P = ParamSpec("_P")
  14. def generate_state_key(string="__composable_api_state_key"):
  15. return f"{string}_{str(uuid.uuid4())}"
  16. STATE_KEY = generate_state_key()
  17. REGISTRY_KEY = generate_state_key()
  18. # TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
  19. # we can add args and kwargs here, and then we can detect whether fully_shard
  20. # is combined with reentrant activation checkpointing and error out with a clear
  21. # message.
  22. class RegistryItem:
  23. pass
  24. _TState = TypeVar("_TState", bound="_State", covariant=True)
  25. _M = TypeVar("_M", nn.Module, list[nn.Module])
  26. class _ContractFn(Protocol, Generic[_P, _T, _TState]):
  27. def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
  28. def state(self, module: nn.Module) -> _TState: ...
  29. def contract(
  30. state_cls: type[_TState] = _State, # type: ignore[assignment]
  31. ) -> Callable[
  32. [Callable[Concatenate[_M, _P], _M]],
  33. _ContractFn[Concatenate[_M, _P], _M, _TState],
  34. ]:
  35. r"""
  36. Decorate a function as a composable distributed API, where the first
  37. argument of the function must be an :class:`nn.Module` instance or sequence
  38. of :class:`nn.Module` instances.
  39. The decorator verifies that the decorated function does not modify
  40. fully-qualified names (FQNs) for parameters, buffers, or modules. The
  41. decorated function can return different module instances than the input
  42. modules; the FQN invariant will be enforced following the input order.
  43. When a function ``func`` is decorated by ``@contract()``, a
  44. ``.state(module: nn.Module)`` method will be installed to the decorated
  45. function. Then you can retrieve and modify the state on a module by calling
  46. ``func.state(module)``.
  47. Example::
  48. >>> # xdoctest: +SKIP
  49. >>> import torch.nn as nn
  50. >>>
  51. >>> class MyModel(nn.Module):
  52. >>> def __init__(self) -> None:
  53. >>> super().__init__()
  54. >>> self.l1 = nn.Linear(10, 10)
  55. >>> self.l2 = nn.Linear(10, 10)
  56. >>>
  57. >>> def forward(self, x):
  58. >>> return self.l2(self.l1(x))
  59. >>>
  60. >>> @contract()
  61. >>> def my_feature(module: nn.Module) -> nn.Module:
  62. >>> my_feature.state(module).some_state = "any value"
  63. >>> return module
  64. >>>
  65. >>> model = MyModel()
  66. >>> my_feature(model.l1)
  67. >>> assert my_feature.state(model.l1).some_state == "any value"
  68. >>> my_feature(model.l2)
  69. >>> model(torch.randn(2, 10)).sum().backward()
  70. """
  71. # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
  72. @wraps(state_cls) # type: ignore[arg-type]
  73. def inner(
  74. func: Callable[Concatenate[_M, _P], _M],
  75. ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
  76. @wraps(func)
  77. def wrapper(
  78. module: _M,
  79. *args: _P.args,
  80. **kwargs: _P.kwargs,
  81. ) -> _M:
  82. inp_module = module
  83. modules: list[nn.Module]
  84. if isinstance(module, nn.Module):
  85. modules = [module]
  86. else:
  87. # If the user passes a sequence of modules, then we assume that
  88. # we only need to insert the state object on the root modules
  89. # (i.e. those without a parent) among the passed-in modules.
  90. # pyrefly: ignore [no-matching-overload]
  91. modules = _get_root_modules(list(module))
  92. state = state_cls() # shared across all modules
  93. registry_item = RegistryItem() # shared across all modules
  94. # `func` is allowed to return different module instances than the
  95. # input modules as long as FQNs are preserved following the input
  96. # module order
  97. all_orig_named_params: list[dict[str, nn.Parameter]] = []
  98. all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
  99. all_orig_named_modules: list[dict[str, nn.Module]] = []
  100. # pyrefly: ignore [bad-assignment]
  101. for module in modules:
  102. default_all_state: dict[Callable, _State] = OrderedDict()
  103. default_registry: dict[str, RegistryItem] = OrderedDict()
  104. all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
  105. STATE_KEY, default_all_state
  106. )
  107. if not isinstance(all_state, dict):
  108. raise AssertionError(
  109. f"Distributed composable API states corrupted: {all_state}"
  110. )
  111. registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
  112. REGISTRY_KEY, default_registry
  113. )
  114. if not isinstance(registry, dict):
  115. raise AssertionError(
  116. f"Distributed composable API registry corrupted: {registry}"
  117. )
  118. if func in all_state or func.__name__ in registry:
  119. raise AssertionError(
  120. "Each distinct composable distributed API can only be applied to a "
  121. f"module once. {func.__name__} has already been applied to the "
  122. f"following module:\n{module}"
  123. )
  124. all_state.setdefault(func, state)
  125. registry.setdefault(func.__name__, registry_item)
  126. # pyrefly: ignore [missing-attribute]
  127. all_orig_named_params.append(OrderedDict(module.named_parameters()))
  128. # pyrefly: ignore [missing-attribute]
  129. all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
  130. # pyrefly: ignore [missing-attribute]
  131. all_orig_named_modules.append(OrderedDict(module.named_modules()))
  132. updated = func(inp_module, *args, **kwargs)
  133. if updated is None:
  134. updated = inp_module # type: ignore[assignment]
  135. updated_modules: list[nn.Module]
  136. if isinstance(updated, nn.Module):
  137. updated_modules = [updated]
  138. else:
  139. updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload]
  140. all_new_named_params: list[dict[str, nn.Parameter]] = []
  141. all_new_named_buffers: list[dict[str, torch.Tensor]] = []
  142. all_new_named_modules: list[dict[str, nn.Module]] = []
  143. # pyrefly: ignore [bad-assignment]
  144. for module in updated_modules:
  145. # pyrefly: ignore [missing-attribute]
  146. all_new_named_params.append(OrderedDict(module.named_parameters()))
  147. # pyrefly: ignore [missing-attribute]
  148. all_new_named_buffers.append(OrderedDict(module.named_buffers()))
  149. # pyrefly: ignore [missing-attribute]
  150. all_new_named_modules.append(OrderedDict(module.named_modules()))
  151. num_orig_modules = len(all_orig_named_modules)
  152. num_new_modules = len(all_new_named_modules)
  153. if num_orig_modules != num_new_modules:
  154. raise AssertionError(
  155. f"{func.__name__} should return the same number of modules as input modules"
  156. f"Inputs: {num_orig_modules} modules\n"
  157. f"Outputs: {num_new_modules} modules"
  158. )
  159. def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
  160. if orig_fqns == new_fqns:
  161. return
  162. orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
  163. orig_only = orig_fqn_set - new_fqn_set
  164. new_only = new_fqn_set - orig_fqn_set
  165. if len(orig_only) or len(new_only):
  166. raise RuntimeError(
  167. f"{check_key}"
  168. "Composable distributed API implementations cannot modify FQNs.\n"
  169. f"FQNs only in original: {orig_only}\n"
  170. f"FQNs only in new: {new_only}"
  171. )
  172. else:
  173. raise RuntimeError(
  174. f"{check_key}"
  175. "Composable distributed API implementations cannot modify "
  176. "the order of FQNs.\n"
  177. f"Original FQNs: {orig_only}\n"
  178. f"New FQNs: {new_only}"
  179. )
  180. for orig_named_params, new_named_params in zip(
  181. all_orig_named_params, all_new_named_params
  182. ):
  183. check_fqn(
  184. list(orig_named_params.keys()),
  185. list(new_named_params.keys()),
  186. "Checking parameters: ",
  187. )
  188. for orig_named_buffers, new_named_buffers in zip(
  189. all_orig_named_buffers, all_new_named_buffers
  190. ):
  191. check_fqn(
  192. list(orig_named_buffers.keys()),
  193. list(new_named_buffers.keys()),
  194. "Checking buffers: ",
  195. )
  196. for orig_named_modules, new_named_modules in zip(
  197. all_orig_named_modules, all_new_named_modules
  198. ):
  199. check_fqn(
  200. list(orig_named_modules.keys()),
  201. list(new_named_modules.keys()),
  202. "Checking modules: ",
  203. )
  204. # TODO: verify that installed distributed paradigms are compatible with
  205. # each other.
  206. return updated
  207. def get_state(module: nn.Module) -> _State:
  208. return module.__dict__.setdefault( # type: ignore[call-overload]
  209. STATE_KEY,
  210. {}, # TODO(@yhcharles): this is a temporary fix, need a better way
  211. ).get(func) # type: ignore[call-overload]
  212. wrapper.state = get_state # type: ignore[attr-defined]
  213. return wrapper # type: ignore[return-value]
  214. return inner # type: ignore[return-value]
  215. def _get_registry(module: nn.Module) -> dict[str, RegistryItem] | None:
  216. r"""
  217. Get an ``OrderedDict`` of composable APIs that have been applied to the
  218. ``module``, indexed by the API name. If no API has been applied, then this
  219. returns ``None``.
  220. """
  221. return getattr(module, REGISTRY_KEY, None)