split_batchnorm.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """ Split BatchNorm
  2. A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
  3. a separate BN layer. The first split is passed through the parent BN layers with weight/bias
  4. keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
  5. namespace.
  6. This allows easily removing the auxiliary BN layers after training to efficiently
  7. achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
  8. 'Disentangled Learning via An Auxiliary BN'
  9. Hacked together by / Copyright 2020 Ross Wightman
  10. """
  11. import torch
  12. import torch.nn as nn
  13. class SplitBatchNorm2d(torch.nn.BatchNorm2d):
  14. def __init__(
  15. self,
  16. num_features,
  17. eps=1e-5,
  18. momentum=0.1,
  19. affine=True,
  20. track_running_stats=True,
  21. num_splits=2,
  22. device=None,
  23. dtype=None,
  24. ):
  25. dd = {'device': device, 'dtype': dtype}
  26. super().__init__(num_features, eps, momentum, affine, track_running_stats)
  27. assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
  28. self.num_splits = num_splits
  29. self.aux_bn = nn.ModuleList([
  30. nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, **dd)
  31. for _ in range(num_splits - 1)
  32. ])
  33. def forward(self, input: torch.Tensor):
  34. if self.training: # aux BN only relevant while training
  35. split_size = input.shape[0] // self.num_splits
  36. assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
  37. split_input = input.split(split_size)
  38. x = [super().forward(split_input[0])]
  39. for i, a in enumerate(self.aux_bn):
  40. x.append(a(split_input[i + 1]))
  41. return torch.cat(x, dim=0)
  42. else:
  43. return super().forward(input)
  44. def convert_splitbn_model(module, num_splits=2):
  45. """
  46. Recursively traverse module and its children to replace all instances of
  47. ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
  48. Args:
  49. module (torch.nn.Module): input module
  50. num_splits: number of separate batchnorm layers to split input across
  51. Example::
  52. >>> # model is an instance of torch.nn.Module
  53. >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
  54. """
  55. mod = module
  56. if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
  57. return module
  58. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  59. mod = SplitBatchNorm2d(
  60. module.num_features, module.eps, module.momentum, module.affine,
  61. module.track_running_stats, num_splits=num_splits)
  62. mod.running_mean = module.running_mean
  63. mod.running_var = module.running_var
  64. mod.num_batches_tracked = module.num_batches_tracked
  65. if module.affine:
  66. mod.weight.data = module.weight.data.clone().detach()
  67. mod.bias.data = module.bias.data.clone().detach()
  68. for aux in mod.aux_bn:
  69. aux.running_mean = module.running_mean.clone()
  70. aux.running_var = module.running_var.clone()
  71. aux.num_batches_tracked = module.num_batches_tracked.clone()
  72. if module.affine:
  73. aux.weight.data = module.weight.data.clone().detach()
  74. aux.bias.data = module.bias.data.clone().detach()
  75. for name, child in module.named_children():
  76. mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
  77. del module
  78. return mod