| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- """ Test Time Pooling (Average-Max Pool)
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import logging
- from torch import nn
- import torch.nn.functional as F
- from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
- _logger = logging.getLogger(__name__)
- class TestTimePoolHead(nn.Module):
- def __init__(self, base, original_pool=7):
- super().__init__()
- self.base = base
- self.original_pool = original_pool
- base_fc = self.base.get_classifier()
- if isinstance(base_fc, nn.Conv2d):
- self.fc = base_fc
- else:
- self.fc = nn.Conv2d(
- self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
- self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
- self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
- self.base.reset_classifier(0) # delete original fc layer
- def forward(self, x):
- x = self.base.forward_features(x)
- x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
- x = self.fc(x)
- x = adaptive_avgmax_pool2d(x, 1)
- return x.view(x.size(0), -1)
- def apply_test_time_pool(model, config, use_test_size=False):
- test_time_pool = False
- if not hasattr(model, 'default_cfg') or not model.default_cfg:
- return model, False
- if use_test_size and 'test_input_size' in model.default_cfg:
- df_input_size = model.default_cfg['test_input_size']
- else:
- df_input_size = model.default_cfg['input_size']
- if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
- _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
- (str(config['input_size'][-2:]), str(df_input_size[-2:])))
- model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
- test_time_pool = True
- return model, test_time_pool
|