initialization.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import sys
  16. from collections import defaultdict
  17. from contextlib import contextmanager
  18. import torch
  19. # Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
  20. # in context managers
  21. TORCH_INIT_FUNCTIONS = {
  22. "uniform_": torch.nn.init.uniform_,
  23. "normal_": torch.nn.init.normal_,
  24. "constant_": torch.nn.init.constant_,
  25. "ones_": torch.nn.init.ones_,
  26. "zeros_": torch.nn.init.zeros_,
  27. "eye_": torch.nn.init.eye_,
  28. "dirac_": torch.nn.init.dirac_,
  29. "xavier_uniform_": torch.nn.init.xavier_uniform_,
  30. "xavier_normal_": torch.nn.init.xavier_normal_,
  31. "kaiming_uniform_": torch.nn.init.kaiming_uniform_,
  32. "kaiming_normal_": torch.nn.init.kaiming_normal_,
  33. "trunc_normal_": torch.nn.init.trunc_normal_,
  34. "orthogonal_": torch.nn.init.orthogonal_,
  35. "sparse_": torch.nn.init.sparse_,
  36. }
  37. def uniform_(
  38. tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
  39. ) -> torch.Tensor:
  40. if not getattr(tensor, "_is_hf_initialized", False):
  41. return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
  42. return tensor
  43. def normal_(
  44. tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
  45. ) -> torch.Tensor:
  46. if not getattr(tensor, "_is_hf_initialized", False):
  47. return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
  48. return tensor
  49. def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
  50. if not getattr(tensor, "_is_hf_initialized", False):
  51. return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
  52. return tensor
  53. def ones_(tensor: torch.Tensor) -> torch.Tensor:
  54. if not getattr(tensor, "_is_hf_initialized", False):
  55. return TORCH_INIT_FUNCTIONS["ones_"](tensor)
  56. return tensor
  57. def zeros_(tensor: torch.Tensor) -> torch.Tensor:
  58. if not getattr(tensor, "_is_hf_initialized", False):
  59. return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
  60. return tensor
  61. def eye_(tensor: torch.Tensor) -> torch.Tensor:
  62. if not getattr(tensor, "_is_hf_initialized", False):
  63. return TORCH_INIT_FUNCTIONS["eye_"](tensor)
  64. return tensor
  65. def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
  66. if not getattr(tensor, "_is_hf_initialized", False):
  67. return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
  68. return tensor
  69. def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
  70. if not getattr(tensor, "_is_hf_initialized", False):
  71. return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
  72. return tensor
  73. def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
  74. if not getattr(tensor, "_is_hf_initialized", False):
  75. return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
  76. return tensor
  77. def kaiming_uniform_(
  78. tensor: torch.Tensor,
  79. a: float = 0,
  80. mode: str = "fan_in",
  81. nonlinearity: str = "leaky_relu",
  82. generator: torch.Generator | None = None,
  83. ) -> torch.Tensor:
  84. if not getattr(tensor, "_is_hf_initialized", False):
  85. return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
  86. tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
  87. )
  88. return tensor
  89. def kaiming_normal_(
  90. tensor: torch.Tensor,
  91. a: float = 0,
  92. mode: str = "fan_in",
  93. nonlinearity: str = "leaky_relu",
  94. generator: torch.Generator | None = None,
  95. ) -> torch.Tensor:
  96. if not getattr(tensor, "_is_hf_initialized", False):
  97. return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
  98. tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
  99. )
  100. return tensor
  101. def trunc_normal_(
  102. tensor: torch.Tensor,
  103. mean: float = 0.0,
  104. std: float = 1.0,
  105. a: float = -2.0,
  106. b: float = 2.0,
  107. generator: torch.Generator | None = None,
  108. ) -> torch.Tensor:
  109. if not getattr(tensor, "_is_hf_initialized", False):
  110. return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
  111. return tensor
  112. def orthogonal_(
  113. tensor: torch.Tensor,
  114. gain: float = 1,
  115. generator: torch.Generator | None = None,
  116. ) -> torch.Tensor:
  117. if not getattr(tensor, "_is_hf_initialized", False):
  118. return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
  119. return tensor
  120. def sparse_(
  121. tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
  122. ) -> torch.Tensor:
  123. if not getattr(tensor, "_is_hf_initialized", False):
  124. return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
  125. return tensor
  126. def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
  127. if not getattr(tensor, "_is_hf_initialized", False):
  128. with torch.no_grad():
  129. return tensor.copy_(other)
  130. return tensor
  131. def _variance_scaling(tensor, mode="fan_in", distribution="normal"):
  132. fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(tensor)
  133. if mode == "fan_in":
  134. denom = fan_in
  135. elif mode == "fan_out":
  136. denom = fan_out
  137. elif mode == "fan_avg":
  138. denom = (fan_in + fan_out) / 2
  139. variance = 1.0 / denom
  140. if distribution == "truncated_normal":
  141. trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
  142. elif distribution == "normal":
  143. normal_(tensor, std=math.sqrt(variance))
  144. elif distribution == "uniform":
  145. bound = math.sqrt(3 * variance)
  146. uniform_(tensor, -bound, bound)
  147. else:
  148. raise ValueError(f"invalid distribution {distribution}")
  149. def lecun_normal_(tensor):
  150. if not getattr(tensor, "_is_hf_initialized", False):
  151. _variance_scaling(tensor, mode="fan_in", distribution="truncated_normal")
  152. return tensor
  153. def default_flax_embed_init_(tensor):
  154. if not getattr(tensor, "_is_hf_initialized", False):
  155. _variance_scaling(tensor, mode="fan_in", distribution="normal")
  156. return tensor
  157. # Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
  158. # something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
  159. # where MultiHeadAttention lives), so the function name is binded at import time and just doing
  160. # `setattr(torch.nn.init, name, globals()[name])` is thus not enough
  161. # The following list should be enough for all torch versions we work with
  162. TORCH_MODULES_TO_PATCH = (
  163. "torch.nn.init",
  164. "torch.nn.modules.activation",
  165. "torch.nn.modules.transformer",
  166. "torch.nn.modules.linear",
  167. "torch.nn.modules.loss",
  168. "torch.nn.modules.batchnorm",
  169. "torch.nn.modules.conv",
  170. "torch.nn.modules.normalization",
  171. "torch.nn.modules.rnn",
  172. "torch.nn.modules.sparse",
  173. )
  174. @contextmanager
  175. def guard_torch_init_functions():
  176. """
  177. Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
  178. protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
  179. Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
  180. and for remote code, we also use this context manager.
  181. """
  182. originals = defaultdict(dict)
  183. try:
  184. # Replace all torch funcs by the ones in this file
  185. for module_name in TORCH_MODULES_TO_PATCH:
  186. if module_name in sys.modules:
  187. module = sys.modules[module_name]
  188. for func_name in TORCH_INIT_FUNCTIONS.keys():
  189. if hasattr(module, func_name):
  190. originals[module][func_name] = getattr(module, func_name)
  191. setattr(module, func_name, globals()[func_name])
  192. yield
  193. finally:
  194. # Set back the original functions on all modules
  195. for module, functions in originals.items():
  196. for func_name, func in functions.items():
  197. setattr(module, func_name, func)
  198. @contextmanager
  199. def no_init_weights():
  200. """
  201. Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
  202. This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
  203. with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
  204. """
  205. from .modeling_utils import PreTrainedModel
  206. def empty_func(*args, **kwargs):
  207. pass
  208. originals = defaultdict(dict)
  209. try:
  210. # Replace all torch funcs by empty ones
  211. for module_name in TORCH_MODULES_TO_PATCH:
  212. if module_name in sys.modules:
  213. module = sys.modules[module_name]
  214. for func_name in TORCH_INIT_FUNCTIONS.keys():
  215. if hasattr(module, func_name):
  216. originals[module][func_name] = getattr(module, func_name)
  217. setattr(module, func_name, empty_func)
  218. # Also patch our own `init_weights`
  219. original_init_weights = PreTrainedModel.init_weights
  220. PreTrainedModel.init_weights = empty_func
  221. yield
  222. finally:
  223. # Set back the original torch functions on all modules
  224. for module, functions in originals.items():
  225. for func_name, func in functions.items():
  226. setattr(module, func_name, func)
  227. # Set back `init_weights`
  228. PreTrainedModel.init_weights = original_init_weights
  229. @contextmanager
  230. def no_tie_weights():
  231. """
  232. Disable weight tying during loading with `from_pretrained`. This is needed as we want to have access to ALL
  233. weights in the state_dict during `from_pretrained`, and otherwise tying them would remove them from it, as it's
  234. called in `post_init` when instantiating.
  235. """
  236. from .modeling_utils import PreTrainedModel
  237. def empty_func(*args, **kwargs):
  238. pass
  239. try:
  240. original_tie_weights = PreTrainedModel.tie_weights
  241. PreTrainedModel.tie_weights = empty_func
  242. yield
  243. finally:
  244. # Set back the original
  245. PreTrainedModel.tie_weights = original_tie_weights
  246. @contextmanager
  247. def meta_device_safe_creation_ops():
  248. """
  249. During meta-device model initialisation, ``torch.linspace`` produces meta
  250. tensors that have no data. Custom models loaded from the Hub (remote code)
  251. often call ``.item()`` on these tensors to compute scalar hyperparameters
  252. (e.g. stochastic-depth / drop-path schedules). Native transformers models
  253. already pass ``device="cpu"`` explicitly for such calls (see e.g.
  254. ``modeling_swin.py``, ``modeling_pvt_v2.py``), but remote-code models
  255. written before v5 do not.
  256. This context manager patches ``torch.linspace`` to default to
  257. ``device="cpu"`` when no explicit device is requested, matching the best
  258. practice already used throughout transformers. Calls that supply an
  259. explicit ``device`` argument (e.g. ``device=self.logits.device``) are left
  260. untouched. ``torch.arange`` is intentionally NOT patched because it is
  261. used in RoPE computations where the device must match model parameters.
  262. """
  263. original_linspace = torch.linspace
  264. def _safe_linspace(*args, **kwargs):
  265. kwargs.setdefault("device", "cpu")
  266. return original_linspace(*args, **kwargs)
  267. torch.linspace = _safe_linspace
  268. try:
  269. yield
  270. finally:
  271. torch.linspace = original_linspace