_prune.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import os
  2. import pkgutil
  3. from copy import deepcopy
  4. from torch import nn as nn
  5. from timm.layers import Conv2dSame, BatchNormAct2d, Linear
  6. __all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
  7. def extract_layer(model, layer):
  8. """Extract a layer from a model using dot-separated path.
  9. Args:
  10. model: PyTorch model.
  11. layer: Dot-separated layer path (e.g., 'layer1.0.conv1').
  12. Returns:
  13. Extracted module.
  14. """
  15. layer = layer.split('.')
  16. module = model
  17. if hasattr(model, 'module') and layer[0] != 'module':
  18. module = model.module
  19. if not hasattr(model, 'module') and layer[0] == 'module':
  20. layer = layer[1:]
  21. for l in layer:
  22. if hasattr(module, l):
  23. if not l.isdigit():
  24. module = getattr(module, l)
  25. else:
  26. module = module[int(l)]
  27. else:
  28. return module
  29. return module
  30. def set_layer(model, layer, val):
  31. """Set a layer in a model using dot-separated path.
  32. Args:
  33. model: PyTorch model.
  34. layer: Dot-separated layer path.
  35. val: New value for the layer.
  36. """
  37. layer = layer.split('.')
  38. module = model
  39. if hasattr(model, 'module') and layer[0] != 'module':
  40. module = model.module
  41. lst_index = 0
  42. module2 = module
  43. for l in layer:
  44. if hasattr(module2, l):
  45. if not l.isdigit():
  46. module2 = getattr(module2, l)
  47. else:
  48. module2 = module2[int(l)]
  49. lst_index += 1
  50. lst_index -= 1
  51. for l in layer[:lst_index]:
  52. if not l.isdigit():
  53. module = getattr(module, l)
  54. else:
  55. module = module[int(l)]
  56. l = layer[lst_index]
  57. setattr(module, l, val)
  58. def adapt_model_from_string(parent_module, model_string):
  59. """Adapt a model to pruned structure from string specification.
  60. Args:
  61. parent_module: Original model to adapt.
  62. model_string: String containing layer shapes for pruned model.
  63. Returns:
  64. Adapted model with pruned layer dimensions.
  65. """
  66. separator = '***'
  67. state_dict = {}
  68. lst_shape = model_string.split(separator)
  69. for k in lst_shape:
  70. k = k.split(':')
  71. key = k[0]
  72. shape = k[1][1:-1].split(',')
  73. if shape[0] != '':
  74. state_dict[key] = [int(i) for i in shape]
  75. # Extract device and dtype from the parent module
  76. device = next(parent_module.parameters()).device
  77. dtype = next(parent_module.parameters()).dtype
  78. dd = {'device': device, 'dtype': dtype}
  79. new_module = deepcopy(parent_module)
  80. for n, m in parent_module.named_modules():
  81. old_module = extract_layer(parent_module, n)
  82. if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
  83. if isinstance(old_module, Conv2dSame):
  84. conv = Conv2dSame
  85. else:
  86. conv = nn.Conv2d
  87. s = state_dict[n + '.weight']
  88. in_channels = s[1]
  89. out_channels = s[0]
  90. g = 1
  91. if old_module.groups > 1:
  92. in_channels = out_channels
  93. g = in_channels
  94. new_conv = conv(
  95. in_channels=in_channels,
  96. out_channels=out_channels,
  97. kernel_size=old_module.kernel_size,
  98. bias=old_module.bias is not None,
  99. padding=old_module.padding,
  100. dilation=old_module.dilation,
  101. groups=g,
  102. stride=old_module.stride,
  103. **dd,
  104. )
  105. set_layer(new_module, n, new_conv)
  106. elif isinstance(old_module, BatchNormAct2d):
  107. new_bn = BatchNormAct2d(
  108. state_dict[n + '.weight'][0],
  109. eps=old_module.eps,
  110. momentum=old_module.momentum,
  111. affine=old_module.affine,
  112. track_running_stats=True,
  113. **dd,
  114. )
  115. new_bn.drop = old_module.drop
  116. new_bn.act = old_module.act
  117. set_layer(new_module, n, new_bn)
  118. elif isinstance(old_module, nn.BatchNorm2d):
  119. new_bn = nn.BatchNorm2d(
  120. num_features=state_dict[n + '.weight'][0],
  121. eps=old_module.eps,
  122. momentum=old_module.momentum,
  123. affine=old_module.affine,
  124. track_running_stats=True,
  125. **dd,
  126. )
  127. set_layer(new_module, n, new_bn)
  128. elif isinstance(old_module, nn.Linear):
  129. # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
  130. num_features = state_dict[n + '.weight'][1]
  131. new_fc = Linear(
  132. in_features=num_features,
  133. out_features=old_module.out_features,
  134. bias=old_module.bias is not None,
  135. **dd,
  136. )
  137. set_layer(new_module, n, new_fc)
  138. if hasattr(new_module, 'num_features'):
  139. if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features:
  140. new_module.head_hidden_size = num_features
  141. new_module.num_features = num_features
  142. new_module.eval()
  143. parent_module.eval()
  144. return new_module
  145. def adapt_model_from_file(parent_module, model_variant):
  146. """Adapt a model to pruned structure from file specification.
  147. Args:
  148. parent_module: Original model to adapt.
  149. model_variant: Name of pruned model variant file.
  150. Returns:
  151. Adapted model with pruned layer dimensions.
  152. """
  153. adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
  154. return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())