| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- """ Classifier head and layer factory
- Hacked together by / Copyright 2020 Ross Wightman
- """
- from collections import OrderedDict
- from functools import partial
- from typing import Optional, Union, Callable
- import torch
- import torch.nn as nn
- from torch.nn import functional as F
- from .adaptive_avgmax_pool import SelectAdaptivePool2d
- from .create_act import get_act_layer
- from .create_norm import get_norm_layer
- def _create_pool(
- num_features: int,
- num_classes: int,
- pool_type: str = 'avg',
- use_conv: bool = False,
- input_fmt: Optional[str] = None,
- ):
- flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
- if not pool_type:
- flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
- global_pool = SelectAdaptivePool2d(
- pool_type=pool_type,
- flatten=flatten_in_pool,
- input_fmt=input_fmt,
- )
- num_pooled_features = num_features * global_pool.feat_mult()
- return global_pool, num_pooled_features
- def _create_fc(num_features, num_classes, use_conv=False, device=None, dtype=None):
- if num_classes <= 0:
- fc = nn.Identity() # pass-through (no classifier)
- elif use_conv:
- fc = nn.Conv2d(num_features, num_classes, 1, bias=True, device=device, dtype=dtype)
- else:
- fc = nn.Linear(num_features, num_classes, bias=True, device=device, dtype=dtype)
- return fc
- def create_classifier(
- num_features: int,
- num_classes: int,
- pool_type: str = 'avg',
- use_conv: bool = False,
- input_fmt: str = 'NCHW',
- drop_rate: Optional[float] = None,
- device=None,
- dtype=None,
- ):
- global_pool, num_pooled_features = _create_pool(
- num_features,
- num_classes,
- pool_type,
- use_conv=use_conv,
- input_fmt=input_fmt,
- )
- fc = _create_fc(
- num_pooled_features,
- num_classes,
- use_conv=use_conv,
- device=device,
- dtype=dtype,
- )
- if drop_rate is not None:
- dropout = nn.Dropout(drop_rate)
- return global_pool, dropout, fc
- return global_pool, fc
- class ClassifierHead(nn.Module):
- """Classifier head w/ configurable global pooling and dropout."""
- def __init__(
- self,
- in_features: int,
- num_classes: int,
- pool_type: str = 'avg',
- drop_rate: float = 0.,
- use_conv: bool = False,
- input_fmt: str = 'NCHW',
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_features: The number of input features.
- num_classes: The number of classes for the final classifier layer (output).
- pool_type: Global pooling type, pooling disabled if empty string ('').
- drop_rate: Pre-classifier dropout rate.
- """
- super().__init__()
- self.in_features = in_features
- self.use_conv = use_conv
- self.input_fmt = input_fmt
- global_pool, fc = create_classifier(
- in_features,
- num_classes,
- pool_type,
- use_conv=use_conv,
- input_fmt=input_fmt,
- device=device,
- dtype=dtype,
- )
- self.global_pool = global_pool
- self.drop = nn.Dropout(drop_rate)
- self.fc = fc
- self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
- def reset(self, num_classes: int, pool_type: Optional[str] = None):
- # FIXME get current device/dtype for reset?
- if pool_type is not None and pool_type != self.global_pool.pool_type:
- self.global_pool, self.fc = create_classifier(
- self.in_features,
- num_classes,
- pool_type=pool_type,
- use_conv=self.use_conv,
- input_fmt=self.input_fmt,
- )
- self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
- else:
- num_pooled_features = self.in_features * self.global_pool.feat_mult()
- self.fc = _create_fc(
- num_pooled_features,
- num_classes,
- use_conv=self.use_conv,
- )
- def forward(self, x, pre_logits: bool = False):
- x = self.global_pool(x)
- x = self.drop(x)
- if pre_logits:
- return self.flatten(x)
- x = self.fc(x)
- return self.flatten(x)
- class NormMlpClassifierHead(nn.Module):
- """ A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
- """
- def __init__(
- self,
- in_features: int,
- num_classes: int,
- hidden_size: Optional[int] = None,
- pool_type: str = 'avg',
- drop_rate: float = 0.,
- norm_layer: Union[str, Callable] = 'layernorm2d',
- act_layer: Union[str, Callable] = 'tanh',
- device=None,
- dtype=None
- ):
- """
- Args:
- in_features: The number of input features.
- num_classes: The number of classes for the final classifier layer (output).
- hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
- pool_type: Global pooling type, pooling disabled if empty string ('').
- drop_rate: Pre-classifier dropout rate.
- norm_layer: Normalization layer type.
- act_layer: MLP activation layer type (only used if hidden_size is not None).
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.in_features = in_features
- self.hidden_size = hidden_size
- self.num_features = in_features
- self.use_conv = not pool_type
- norm_layer = get_norm_layer(norm_layer)
- act_layer = get_act_layer(act_layer)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
- self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
- self.norm = norm_layer(in_features, **dd)
- self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
- if hidden_size:
- self.pre_logits = nn.Sequential(OrderedDict([
- ('fc', linear_layer(in_features, hidden_size, **dd)),
- ('act', act_layer()),
- ]))
- self.num_features = hidden_size
- else:
- self.pre_logits = nn.Identity()
- self.drop = nn.Dropout(drop_rate)
- self.fc = linear_layer(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
- def reset(self, num_classes: int, pool_type: Optional[str] = None):
- # FIXME handle device/dtype on reset
- if pool_type is not None:
- self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
- self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
- self.use_conv = self.global_pool.is_identity()
- linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
- if self.hidden_size:
- if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
- (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
- with torch.no_grad():
- new_fc = linear_layer(self.in_features, self.hidden_size)
- new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
- new_fc.bias.copy_(self.pre_logits.fc.bias)
- self.pre_logits.fc = new_fc
- self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
- def forward(self, x, pre_logits: bool = False):
- x = self.global_pool(x)
- x = self.norm(x)
- x = self.flatten(x)
- x = self.pre_logits(x)
- x = self.drop(x)
- if pre_logits:
- return x
- x = self.fc(x)
- return x
- class ClNormMlpClassifierHead(nn.Module):
- """ A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
- """
- def __init__(
- self,
- in_features: int,
- num_classes: int,
- hidden_size: Optional[int] = None,
- pool_type: str = 'avg',
- drop_rate: float = 0.,
- norm_layer: Union[str, Callable] = 'layernorm',
- act_layer: Union[str, Callable] = 'gelu',
- input_fmt: str = 'NHWC',
- device=None,
- dtype=None,
- ):
- """
- Args:
- in_features: The number of input features.
- num_classes: The number of classes for the final classifier layer (output).
- hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
- pool_type: Global pooling type, pooling disabled if empty string ('').
- drop_rate: Pre-classifier dropout rate.
- norm_layer: Normalization layer type.
- act_layer: MLP activation layer type (only used if hidden_size is not None).
- """
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.in_features = in_features
- self.hidden_size = hidden_size
- self.num_features = in_features
- assert pool_type in ('', 'avg', 'max', 'avgmax')
- self.pool_type = pool_type
- assert input_fmt in ('NHWC', 'NLC')
- self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
- norm_layer = get_norm_layer(norm_layer)
- act_layer = get_act_layer(act_layer)
- self.norm = norm_layer(in_features, **dd)
- if hidden_size:
- self.pre_logits = nn.Sequential(OrderedDict([
- ('fc', nn.Linear(in_features, hidden_size, **dd)),
- ('act', act_layer()),
- ]))
- self.num_features = hidden_size
- else:
- self.pre_logits = nn.Identity()
- self.drop = nn.Dropout(drop_rate)
- self.fc = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
- def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
- # FIXME extract dd on reset
- if pool_type is not None:
- self.pool_type = pool_type
- if reset_other:
- self.pre_logits = nn.Identity()
- self.norm = nn.Identity()
- self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
- def _global_pool(self, x):
- if self.pool_type:
- if self.pool_type == 'avg':
- x = x.mean(dim=self.pool_dim)
- elif self.pool_type == 'max':
- x = x.amax(dim=self.pool_dim)
- elif self.pool_type == 'avgmax':
- x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
- return x
- def forward(self, x, pre_logits: bool = False):
- x = self._global_pool(x)
- x = self.norm(x)
- x = self.pre_logits(x)
- x = self.drop(x)
- if pre_logits:
- return x
- x = self.fc(x)
- return x
|