| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- """ Split BatchNorm
- A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
- a separate BN layer. The first split is passed through the parent BN layers with weight/bias
- keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
- namespace.
- This allows easily removing the auxiliary BN layers after training to efficiently
- achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
- 'Disentangled Learning via An Auxiliary BN'
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import torch
- import torch.nn as nn
- class SplitBatchNorm2d(torch.nn.BatchNorm2d):
- def __init__(
- self,
- num_features,
- eps=1e-5,
- momentum=0.1,
- affine=True,
- track_running_stats=True,
- num_splits=2,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__(num_features, eps, momentum, affine, track_running_stats)
- assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
- self.num_splits = num_splits
- self.aux_bn = nn.ModuleList([
- nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, **dd)
- for _ in range(num_splits - 1)
- ])
- def forward(self, input: torch.Tensor):
- if self.training: # aux BN only relevant while training
- split_size = input.shape[0] // self.num_splits
- assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
- split_input = input.split(split_size)
- x = [super().forward(split_input[0])]
- for i, a in enumerate(self.aux_bn):
- x.append(a(split_input[i + 1]))
- return torch.cat(x, dim=0)
- else:
- return super().forward(input)
- def convert_splitbn_model(module, num_splits=2):
- """
- Recursively traverse module and its children to replace all instances of
- ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
- Args:
- module (torch.nn.Module): input module
- num_splits: number of separate batchnorm layers to split input across
- Example::
- >>> # model is an instance of torch.nn.Module
- >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
- """
- mod = module
- if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
- return module
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- mod = SplitBatchNorm2d(
- module.num_features, module.eps, module.momentum, module.affine,
- module.track_running_stats, num_splits=num_splits)
- mod.running_mean = module.running_mean
- mod.running_var = module.running_var
- mod.num_batches_tracked = module.num_batches_tracked
- if module.affine:
- mod.weight.data = module.weight.data.clone().detach()
- mod.bias.data = module.bias.data.clone().detach()
- for aux in mod.aux_bn:
- aux.running_mean = module.running_mean.clone()
- aux.running_var = module.running_var.clone()
- aux.num_batches_tracked = module.num_batches_tracked.clone()
- if module.affine:
- aux.weight.data = module.weight.data.clone().detach()
- aux.bias.data = module.bias.data.clone().detach()
- for name, child in module.named_children():
- mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
- del module
- return mod
|