test_time_pool.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. """ Test Time Pooling (Average-Max Pool)
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import logging
  5. from torch import nn
  6. import torch.nn.functional as F
  7. from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
  8. _logger = logging.getLogger(__name__)
  9. class TestTimePoolHead(nn.Module):
  10. def __init__(self, base, original_pool=7):
  11. super().__init__()
  12. self.base = base
  13. self.original_pool = original_pool
  14. base_fc = self.base.get_classifier()
  15. if isinstance(base_fc, nn.Conv2d):
  16. self.fc = base_fc
  17. else:
  18. self.fc = nn.Conv2d(
  19. self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
  20. self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
  21. self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
  22. self.base.reset_classifier(0) # delete original fc layer
  23. def forward(self, x):
  24. x = self.base.forward_features(x)
  25. x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
  26. x = self.fc(x)
  27. x = adaptive_avgmax_pool2d(x, 1)
  28. return x.view(x.size(0), -1)
  29. def apply_test_time_pool(model, config, use_test_size=False):
  30. test_time_pool = False
  31. if not hasattr(model, 'default_cfg') or not model.default_cfg:
  32. return model, False
  33. if use_test_size and 'test_input_size' in model.default_cfg:
  34. df_input_size = model.default_cfg['test_input_size']
  35. else:
  36. df_input_size = model.default_cfg['input_size']
  37. if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
  38. _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
  39. (str(config['input_size'][-2:]), str(df_input_size[-2:])))
  40. model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
  41. test_time_pool = True
  42. return model, test_time_pool