lazy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. from typing import Any, Protocol
  4. import torch
  5. from torch.nn.parameter import is_lazy
  6. __all__ = ["LazyModuleMixin"]
  7. class _LazyProtocol(Protocol):
  8. """This class is used to avoid errors with mypy checks for the attributes in a mixin.
  9. https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
  10. """
  11. def _register_load_state_dict_pre_hook(self, hook): ...
  12. def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): ...
  13. def _lazy_load_hook(
  14. self,
  15. state_dict,
  16. prefix,
  17. local_metadata,
  18. strict,
  19. missing_keys,
  20. unexpected_keys,
  21. error_msgs,
  22. ): ...
  23. def _get_name(self): ...
  24. def _infer_parameters(self, module, input): ...
  25. @property
  26. def _parameters(self): ...
  27. @property
  28. def _buffers(self): ...
  29. @property
  30. def _non_persistent_buffers_set(self): ...
  31. @property
  32. def _load_hook(self): ...
  33. @property
  34. def _initialize_hook(self): ...
  35. class LazyModuleMixin:
  36. r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules".
  37. .. warning:
  38. Lazy modules are an experimental new feature under active development,
  39. and their API is likely to change.
  40. Modules that lazily initialize parameters, or "lazy modules",
  41. derive the shapes of their parameters from the first input(s)
  42. to their forward method. Until that first forward they contain
  43. :class:`torch.nn.UninitializedParameter` s that should not be accessed
  44. or used, and afterward they contain regular :class:`torch.nn.Parameter` s.
  45. Lazy modules are convenient since they don't require computing some
  46. module arguments, like the :attr:`in_features` argument of a
  47. typical :class:`torch.nn.Linear`.
  48. After construction, networks with lazy modules should first
  49. be converted to the desired dtype and placed on the expected device.
  50. This is because lazy modules only perform shape inference so the usual dtype
  51. and device placement behavior applies.
  52. The lazy modules should then perform "dry runs" to initialize all the components in the module.
  53. These "dry runs" send inputs of the correct size, dtype, and device through
  54. the network and to each one of its lazy modules. After this the network can be used as usual.
  55. >>> # xdoctest: +SKIP
  56. >>> class LazyMLP(torch.nn.Module):
  57. ... def __init__(self) -> None:
  58. ... super().__init__()
  59. ... self.fc1 = torch.nn.LazyLinear(10)
  60. ... self.relu1 = torch.nn.ReLU()
  61. ... self.fc2 = torch.nn.LazyLinear(1)
  62. ... self.relu2 = torch.nn.ReLU()
  63. ...
  64. ... def forward(self, input):
  65. ... x = self.relu1(self.fc1(input))
  66. ... y = self.relu2(self.fc2(x))
  67. ... return y
  68. >>> # constructs a network with lazy modules
  69. >>> lazy_mlp = LazyMLP()
  70. >>> # transforms the network's device and dtype
  71. >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
  72. >>> lazy_mlp = lazy_mlp.cuda()
  73. >>> lazy_mlp
  74. LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True)
  75. (relu1): ReLU()
  76. (fc2): LazyLinear(in_features=0, out_features=1, bias=True)
  77. (relu2): ReLU()
  78. )
  79. >>> # performs a dry run to initialize the network's lazy modules
  80. >>> lazy_mlp(torch.ones(10, 10).cuda())
  81. >>> # after initialization, LazyLinear modules become regular Linear modules
  82. >>> lazy_mlp
  83. LazyMLP(
  84. (fc1): Linear(in_features=10, out_features=10, bias=True)
  85. (relu1): ReLU()
  86. (fc2): Linear(in_features=10, out_features=1, bias=True)
  87. (relu2): ReLU()
  88. )
  89. >>> # attaches an optimizer, since parameters can now be used as usual
  90. >>> optim = torch.optim.SGD(lazy_mlp.parameters(), lr=0.01)
  91. A final caveat when using lazy modules is that the order of initialization of a network's
  92. parameters may change, since the lazy modules are always initialized after other modules.
  93. For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module
  94. first and then a regular :class:`torch.nn.Linear` second, the second module would be
  95. initialized on construction and the first module would be initialized during the first dry run.
  96. This can cause the parameters of a network using lazy modules to be initialized differently
  97. than the parameters of a network without lazy modules as the order of parameter initializations,
  98. which often depends on a stateful random number generator, is different.
  99. Check :doc:`/notes/randomness` for more details.
  100. Lazy modules can be serialized with a state dict like other modules. For example:
  101. >>> lazy_mlp = LazyMLP()
  102. >>> # The state dict shows the uninitialized parameters
  103. >>> lazy_mlp.state_dict()
  104. OrderedDict({'fc1.weight': <UninitializedParameter>,
  105. 'fc1.bias': <UninitializedParameter>,
  106. 'fc2.weight': <UninitializedParameter>,
  107. 'fc2.bias': <UninitializedParameter>})
  108. Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize
  109. initialized LazyModules and they will remain initialized)
  110. >>> full_mlp = LazyMLP()
  111. >>> # Dry run to initialize another module
  112. >>> full_mlp.forward(torch.ones(10, 1))
  113. >>> # Load an initialized state into a lazy module
  114. >>> lazy_mlp.load_state_dict(full_mlp.state_dict())
  115. >>> # The state dict now holds valid values
  116. >>> lazy_mlp.state_dict()
  117. OrderedDict([('fc1.weight',
  118. tensor([[-0.3837],
  119. [ 0.0907],
  120. [ 0.6708],
  121. [-0.5223],
  122. [-0.9028],
  123. [ 0.2851],
  124. [-0.4537],
  125. [ 0.6813],
  126. [ 0.5766],
  127. [-0.8678]])),
  128. ('fc1.bias',
  129. tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
  130. 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
  131. ('fc2.weight',
  132. tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807,
  133. 0.2479, 0.1091]])),
  134. ('fc2.bias', tensor([0.0019]))])
  135. Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized
  136. when the state is loaded. This prevents using initialized modules in different contexts.
  137. """
  138. # modules inheriting from this will change their __class__ to the specified
  139. # one after they are fully initialized
  140. cls_to_become: type[Any] | None = None
  141. def __init__(self: _LazyProtocol, *args, **kwargs):
  142. # Mypy doesn't like this super call in a mixin
  143. super().__init__(*args, **kwargs) # type: ignore[misc]
  144. # pyrefly: ignore [read-only]
  145. self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
  146. # pyrefly: ignore [read-only]
  147. self._initialize_hook = self.register_forward_pre_hook(
  148. self._infer_parameters, with_kwargs=True
  149. )
  150. def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars):
  151. # This should be ideally implemented as a hook,
  152. # but we should override `detach` in the UninitializedParameter to return itself
  153. # which is not clean
  154. for name, param in self._parameters.items():
  155. if param is not None:
  156. if not (is_lazy(param) or keep_vars):
  157. param = param.detach()
  158. destination[prefix + name] = param
  159. for name, buf in self._buffers.items():
  160. if buf is not None and name not in self._non_persistent_buffers_set:
  161. if not (is_lazy(buf) or keep_vars):
  162. buf = buf.detach()
  163. destination[prefix + name] = buf
  164. def _lazy_load_hook(
  165. self: _LazyProtocol,
  166. state_dict,
  167. prefix,
  168. local_metadata,
  169. strict,
  170. missing_keys,
  171. unexpected_keys,
  172. error_msgs,
  173. ):
  174. """load_state_dict pre-hook function for lazy buffers and parameters.
  175. The purpose of this hook is to adjust the current state and/or
  176. ``state_dict`` being loaded so that a module instance serialized in
  177. both un/initialized state can be deserialized onto both un/initialized
  178. module instance.
  179. See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
  180. for the details of the hook specification.
  181. """
  182. for name, param in itertools.chain(
  183. self._parameters.items(), self._buffers.items()
  184. ):
  185. key = prefix + name
  186. if key in state_dict and param is not None:
  187. input_param = state_dict[key]
  188. if is_lazy(param):
  189. # The current parameter is not initialized but the one being loaded one is
  190. # create a new parameter based on the uninitialized one
  191. if not is_lazy(input_param):
  192. with torch.no_grad():
  193. param.materialize(input_param.shape)
  194. def initialize_parameters(self: _LazyProtocol, *args, **kwargs):
  195. r"""Initialize parameters according to the input batch properties.
  196. This adds an interface to isolate parameter initialization from the
  197. forward pass when doing parameter shape inference.
  198. """
  199. raise NotImplementedError(
  200. f"initialize_parameters is not implemented for {self.__class__.__name__}"
  201. )
  202. def has_uninitialized_params(self: _LazyProtocol):
  203. r"""Check if a module has parameters that are not initialized."""
  204. # This is to avoid the JIT to track this parameter and force
  205. # custom modules __setstate__ to add it
  206. params = self._parameters.values()
  207. buffers = self._buffers.values()
  208. for param in itertools.chain(params, buffers):
  209. if is_lazy(param):
  210. return True
  211. return False
  212. # torchrec tests the code consistency with the following code
  213. # fmt: off
  214. def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None):
  215. r"""Infers the size and initializes the parameters according to the provided input batch.
  216. Given a module that contains parameters that were declared inferable
  217. using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass
  218. in the complete module using the provided input to initialize all the parameters
  219. as needed.
  220. The module is set into evaluation mode before running the forward pass in order
  221. to avoid saving statistics or calculating gradients
  222. """
  223. kwargs = kwargs if kwargs else {}
  224. module.initialize_parameters(*args, **kwargs)
  225. if module.has_uninitialized_params():
  226. raise RuntimeError(f'module {self._get_name()} has not been fully initialized')
  227. module._initialize_hook.remove()
  228. module._load_hook.remove()
  229. delattr(module, '_initialize_hook')
  230. delattr(module, '_load_hook')
  231. if module.cls_to_become is not None:
  232. module.__class__ = module.cls_to_become
  233. # fmt: on
  234. def _replicate_for_data_parallel(self: _LazyProtocol):
  235. raise RuntimeError(
  236. "Modules with uninitialized parameters can't be used with `DataParallel`. "
  237. "Run a dummy forward pass to correctly initialize the modules"
  238. )