| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- import collections.abc
- import math
- import re
- from collections import defaultdict
- from itertools import chain
- from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn as nn
- from torch import Tensor
- from timm.layers import use_reentrant_ckpt
- __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
- 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint',
- 'reinit_non_persistent_buffers']
- def model_parameters(model: nn.Module, exclude_head: bool = False):
- if exclude_head:
- # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
- return [p for p in model.parameters()][:-2]
- else:
- return model.parameters()
- def named_apply(
- fn: Callable,
- module: nn.Module, name='',
- depth_first: bool = True,
- include_root: bool = False,
- ) -> nn.Module:
- if not depth_first and include_root:
- fn(module=module, name=name)
- for child_name, child_module in module.named_children():
- child_name = '.'.join((name, child_name)) if name else child_name
- named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
- if depth_first and include_root:
- fn(module=module, name=name)
- return module
- def named_modules(
- module: nn.Module,
- name: str = '',
- depth_first: bool = True,
- include_root: bool = False,
- ):
- if not depth_first and include_root:
- yield name, module
- for child_name, child_module in module.named_children():
- child_name = '.'.join((name, child_name)) if name else child_name
- yield from named_modules(
- module=child_module, name=child_name, depth_first=depth_first, include_root=True)
- if depth_first and include_root:
- yield name, module
- def named_modules_with_params(
- module: nn.Module,
- name: str = '',
- depth_first: bool = True,
- include_root: bool = False,
- ):
- if module._parameters and not depth_first and include_root:
- yield name, module
- for child_name, child_module in module.named_children():
- child_name = '.'.join((name, child_name)) if name else child_name
- yield from named_modules_with_params(
- module=child_module, name=child_name, depth_first=depth_first, include_root=True)
- if module._parameters and depth_first and include_root:
- yield name, module
- MATCH_PREV_GROUP = (99999,)
- def group_with_matcher(
- named_objects: Iterator[Tuple[str, Any]],
- group_matcher: Union[Dict, Callable],
- return_values: bool = False,
- reverse: bool = False
- ):
- if isinstance(group_matcher, dict):
- # dictionary matcher contains a dict of raw-string regex expr that must be compiled
- compiled = []
- for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
- if mspec is None:
- continue
- # map all matching specifications into 3-tuple (compiled re, prefix, suffix)
- if isinstance(mspec, (tuple, list)):
- # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
- for sspec in mspec:
- compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
- else:
- compiled += [(re.compile(mspec), (group_ordinal,), None)]
- group_matcher = compiled
- def _get_grouping(name):
- if isinstance(group_matcher, (list, tuple)):
- for match_fn, prefix, suffix in group_matcher:
- r = match_fn.match(name)
- if r:
- parts = (prefix, r.groups(), suffix)
- # map all tuple elem to int for numeric sort, filter out None entries
- return tuple(map(float, chain.from_iterable(filter(None, parts))))
- return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
- else:
- ord = group_matcher(name)
- if not isinstance(ord, collections.abc.Iterable):
- return ord,
- return tuple(ord)
- # map layers into groups via ordinals (ints or tuples of ints) from matcher
- grouping = defaultdict(list)
- for k, v in named_objects:
- grouping[_get_grouping(k)].append(v if return_values else k)
- # remap to integers
- layer_id_to_param = defaultdict(list)
- lid = -1
- for k in sorted(filter(lambda x: x is not None, grouping.keys())):
- if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
- lid += 1
- layer_id_to_param[lid].extend(grouping[k])
- if reverse:
- assert not return_values, "reverse mapping only sensible for name output"
- # output reverse mapping
- param_to_layer_id = {}
- for lid, lm in layer_id_to_param.items():
- for n in lm:
- param_to_layer_id[n] = lid
- return param_to_layer_id
- return layer_id_to_param
- def group_parameters(
- module: nn.Module,
- group_matcher,
- return_values: bool = False,
- reverse: bool = False,
- ):
- return group_with_matcher(
- module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
- def group_modules(
- module: nn.Module,
- group_matcher,
- return_values: bool = False,
- reverse: bool = False,
- ):
- return group_with_matcher(
- named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
- def flatten_modules(
- named_modules: Iterator[Tuple[str, nn.Module]],
- depth: int = 1,
- prefix: Union[str, Tuple[str, ...]] = '',
- module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
- ):
- prefix_is_tuple = isinstance(prefix, tuple)
- if isinstance(module_types, str):
- if module_types == 'container':
- module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
- else:
- module_types = (nn.Sequential,)
- for name, module in named_modules:
- if depth and isinstance(module, module_types):
- yield from flatten_modules(
- module.named_children(),
- depth - 1,
- prefix=(name,) if prefix_is_tuple else name,
- module_types=module_types,
- )
- else:
- if prefix_is_tuple:
- name = prefix + (name,)
- yield name, module
- else:
- if prefix:
- name = '.'.join([prefix, name])
- yield name, module
- def checkpoint(
- function,
- *args,
- use_reentrant: Optional[bool] = None,
- **kwargs,
- ):
- """ checkpoint wrapper fn
- A thin wrapper around torch.utils.checkpoint.checkpoint to default
- use_reentrant to False
- """
- if use_reentrant is None:
- use_reentrant = use_reentrant_ckpt()
- return torch.utils.checkpoint.checkpoint(
- function,
- *args,
- use_reentrant=use_reentrant,
- **kwargs,
- )
- def checkpoint_seq(
- functions,
- x,
- every: int = 1,
- flatten: bool = False,
- skip_last: bool = False,
- use_reentrant: Optional[bool] = None,
- ):
- r"""A helper function for checkpointing sequential models.
- Sequential models execute a list of modules/functions in order
- (sequentially). Therefore, we can divide such a sequence into segments
- and checkpoint each segment. All segments except run in :func:`torch.no_grad`
- manner, i.e., not storing the intermediate activations. The inputs of each
- checkpointed segment will be saved for re-running the segment in the backward pass.
- See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
- .. warning::
- Checkpointing currently only supports :func:`torch.autograd.backward`
- and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
- is not supported.
- .. warning:
- At least one of the inputs needs to have :code:`requires_grad=True` if
- grads are needed for model inputs, otherwise the checkpointed part of the
- model won't have gradients.
- Args:
- functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
- x: A Tensor that is input to :attr:`functions`
- every: checkpoint every-n functions (default: 1)
- flatten: flatten nn.Sequential of nn.Sequentials
- skip_last: skip checkpointing the last function in the sequence if True
- use_reentrant: Use re-entrant checkpointing
- Returns:
- Output of running :attr:`functions` sequentially on :attr:`*inputs`
- Example:
- >>> model = nn.Sequential(...)
- >>> input_var = checkpoint_seq(model, input_var, every=2)
- """
- if use_reentrant is None:
- use_reentrant = use_reentrant_ckpt()
- def run_function(start, end, functions):
- def forward(_x):
- for j in range(start, end + 1):
- _x = functions[j](_x)
- return _x
- return forward
- if isinstance(functions, torch.nn.Sequential):
- functions = functions.children()
- if flatten:
- functions = chain.from_iterable(functions)
- if not isinstance(functions, (tuple, list)):
- functions = tuple(functions)
- num_checkpointed = len(functions)
- if skip_last:
- num_checkpointed -= 1
- end = -1
- for start in range(0, num_checkpointed, every):
- end = min(start + every - 1, num_checkpointed - 1)
- x = torch.utils.checkpoint.checkpoint(
- run_function(start, end, functions),
- x,
- use_reentrant=use_reentrant,
- )
- if skip_last:
- return run_function(end + 1, len(functions) - 1, functions)(x)
- return x
- def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
- conv_type = conv_weight.dtype
- conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
- O, I, J, K = conv_weight.shape
- if in_chans == 1:
- if I > 3:
- assert conv_weight.shape[1] % 3 == 0
- # For models with space2depth stems
- conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
- conv_weight = conv_weight.sum(dim=2, keepdim=False)
- else:
- conv_weight = conv_weight.sum(dim=1, keepdim=True)
- elif in_chans != 3:
- if I != 3:
- raise NotImplementedError('Weight format not supported by conversion.')
- else:
- # NOTE this strategy should be better than random init, but there could be other combinations of
- # the original RGB input layer weights that'd work better for specific cases.
- repeat = int(math.ceil(in_chans / 3))
- conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
- conv_weight *= (3 / float(in_chans))
- conv_weight = conv_weight.to(conv_type)
- return conv_weight
- def reinit_non_persistent_buffers(model: nn.Module) -> List[str]:
- """Walk model and call init_non_persistent_buffers() on modules that have it.
- This reinitializes computed buffers (like RoPE frequencies, attention bias indices)
- that are marked as non-persistent and thus not saved in checkpoints. These buffers
- are typically computed from module configuration and need to be reinitialized after
- loading a checkpoint.
- Args:
- model: Model to reinitialize buffers for
- Returns:
- List of module names that were reinitialized
- Example:
- >>> model = create_model('vit_base', pretrained=True)
- >>> # After loading checkpoint or moving to new device
- >>> reinitialized = reinit_non_persistent_buffers(model)
- >>> print(f"Reinitialized {len(reinitialized)} modules")
- """
- reinitialized = []
- for name, module in model.named_modules():
- if hasattr(module, 'init_non_persistent_buffers'):
- module.init_non_persistent_buffers()
- reinitialized.append(name if name else '(root)')
- return reinitialized
|