classifier.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. """ Classifier head and layer factory
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. from collections import OrderedDict
  5. from functools import partial
  6. from typing import Optional, Union, Callable
  7. import torch
  8. import torch.nn as nn
  9. from torch.nn import functional as F
  10. from .adaptive_avgmax_pool import SelectAdaptivePool2d
  11. from .create_act import get_act_layer
  12. from .create_norm import get_norm_layer
  13. def _create_pool(
  14. num_features: int,
  15. num_classes: int,
  16. pool_type: str = 'avg',
  17. use_conv: bool = False,
  18. input_fmt: Optional[str] = None,
  19. ):
  20. flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
  21. if not pool_type:
  22. flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
  23. global_pool = SelectAdaptivePool2d(
  24. pool_type=pool_type,
  25. flatten=flatten_in_pool,
  26. input_fmt=input_fmt,
  27. )
  28. num_pooled_features = num_features * global_pool.feat_mult()
  29. return global_pool, num_pooled_features
  30. def _create_fc(num_features, num_classes, use_conv=False, device=None, dtype=None):
  31. if num_classes <= 0:
  32. fc = nn.Identity() # pass-through (no classifier)
  33. elif use_conv:
  34. fc = nn.Conv2d(num_features, num_classes, 1, bias=True, device=device, dtype=dtype)
  35. else:
  36. fc = nn.Linear(num_features, num_classes, bias=True, device=device, dtype=dtype)
  37. return fc
  38. def create_classifier(
  39. num_features: int,
  40. num_classes: int,
  41. pool_type: str = 'avg',
  42. use_conv: bool = False,
  43. input_fmt: str = 'NCHW',
  44. drop_rate: Optional[float] = None,
  45. device=None,
  46. dtype=None,
  47. ):
  48. global_pool, num_pooled_features = _create_pool(
  49. num_features,
  50. num_classes,
  51. pool_type,
  52. use_conv=use_conv,
  53. input_fmt=input_fmt,
  54. )
  55. fc = _create_fc(
  56. num_pooled_features,
  57. num_classes,
  58. use_conv=use_conv,
  59. device=device,
  60. dtype=dtype,
  61. )
  62. if drop_rate is not None:
  63. dropout = nn.Dropout(drop_rate)
  64. return global_pool, dropout, fc
  65. return global_pool, fc
  66. class ClassifierHead(nn.Module):
  67. """Classifier head w/ configurable global pooling and dropout."""
  68. def __init__(
  69. self,
  70. in_features: int,
  71. num_classes: int,
  72. pool_type: str = 'avg',
  73. drop_rate: float = 0.,
  74. use_conv: bool = False,
  75. input_fmt: str = 'NCHW',
  76. device=None,
  77. dtype=None,
  78. ):
  79. """
  80. Args:
  81. in_features: The number of input features.
  82. num_classes: The number of classes for the final classifier layer (output).
  83. pool_type: Global pooling type, pooling disabled if empty string ('').
  84. drop_rate: Pre-classifier dropout rate.
  85. """
  86. super().__init__()
  87. self.in_features = in_features
  88. self.use_conv = use_conv
  89. self.input_fmt = input_fmt
  90. global_pool, fc = create_classifier(
  91. in_features,
  92. num_classes,
  93. pool_type,
  94. use_conv=use_conv,
  95. input_fmt=input_fmt,
  96. device=device,
  97. dtype=dtype,
  98. )
  99. self.global_pool = global_pool
  100. self.drop = nn.Dropout(drop_rate)
  101. self.fc = fc
  102. self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
  103. def reset(self, num_classes: int, pool_type: Optional[str] = None):
  104. # FIXME get current device/dtype for reset?
  105. if pool_type is not None and pool_type != self.global_pool.pool_type:
  106. self.global_pool, self.fc = create_classifier(
  107. self.in_features,
  108. num_classes,
  109. pool_type=pool_type,
  110. use_conv=self.use_conv,
  111. input_fmt=self.input_fmt,
  112. )
  113. self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
  114. else:
  115. num_pooled_features = self.in_features * self.global_pool.feat_mult()
  116. self.fc = _create_fc(
  117. num_pooled_features,
  118. num_classes,
  119. use_conv=self.use_conv,
  120. )
  121. def forward(self, x, pre_logits: bool = False):
  122. x = self.global_pool(x)
  123. x = self.drop(x)
  124. if pre_logits:
  125. return self.flatten(x)
  126. x = self.fc(x)
  127. return self.flatten(x)
  128. class NormMlpClassifierHead(nn.Module):
  129. """ A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
  130. """
  131. def __init__(
  132. self,
  133. in_features: int,
  134. num_classes: int,
  135. hidden_size: Optional[int] = None,
  136. pool_type: str = 'avg',
  137. drop_rate: float = 0.,
  138. norm_layer: Union[str, Callable] = 'layernorm2d',
  139. act_layer: Union[str, Callable] = 'tanh',
  140. device=None,
  141. dtype=None
  142. ):
  143. """
  144. Args:
  145. in_features: The number of input features.
  146. num_classes: The number of classes for the final classifier layer (output).
  147. hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
  148. pool_type: Global pooling type, pooling disabled if empty string ('').
  149. drop_rate: Pre-classifier dropout rate.
  150. norm_layer: Normalization layer type.
  151. act_layer: MLP activation layer type (only used if hidden_size is not None).
  152. """
  153. dd = {'device': device, 'dtype': dtype}
  154. super().__init__()
  155. self.in_features = in_features
  156. self.hidden_size = hidden_size
  157. self.num_features = in_features
  158. self.use_conv = not pool_type
  159. norm_layer = get_norm_layer(norm_layer)
  160. act_layer = get_act_layer(act_layer)
  161. linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
  162. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
  163. self.norm = norm_layer(in_features, **dd)
  164. self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
  165. if hidden_size:
  166. self.pre_logits = nn.Sequential(OrderedDict([
  167. ('fc', linear_layer(in_features, hidden_size, **dd)),
  168. ('act', act_layer()),
  169. ]))
  170. self.num_features = hidden_size
  171. else:
  172. self.pre_logits = nn.Identity()
  173. self.drop = nn.Dropout(drop_rate)
  174. self.fc = linear_layer(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  175. def reset(self, num_classes: int, pool_type: Optional[str] = None):
  176. # FIXME handle device/dtype on reset
  177. if pool_type is not None:
  178. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
  179. self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
  180. self.use_conv = self.global_pool.is_identity()
  181. linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
  182. if self.hidden_size:
  183. if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
  184. (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
  185. with torch.no_grad():
  186. new_fc = linear_layer(self.in_features, self.hidden_size)
  187. new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
  188. new_fc.bias.copy_(self.pre_logits.fc.bias)
  189. self.pre_logits.fc = new_fc
  190. self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  191. def forward(self, x, pre_logits: bool = False):
  192. x = self.global_pool(x)
  193. x = self.norm(x)
  194. x = self.flatten(x)
  195. x = self.pre_logits(x)
  196. x = self.drop(x)
  197. if pre_logits:
  198. return x
  199. x = self.fc(x)
  200. return x
  201. class ClNormMlpClassifierHead(nn.Module):
  202. """ A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
  203. """
  204. def __init__(
  205. self,
  206. in_features: int,
  207. num_classes: int,
  208. hidden_size: Optional[int] = None,
  209. pool_type: str = 'avg',
  210. drop_rate: float = 0.,
  211. norm_layer: Union[str, Callable] = 'layernorm',
  212. act_layer: Union[str, Callable] = 'gelu',
  213. input_fmt: str = 'NHWC',
  214. device=None,
  215. dtype=None,
  216. ):
  217. """
  218. Args:
  219. in_features: The number of input features.
  220. num_classes: The number of classes for the final classifier layer (output).
  221. hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
  222. pool_type: Global pooling type, pooling disabled if empty string ('').
  223. drop_rate: Pre-classifier dropout rate.
  224. norm_layer: Normalization layer type.
  225. act_layer: MLP activation layer type (only used if hidden_size is not None).
  226. """
  227. dd = {'device': device, 'dtype': dtype}
  228. super().__init__()
  229. self.in_features = in_features
  230. self.hidden_size = hidden_size
  231. self.num_features = in_features
  232. assert pool_type in ('', 'avg', 'max', 'avgmax')
  233. self.pool_type = pool_type
  234. assert input_fmt in ('NHWC', 'NLC')
  235. self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
  236. norm_layer = get_norm_layer(norm_layer)
  237. act_layer = get_act_layer(act_layer)
  238. self.norm = norm_layer(in_features, **dd)
  239. if hidden_size:
  240. self.pre_logits = nn.Sequential(OrderedDict([
  241. ('fc', nn.Linear(in_features, hidden_size, **dd)),
  242. ('act', act_layer()),
  243. ]))
  244. self.num_features = hidden_size
  245. else:
  246. self.pre_logits = nn.Identity()
  247. self.drop = nn.Dropout(drop_rate)
  248. self.fc = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  249. def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
  250. # FIXME extract dd on reset
  251. if pool_type is not None:
  252. self.pool_type = pool_type
  253. if reset_other:
  254. self.pre_logits = nn.Identity()
  255. self.norm = nn.Identity()
  256. self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  257. def _global_pool(self, x):
  258. if self.pool_type:
  259. if self.pool_type == 'avg':
  260. x = x.mean(dim=self.pool_dim)
  261. elif self.pool_type == 'max':
  262. x = x.amax(dim=self.pool_dim)
  263. elif self.pool_type == 'avgmax':
  264. x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
  265. return x
  266. def forward(self, x, pre_logits: bool = False):
  267. x = self._global_pool(x)
  268. x = self.norm(x)
  269. x = self.pre_logits(x)
  270. x = self.drop(x)
  271. if pre_logits:
  272. return x
  273. x = self.fc(x)
  274. return x