| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- # mypy: allow-untyped-defs
- import uuid
- from collections import OrderedDict
- from collections.abc import Callable
- from functools import wraps
- from typing import Concatenate, Generic, Protocol
- from typing_extensions import ParamSpec, TypeVar
- import torch
- import torch.nn as nn
- from torch.distributed._composable_state import _State
- from torch.distributed.utils import _get_root_modules
- _T = TypeVar("_T", covariant=True)
- _P = ParamSpec("_P")
- def generate_state_key(string="__composable_api_state_key"):
- return f"{string}_{str(uuid.uuid4())}"
- STATE_KEY = generate_state_key()
- REGISTRY_KEY = generate_state_key()
- # TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
- # we can add args and kwargs here, and then we can detect whether fully_shard
- # is combined with reentrant activation checkpointing and error out with a clear
- # message.
- class RegistryItem:
- pass
- _TState = TypeVar("_TState", bound="_State", covariant=True)
- _M = TypeVar("_M", nn.Module, list[nn.Module])
- class _ContractFn(Protocol, Generic[_P, _T, _TState]):
- def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
- def state(self, module: nn.Module) -> _TState: ...
- def contract(
- state_cls: type[_TState] = _State, # type: ignore[assignment]
- ) -> Callable[
- [Callable[Concatenate[_M, _P], _M]],
- _ContractFn[Concatenate[_M, _P], _M, _TState],
- ]:
- r"""
- Decorate a function as a composable distributed API, where the first
- argument of the function must be an :class:`nn.Module` instance or sequence
- of :class:`nn.Module` instances.
- The decorator verifies that the decorated function does not modify
- fully-qualified names (FQNs) for parameters, buffers, or modules. The
- decorated function can return different module instances than the input
- modules; the FQN invariant will be enforced following the input order.
- When a function ``func`` is decorated by ``@contract()``, a
- ``.state(module: nn.Module)`` method will be installed to the decorated
- function. Then you can retrieve and modify the state on a module by calling
- ``func.state(module)``.
- Example::
- >>> # xdoctest: +SKIP
- >>> import torch.nn as nn
- >>>
- >>> class MyModel(nn.Module):
- >>> def __init__(self) -> None:
- >>> super().__init__()
- >>> self.l1 = nn.Linear(10, 10)
- >>> self.l2 = nn.Linear(10, 10)
- >>>
- >>> def forward(self, x):
- >>> return self.l2(self.l1(x))
- >>>
- >>> @contract()
- >>> def my_feature(module: nn.Module) -> nn.Module:
- >>> my_feature.state(module).some_state = "any value"
- >>> return module
- >>>
- >>> model = MyModel()
- >>> my_feature(model.l1)
- >>> assert my_feature.state(model.l1).some_state == "any value"
- >>> my_feature(model.l2)
- >>> model(torch.randn(2, 10)).sum().backward()
- """
- # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
- @wraps(state_cls) # type: ignore[arg-type]
- def inner(
- func: Callable[Concatenate[_M, _P], _M],
- ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]:
- @wraps(func)
- def wrapper(
- module: _M,
- *args: _P.args,
- **kwargs: _P.kwargs,
- ) -> _M:
- inp_module = module
- modules: list[nn.Module]
- if isinstance(module, nn.Module):
- modules = [module]
- else:
- # If the user passes a sequence of modules, then we assume that
- # we only need to insert the state object on the root modules
- # (i.e. those without a parent) among the passed-in modules.
- # pyrefly: ignore [no-matching-overload]
- modules = _get_root_modules(list(module))
- state = state_cls() # shared across all modules
- registry_item = RegistryItem() # shared across all modules
- # `func` is allowed to return different module instances than the
- # input modules as long as FQNs are preserved following the input
- # module order
- all_orig_named_params: list[dict[str, nn.Parameter]] = []
- all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
- all_orig_named_modules: list[dict[str, nn.Module]] = []
- # pyrefly: ignore [bad-assignment]
- for module in modules:
- default_all_state: dict[Callable, _State] = OrderedDict()
- default_registry: dict[str, RegistryItem] = OrderedDict()
- all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
- STATE_KEY, default_all_state
- )
- if not isinstance(all_state, dict):
- raise AssertionError(
- f"Distributed composable API states corrupted: {all_state}"
- )
- registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
- REGISTRY_KEY, default_registry
- )
- if not isinstance(registry, dict):
- raise AssertionError(
- f"Distributed composable API registry corrupted: {registry}"
- )
- if func in all_state or func.__name__ in registry:
- raise AssertionError(
- "Each distinct composable distributed API can only be applied to a "
- f"module once. {func.__name__} has already been applied to the "
- f"following module:\n{module}"
- )
- all_state.setdefault(func, state)
- registry.setdefault(func.__name__, registry_item)
- # pyrefly: ignore [missing-attribute]
- all_orig_named_params.append(OrderedDict(module.named_parameters()))
- # pyrefly: ignore [missing-attribute]
- all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
- # pyrefly: ignore [missing-attribute]
- all_orig_named_modules.append(OrderedDict(module.named_modules()))
- updated = func(inp_module, *args, **kwargs)
- if updated is None:
- updated = inp_module # type: ignore[assignment]
- updated_modules: list[nn.Module]
- if isinstance(updated, nn.Module):
- updated_modules = [updated]
- else:
- updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload]
- all_new_named_params: list[dict[str, nn.Parameter]] = []
- all_new_named_buffers: list[dict[str, torch.Tensor]] = []
- all_new_named_modules: list[dict[str, nn.Module]] = []
- # pyrefly: ignore [bad-assignment]
- for module in updated_modules:
- # pyrefly: ignore [missing-attribute]
- all_new_named_params.append(OrderedDict(module.named_parameters()))
- # pyrefly: ignore [missing-attribute]
- all_new_named_buffers.append(OrderedDict(module.named_buffers()))
- # pyrefly: ignore [missing-attribute]
- all_new_named_modules.append(OrderedDict(module.named_modules()))
- num_orig_modules = len(all_orig_named_modules)
- num_new_modules = len(all_new_named_modules)
- if num_orig_modules != num_new_modules:
- raise AssertionError(
- f"{func.__name__} should return the same number of modules as input modules"
- f"Inputs: {num_orig_modules} modules\n"
- f"Outputs: {num_new_modules} modules"
- )
- def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
- if orig_fqns == new_fqns:
- return
- orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
- orig_only = orig_fqn_set - new_fqn_set
- new_only = new_fqn_set - orig_fqn_set
- if len(orig_only) or len(new_only):
- raise RuntimeError(
- f"{check_key}"
- "Composable distributed API implementations cannot modify FQNs.\n"
- f"FQNs only in original: {orig_only}\n"
- f"FQNs only in new: {new_only}"
- )
- else:
- raise RuntimeError(
- f"{check_key}"
- "Composable distributed API implementations cannot modify "
- "the order of FQNs.\n"
- f"Original FQNs: {orig_only}\n"
- f"New FQNs: {new_only}"
- )
- for orig_named_params, new_named_params in zip(
- all_orig_named_params, all_new_named_params
- ):
- check_fqn(
- list(orig_named_params.keys()),
- list(new_named_params.keys()),
- "Checking parameters: ",
- )
- for orig_named_buffers, new_named_buffers in zip(
- all_orig_named_buffers, all_new_named_buffers
- ):
- check_fqn(
- list(orig_named_buffers.keys()),
- list(new_named_buffers.keys()),
- "Checking buffers: ",
- )
- for orig_named_modules, new_named_modules in zip(
- all_orig_named_modules, all_new_named_modules
- ):
- check_fqn(
- list(orig_named_modules.keys()),
- list(new_named_modules.keys()),
- "Checking modules: ",
- )
- # TODO: verify that installed distributed paradigms are compatible with
- # each other.
- return updated
- def get_state(module: nn.Module) -> _State:
- return module.__dict__.setdefault( # type: ignore[call-overload]
- STATE_KEY,
- {}, # TODO(@yhcharles): this is a temporary fix, need a better way
- ).get(func) # type: ignore[call-overload]
- wrapper.state = get_state # type: ignore[attr-defined]
- return wrapper # type: ignore[return-value]
- return inner # type: ignore[return-value]
- def _get_registry(module: nn.Module) -> dict[str, RegistryItem] | None:
- r"""
- Get an ``OrderedDict`` of composable APIs that have been applied to the
- ``module``, indexed by the API name. If no API has been applied, then this
- returns ``None``.
- """
- return getattr(module, REGISTRY_KEY, None)
|