xception.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. """
  2. Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)
  3. @author: tstandley
  4. Adapted by cadene
  5. Creates an Xception Model as defined in:
  6. Francois Chollet
  7. Xception: Deep Learning with Depthwise Separable Convolutions
  8. https://arxiv.org/pdf/1610.02357.pdf
  9. This weights ported from the Keras implementation. Achieves the following performance on the validation set:
  10. Loss:0.9173 Prec@1:78.892 Prec@5:94.292
  11. REMEMBER to set your image size to 3x299x299 for both test and validation
  12. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
  13. std=[0.5, 0.5, 0.5])
  14. The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
  15. """
  16. import torch.jit
  17. import torch.nn as nn
  18. import torch.nn.functional as F
  19. from typing import Optional
  20. from timm.layers import create_classifier
  21. from ._builder import build_model_with_cfg
  22. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  23. __all__ = ['Xception']
  24. class SeparableConv2d(nn.Module):
  25. def __init__(
  26. self,
  27. in_channels: int,
  28. out_channels: int,
  29. kernel_size: int = 1,
  30. stride: int = 1,
  31. padding: int = 0,
  32. dilation: int = 1,
  33. device=None,
  34. dtype=None,
  35. ):
  36. dd = {'device': device, 'dtype': dtype}
  37. super().__init__()
  38. self.conv1 = nn.Conv2d(
  39. in_channels,
  40. in_channels,
  41. kernel_size,
  42. stride,
  43. padding,
  44. dilation,
  45. groups=in_channels,
  46. bias=False,
  47. **dd,
  48. )
  49. self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False, **dd)
  50. def forward(self, x):
  51. x = self.conv1(x)
  52. x = self.pointwise(x)
  53. return x
  54. class Block(nn.Module):
  55. def __init__(
  56. self,
  57. in_channels: int,
  58. out_channels: int,
  59. reps: int,
  60. strides: int = 1,
  61. start_with_relu: bool = True,
  62. grow_first: bool = True,
  63. device=None,
  64. dtype=None,
  65. ):
  66. dd = {'device': device, 'dtype': dtype}
  67. super().__init__()
  68. if out_channels != in_channels or strides != 1:
  69. self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False, **dd)
  70. self.skipbn = nn.BatchNorm2d(out_channels, **dd)
  71. else:
  72. self.skip = None
  73. rep = []
  74. for i in range(reps):
  75. if grow_first:
  76. inc = in_channels if i == 0 else out_channels
  77. outc = out_channels
  78. else:
  79. inc = in_channels
  80. outc = in_channels if i < (reps - 1) else out_channels
  81. rep.append(nn.ReLU(inplace=True))
  82. rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1, **dd))
  83. rep.append(nn.BatchNorm2d(outc, **dd))
  84. if not start_with_relu:
  85. rep = rep[1:]
  86. else:
  87. rep[0] = nn.ReLU(inplace=False)
  88. if strides != 1:
  89. rep.append(nn.MaxPool2d(3, strides, 1))
  90. self.rep = nn.Sequential(*rep)
  91. def forward(self, inp):
  92. x = self.rep(inp)
  93. if self.skip is not None:
  94. skip = self.skip(inp)
  95. skip = self.skipbn(skip)
  96. else:
  97. skip = inp
  98. x += skip
  99. return x
  100. class Xception(nn.Module):
  101. """
  102. Xception optimized for the ImageNet dataset, as specified in
  103. https://arxiv.org/pdf/1610.02357.pdf
  104. """
  105. def __init__(
  106. self,
  107. num_classes: int = 1000,
  108. in_chans: int = 3,
  109. drop_rate: float = 0.,
  110. global_pool: str = 'avg',
  111. device=None,
  112. dtype=None,
  113. ):
  114. """ Constructor
  115. Args:
  116. num_classes: number of classes
  117. """
  118. super().__init__()
  119. dd = {'device': device, 'dtype': dtype}
  120. self.drop_rate = drop_rate
  121. self.global_pool = global_pool
  122. self.num_classes = num_classes
  123. self.in_chans = in_chans
  124. self.num_features = self.head_hidden_size = 2048
  125. self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False, **dd)
  126. self.bn1 = nn.BatchNorm2d(32, **dd)
  127. self.act1 = nn.ReLU(inplace=True)
  128. self.conv2 = nn.Conv2d(32, 64, 3, bias=False, **dd)
  129. self.bn2 = nn.BatchNorm2d(64, **dd)
  130. self.act2 = nn.ReLU(inplace=True)
  131. self.block1 = Block(64, 128, 2, 2, start_with_relu=False, **dd)
  132. self.block2 = Block(128, 256, 2, 2, **dd)
  133. self.block3 = Block(256, 728, 2, 2, **dd)
  134. self.block4 = Block(728, 728, 3, 1, **dd)
  135. self.block5 = Block(728, 728, 3, 1, **dd)
  136. self.block6 = Block(728, 728, 3, 1, **dd)
  137. self.block7 = Block(728, 728, 3, 1, **dd)
  138. self.block8 = Block(728, 728, 3, 1, **dd)
  139. self.block9 = Block(728, 728, 3, 1, **dd)
  140. self.block10 = Block(728, 728, 3, 1, **dd)
  141. self.block11 = Block(728, 728, 3, 1, **dd)
  142. self.block12 = Block(728, 1024, 2, 2, grow_first=False, **dd)
  143. self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1, **dd)
  144. self.bn3 = nn.BatchNorm2d(1536, **dd)
  145. self.act3 = nn.ReLU(inplace=True)
  146. self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1, **dd)
  147. self.bn4 = nn.BatchNorm2d(self.num_features, **dd)
  148. self.act4 = nn.ReLU(inplace=True)
  149. self.feature_info = [
  150. dict(num_chs=64, reduction=2, module='act2'),
  151. dict(num_chs=128, reduction=4, module='block2.rep.0'),
  152. dict(num_chs=256, reduction=8, module='block3.rep.0'),
  153. dict(num_chs=728, reduction=16, module='block12.rep.0'),
  154. dict(num_chs=2048, reduction=32, module='act4'),
  155. ]
  156. self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd)
  157. # #------- init weights --------
  158. for m in self.modules():
  159. if isinstance(m, nn.Conv2d):
  160. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  161. elif isinstance(m, nn.BatchNorm2d):
  162. m.weight.data.fill_(1)
  163. m.bias.data.zero_()
  164. @torch.jit.ignore
  165. def group_matcher(self, coarse=False):
  166. return dict(
  167. stem=r'^conv[12]|bn[12]',
  168. blocks=[
  169. (r'^block(\d+)', None),
  170. (r'^conv[34]|bn[34]', (99,)),
  171. ],
  172. )
  173. @torch.jit.ignore
  174. def set_grad_checkpointing(self, enable=True):
  175. assert not enable, "gradient checkpointing not supported"
  176. @torch.jit.ignore
  177. def get_classifier(self) -> nn.Module:
  178. return self.fc
  179. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  180. self.num_classes = num_classes
  181. self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
  182. def forward_features(self, x):
  183. x = self.conv1(x)
  184. x = self.bn1(x)
  185. x = self.act1(x)
  186. x = self.conv2(x)
  187. x = self.bn2(x)
  188. x = self.act2(x)
  189. x = self.block1(x)
  190. x = self.block2(x)
  191. x = self.block3(x)
  192. x = self.block4(x)
  193. x = self.block5(x)
  194. x = self.block6(x)
  195. x = self.block7(x)
  196. x = self.block8(x)
  197. x = self.block9(x)
  198. x = self.block10(x)
  199. x = self.block11(x)
  200. x = self.block12(x)
  201. x = self.conv3(x)
  202. x = self.bn3(x)
  203. x = self.act3(x)
  204. x = self.conv4(x)
  205. x = self.bn4(x)
  206. x = self.act4(x)
  207. return x
  208. def forward_head(self, x, pre_logits: bool = False):
  209. x = self.global_pool(x)
  210. if self.drop_rate:
  211. F.dropout(x, self.drop_rate, training=self.training)
  212. return x if pre_logits else self.fc(x)
  213. def forward(self, x):
  214. x = self.forward_features(x)
  215. x = self.forward_head(x)
  216. return x
  217. def _xception(variant, pretrained=False, **kwargs):
  218. return build_model_with_cfg(
  219. Xception, variant, pretrained,
  220. feature_cfg=dict(feature_cls='hook'),
  221. **kwargs)
  222. default_cfgs = generate_default_cfgs({
  223. 'legacy_xception.tf_in1k': {
  224. 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
  225. 'input_size': (3, 299, 299),
  226. 'pool_size': (10, 10),
  227. 'crop_pct': 0.8975,
  228. 'interpolation': 'bicubic',
  229. 'mean': (0.5, 0.5, 0.5),
  230. 'std': (0.5, 0.5, 0.5),
  231. 'num_classes': 1000,
  232. 'first_conv': 'conv1',
  233. 'classifier': 'fc',
  234. 'license': 'apache-2.0',
  235. # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
  236. }
  237. })
  238. @register_model
  239. def legacy_xception(pretrained=False, **kwargs) -> Xception:
  240. return _xception('legacy_xception', pretrained=pretrained, **kwargs)
  241. register_model_deprecations(__name__, {
  242. 'xception': 'legacy_xception',
  243. })