| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import os
- import pkgutil
- from copy import deepcopy
- from torch import nn as nn
- from timm.layers import Conv2dSame, BatchNormAct2d, Linear
- __all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
- def extract_layer(model, layer):
- """Extract a layer from a model using dot-separated path.
- Args:
- model: PyTorch model.
- layer: Dot-separated layer path (e.g., 'layer1.0.conv1').
- Returns:
- Extracted module.
- """
- layer = layer.split('.')
- module = model
- if hasattr(model, 'module') and layer[0] != 'module':
- module = model.module
- if not hasattr(model, 'module') and layer[0] == 'module':
- layer = layer[1:]
- for l in layer:
- if hasattr(module, l):
- if not l.isdigit():
- module = getattr(module, l)
- else:
- module = module[int(l)]
- else:
- return module
- return module
- def set_layer(model, layer, val):
- """Set a layer in a model using dot-separated path.
- Args:
- model: PyTorch model.
- layer: Dot-separated layer path.
- val: New value for the layer.
- """
- layer = layer.split('.')
- module = model
- if hasattr(model, 'module') and layer[0] != 'module':
- module = model.module
- lst_index = 0
- module2 = module
- for l in layer:
- if hasattr(module2, l):
- if not l.isdigit():
- module2 = getattr(module2, l)
- else:
- module2 = module2[int(l)]
- lst_index += 1
- lst_index -= 1
- for l in layer[:lst_index]:
- if not l.isdigit():
- module = getattr(module, l)
- else:
- module = module[int(l)]
- l = layer[lst_index]
- setattr(module, l, val)
- def adapt_model_from_string(parent_module, model_string):
- """Adapt a model to pruned structure from string specification.
- Args:
- parent_module: Original model to adapt.
- model_string: String containing layer shapes for pruned model.
- Returns:
- Adapted model with pruned layer dimensions.
- """
- separator = '***'
- state_dict = {}
- lst_shape = model_string.split(separator)
- for k in lst_shape:
- k = k.split(':')
- key = k[0]
- shape = k[1][1:-1].split(',')
- if shape[0] != '':
- state_dict[key] = [int(i) for i in shape]
- # Extract device and dtype from the parent module
- device = next(parent_module.parameters()).device
- dtype = next(parent_module.parameters()).dtype
- dd = {'device': device, 'dtype': dtype}
- new_module = deepcopy(parent_module)
- for n, m in parent_module.named_modules():
- old_module = extract_layer(parent_module, n)
- if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
- if isinstance(old_module, Conv2dSame):
- conv = Conv2dSame
- else:
- conv = nn.Conv2d
- s = state_dict[n + '.weight']
- in_channels = s[1]
- out_channels = s[0]
- g = 1
- if old_module.groups > 1:
- in_channels = out_channels
- g = in_channels
- new_conv = conv(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=old_module.kernel_size,
- bias=old_module.bias is not None,
- padding=old_module.padding,
- dilation=old_module.dilation,
- groups=g,
- stride=old_module.stride,
- **dd,
- )
- set_layer(new_module, n, new_conv)
- elif isinstance(old_module, BatchNormAct2d):
- new_bn = BatchNormAct2d(
- state_dict[n + '.weight'][0],
- eps=old_module.eps,
- momentum=old_module.momentum,
- affine=old_module.affine,
- track_running_stats=True,
- **dd,
- )
- new_bn.drop = old_module.drop
- new_bn.act = old_module.act
- set_layer(new_module, n, new_bn)
- elif isinstance(old_module, nn.BatchNorm2d):
- new_bn = nn.BatchNorm2d(
- num_features=state_dict[n + '.weight'][0],
- eps=old_module.eps,
- momentum=old_module.momentum,
- affine=old_module.affine,
- track_running_stats=True,
- **dd,
- )
- set_layer(new_module, n, new_bn)
- elif isinstance(old_module, nn.Linear):
- # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
- num_features = state_dict[n + '.weight'][1]
- new_fc = Linear(
- in_features=num_features,
- out_features=old_module.out_features,
- bias=old_module.bias is not None,
- **dd,
- )
- set_layer(new_module, n, new_fc)
- if hasattr(new_module, 'num_features'):
- if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features:
- new_module.head_hidden_size = num_features
- new_module.num_features = num_features
- new_module.eval()
- parent_module.eval()
- return new_module
- def adapt_model_from_file(parent_module, model_variant):
- """Adapt a model to pruned structure from file specification.
- Args:
- parent_module: Original model to adapt.
- model_variant: Name of pruned model variant file.
- Returns:
- Adapted model with pruned layer dimensions.
- """
- adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
- return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
|