| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import sys
- from collections import defaultdict
- from contextlib import contextmanager
- import torch
- # Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
- # in context managers
- TORCH_INIT_FUNCTIONS = {
- "uniform_": torch.nn.init.uniform_,
- "normal_": torch.nn.init.normal_,
- "constant_": torch.nn.init.constant_,
- "ones_": torch.nn.init.ones_,
- "zeros_": torch.nn.init.zeros_,
- "eye_": torch.nn.init.eye_,
- "dirac_": torch.nn.init.dirac_,
- "xavier_uniform_": torch.nn.init.xavier_uniform_,
- "xavier_normal_": torch.nn.init.xavier_normal_,
- "kaiming_uniform_": torch.nn.init.kaiming_uniform_,
- "kaiming_normal_": torch.nn.init.kaiming_normal_,
- "trunc_normal_": torch.nn.init.trunc_normal_,
- "orthogonal_": torch.nn.init.orthogonal_,
- "sparse_": torch.nn.init.sparse_,
- }
- def uniform_(
- tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
- ) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
- return tensor
- def normal_(
- tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
- ) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
- return tensor
- def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
- return tensor
- def ones_(tensor: torch.Tensor) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["ones_"](tensor)
- return tensor
- def zeros_(tensor: torch.Tensor) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
- return tensor
- def eye_(tensor: torch.Tensor) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["eye_"](tensor)
- return tensor
- def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
- return tensor
- def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
- return tensor
- def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
- return tensor
- def kaiming_uniform_(
- tensor: torch.Tensor,
- a: float = 0,
- mode: str = "fan_in",
- nonlinearity: str = "leaky_relu",
- generator: torch.Generator | None = None,
- ) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
- tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
- )
- return tensor
- def kaiming_normal_(
- tensor: torch.Tensor,
- a: float = 0,
- mode: str = "fan_in",
- nonlinearity: str = "leaky_relu",
- generator: torch.Generator | None = None,
- ) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
- tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
- )
- return tensor
- def trunc_normal_(
- tensor: torch.Tensor,
- mean: float = 0.0,
- std: float = 1.0,
- a: float = -2.0,
- b: float = 2.0,
- generator: torch.Generator | None = None,
- ) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
- return tensor
- def orthogonal_(
- tensor: torch.Tensor,
- gain: float = 1,
- generator: torch.Generator | None = None,
- ) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
- return tensor
- def sparse_(
- tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
- ) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
- return tensor
- def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
- if not getattr(tensor, "_is_hf_initialized", False):
- with torch.no_grad():
- return tensor.copy_(other)
- return tensor
- def _variance_scaling(tensor, mode="fan_in", distribution="normal"):
- fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(tensor)
- if mode == "fan_in":
- denom = fan_in
- elif mode == "fan_out":
- denom = fan_out
- elif mode == "fan_avg":
- denom = (fan_in + fan_out) / 2
- variance = 1.0 / denom
- if distribution == "truncated_normal":
- trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
- elif distribution == "normal":
- normal_(tensor, std=math.sqrt(variance))
- elif distribution == "uniform":
- bound = math.sqrt(3 * variance)
- uniform_(tensor, -bound, bound)
- else:
- raise ValueError(f"invalid distribution {distribution}")
- def lecun_normal_(tensor):
- if not getattr(tensor, "_is_hf_initialized", False):
- _variance_scaling(tensor, mode="fan_in", distribution="truncated_normal")
- return tensor
- def default_flax_embed_init_(tensor):
- if not getattr(tensor, "_is_hf_initialized", False):
- _variance_scaling(tensor, mode="fan_in", distribution="normal")
- return tensor
- # Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
- # something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
- # where MultiHeadAttention lives), so the function name is binded at import time and just doing
- # `setattr(torch.nn.init, name, globals()[name])` is thus not enough
- # The following list should be enough for all torch versions we work with
- TORCH_MODULES_TO_PATCH = (
- "torch.nn.init",
- "torch.nn.modules.activation",
- "torch.nn.modules.transformer",
- "torch.nn.modules.linear",
- "torch.nn.modules.loss",
- "torch.nn.modules.batchnorm",
- "torch.nn.modules.conv",
- "torch.nn.modules.normalization",
- "torch.nn.modules.rnn",
- "torch.nn.modules.sparse",
- )
- @contextmanager
- def guard_torch_init_functions():
- """
- Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
- protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
- Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
- and for remote code, we also use this context manager.
- """
- originals = defaultdict(dict)
- try:
- # Replace all torch funcs by the ones in this file
- for module_name in TORCH_MODULES_TO_PATCH:
- if module_name in sys.modules:
- module = sys.modules[module_name]
- for func_name in TORCH_INIT_FUNCTIONS.keys():
- if hasattr(module, func_name):
- originals[module][func_name] = getattr(module, func_name)
- setattr(module, func_name, globals()[func_name])
- yield
- finally:
- # Set back the original functions on all modules
- for module, functions in originals.items():
- for func_name, func in functions.items():
- setattr(module, func_name, func)
- @contextmanager
- def no_init_weights():
- """
- Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
- This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
- with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
- """
- from .modeling_utils import PreTrainedModel
- def empty_func(*args, **kwargs):
- pass
- originals = defaultdict(dict)
- try:
- # Replace all torch funcs by empty ones
- for module_name in TORCH_MODULES_TO_PATCH:
- if module_name in sys.modules:
- module = sys.modules[module_name]
- for func_name in TORCH_INIT_FUNCTIONS.keys():
- if hasattr(module, func_name):
- originals[module][func_name] = getattr(module, func_name)
- setattr(module, func_name, empty_func)
- # Also patch our own `init_weights`
- original_init_weights = PreTrainedModel.init_weights
- PreTrainedModel.init_weights = empty_func
- yield
- finally:
- # Set back the original torch functions on all modules
- for module, functions in originals.items():
- for func_name, func in functions.items():
- setattr(module, func_name, func)
- # Set back `init_weights`
- PreTrainedModel.init_weights = original_init_weights
- @contextmanager
- def no_tie_weights():
- """
- Disable weight tying during loading with `from_pretrained`. This is needed as we want to have access to ALL
- weights in the state_dict during `from_pretrained`, and otherwise tying them would remove them from it, as it's
- called in `post_init` when instantiating.
- """
- from .modeling_utils import PreTrainedModel
- def empty_func(*args, **kwargs):
- pass
- try:
- original_tie_weights = PreTrainedModel.tie_weights
- PreTrainedModel.tie_weights = empty_func
- yield
- finally:
- # Set back the original
- PreTrainedModel.tie_weights = original_tie_weights
- @contextmanager
- def meta_device_safe_creation_ops():
- """
- During meta-device model initialisation, ``torch.linspace`` produces meta
- tensors that have no data. Custom models loaded from the Hub (remote code)
- often call ``.item()`` on these tensors to compute scalar hyperparameters
- (e.g. stochastic-depth / drop-path schedules). Native transformers models
- already pass ``device="cpu"`` explicitly for such calls (see e.g.
- ``modeling_swin.py``, ``modeling_pvt_v2.py``), but remote-code models
- written before v5 do not.
- This context manager patches ``torch.linspace`` to default to
- ``device="cpu"`` when no explicit device is requested, matching the best
- practice already used throughout transformers. Calls that supply an
- explicit ``device`` argument (e.g. ``device=self.logits.device``) are left
- untouched. ``torch.arange`` is intentionally NOT patched because it is
- used in RoPE computations where the device must match model parameters.
- """
- original_linspace = torch.linspace
- def _safe_linspace(*args, **kwargs):
- kwargs.setdefault("device", "cpu")
- return original_linspace(*args, **kwargs)
- torch.linspace = _safe_linspace
- try:
- yield
- finally:
- torch.linspace = original_linspace
|