dla.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. """ Deep Layer Aggregation and DLA w/ Res2Net
  2. DLA original adapted from Official Pytorch impl at: https://github.com/ucbdrive/dla
  3. DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484
  4. Res2Net additions from: https://github.com/gasvn/Res2Net/
  5. Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
  6. """
  7. import math
  8. from typing import List, Optional, Tuple, Type
  9. import torch
  10. import torch.nn as nn
  11. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  12. from timm.layers import create_classifier
  13. from ._builder import build_model_with_cfg
  14. from ._registry import register_model, generate_default_cfgs
  15. __all__ = ['DLA']
  16. class DlaBasic(nn.Module):
  17. """DLA Basic"""
  18. def __init__(
  19. self,
  20. inplanes: int,
  21. planes: int,
  22. stride: int = 1,
  23. dilation: int = 1,
  24. device=None,
  25. dtype=None,
  26. **_,
  27. ):
  28. dd = {'device': device, 'dtype': dtype}
  29. super().__init__()
  30. self.conv1 = nn.Conv2d(
  31. inplanes,
  32. planes,
  33. kernel_size=3,
  34. stride=stride,
  35. padding=dilation,
  36. bias=False,
  37. dilation=dilation,
  38. **dd,
  39. )
  40. self.bn1 = nn.BatchNorm2d(planes, **dd)
  41. self.relu = nn.ReLU(inplace=True)
  42. self.conv2 = nn.Conv2d(
  43. planes,
  44. planes,
  45. kernel_size=3,
  46. stride=1,
  47. padding=dilation,
  48. bias=False,
  49. dilation=dilation,
  50. **dd,
  51. )
  52. self.bn2 = nn.BatchNorm2d(planes, **dd)
  53. self.stride = stride
  54. def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
  55. if shortcut is None:
  56. shortcut = x
  57. out = self.conv1(x)
  58. out = self.bn1(out)
  59. out = self.relu(out)
  60. out = self.conv2(out)
  61. out = self.bn2(out)
  62. out += shortcut
  63. out = self.relu(out)
  64. return out
  65. class DlaBottleneck(nn.Module):
  66. """DLA/DLA-X Bottleneck"""
  67. expansion = 2
  68. def __init__(
  69. self,
  70. inplanes: int,
  71. outplanes: int,
  72. stride: int = 1,
  73. dilation: int = 1,
  74. cardinality: int = 1,
  75. base_width: int = 64,
  76. device=None,
  77. dtype=None,
  78. ):
  79. dd = {'device': device, 'dtype': dtype}
  80. super().__init__()
  81. self.stride = stride
  82. mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
  83. mid_planes = mid_planes // self.expansion
  84. self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False, **dd)
  85. self.bn1 = nn.BatchNorm2d(mid_planes, **dd)
  86. self.conv2 = nn.Conv2d(
  87. mid_planes,
  88. mid_planes,
  89. kernel_size=3,
  90. stride=stride,
  91. padding=dilation,
  92. bias=False,
  93. dilation=dilation,
  94. groups=cardinality,
  95. **dd,
  96. )
  97. self.bn2 = nn.BatchNorm2d(mid_planes, **dd)
  98. self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False, **dd)
  99. self.bn3 = nn.BatchNorm2d(outplanes, **dd)
  100. self.relu = nn.ReLU(inplace=True)
  101. def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
  102. if shortcut is None:
  103. shortcut = x
  104. out = self.conv1(x)
  105. out = self.bn1(out)
  106. out = self.relu(out)
  107. out = self.conv2(out)
  108. out = self.bn2(out)
  109. out = self.relu(out)
  110. out = self.conv3(out)
  111. out = self.bn3(out)
  112. out += shortcut
  113. out = self.relu(out)
  114. return out
  115. class DlaBottle2neck(nn.Module):
  116. """ Res2Net/Res2NeXT DLA Bottleneck
  117. Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py
  118. """
  119. expansion = 2
  120. def __init__(
  121. self,
  122. inplanes: int,
  123. outplanes: int,
  124. stride: int = 1,
  125. dilation: int = 1,
  126. scale: int = 4,
  127. cardinality: int = 8,
  128. base_width: int = 4,
  129. device=None,
  130. dtype=None,
  131. ):
  132. dd = {'device': device, 'dtype': dtype}
  133. super().__init__()
  134. self.is_first = stride > 1
  135. self.scale = scale
  136. mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
  137. mid_planes = mid_planes // self.expansion
  138. self.width = mid_planes
  139. self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False, **dd)
  140. self.bn1 = nn.BatchNorm2d(mid_planes * scale, **dd)
  141. num_scale_convs = max(1, scale - 1)
  142. convs = []
  143. bns = []
  144. for _ in range(num_scale_convs):
  145. convs.append(nn.Conv2d(
  146. mid_planes,
  147. mid_planes,
  148. kernel_size=3,
  149. stride=stride,
  150. padding=dilation,
  151. dilation=dilation,
  152. groups=cardinality,
  153. bias=False,
  154. **dd,
  155. ))
  156. bns.append(nn.BatchNorm2d(mid_planes, **dd))
  157. self.convs = nn.ModuleList(convs)
  158. self.bns = nn.ModuleList(bns)
  159. self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) if self.is_first else None
  160. self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False, **dd)
  161. self.bn3 = nn.BatchNorm2d(outplanes, **dd)
  162. self.relu = nn.ReLU(inplace=True)
  163. def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
  164. if shortcut is None:
  165. shortcut = x
  166. out = self.conv1(x)
  167. out = self.bn1(out)
  168. out = self.relu(out)
  169. spx = torch.split(out, self.width, 1)
  170. spo = []
  171. sp = spx[0] # redundant, for torchscript
  172. for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
  173. if i == 0 or self.is_first:
  174. sp = spx[i]
  175. else:
  176. sp = sp + spx[i]
  177. sp = conv(sp)
  178. sp = bn(sp)
  179. sp = self.relu(sp)
  180. spo.append(sp)
  181. if self.scale > 1:
  182. if self.pool is not None: # self.is_first == True, None check for torchscript
  183. spo.append(self.pool(spx[-1]))
  184. else:
  185. spo.append(spx[-1])
  186. out = torch.cat(spo, 1)
  187. out = self.conv3(out)
  188. out = self.bn3(out)
  189. out += shortcut
  190. out = self.relu(out)
  191. return out
  192. class DlaRoot(nn.Module):
  193. def __init__(
  194. self,
  195. in_channels: int,
  196. out_channels: int,
  197. kernel_size: int,
  198. shortcut: bool,
  199. device=None,
  200. dtype=None,
  201. ):
  202. dd = {'device': device, 'dtype': dtype}
  203. super().__init__()
  204. self.conv = nn.Conv2d(
  205. in_channels,
  206. out_channels,
  207. 1,
  208. stride=1,
  209. bias=False,
  210. padding=(kernel_size - 1) // 2,
  211. **dd,
  212. )
  213. self.bn = nn.BatchNorm2d(out_channels, **dd)
  214. self.relu = nn.ReLU(inplace=True)
  215. self.shortcut = shortcut
  216. def forward(self, x_children: List[torch.Tensor]):
  217. x = self.conv(torch.cat(x_children, 1))
  218. x = self.bn(x)
  219. if self.shortcut:
  220. x += x_children[0]
  221. x = self.relu(x)
  222. return x
  223. class DlaTree(nn.Module):
  224. def __init__(
  225. self,
  226. levels: int,
  227. block: Type[nn.Module],
  228. in_channels: int,
  229. out_channels: int,
  230. stride: int = 1,
  231. dilation: int = 1,
  232. cardinality: int = 1,
  233. base_width: int = 64,
  234. level_root: bool = False,
  235. root_dim: int = 0,
  236. root_kernel_size: int = 1,
  237. root_shortcut: bool = False,
  238. device=None,
  239. dtype=None,
  240. ):
  241. dd = {'device': device, 'dtype': dtype}
  242. super().__init__()
  243. if root_dim == 0:
  244. root_dim = 2 * out_channels
  245. if level_root:
  246. root_dim += in_channels
  247. self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity()
  248. self.project = nn.Identity()
  249. cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width, **dd)
  250. if levels == 1:
  251. self.tree1 = block(in_channels, out_channels, stride, **cargs)
  252. self.tree2 = block(out_channels, out_channels, 1, **cargs)
  253. if in_channels != out_channels:
  254. # NOTE the official impl/weights have project layers in levels > 1 case that are never
  255. # used, I've moved the project layer here to avoid wasted params but old checkpoints will
  256. # need strict=False while loading.
  257. self.project = nn.Sequential(
  258. nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False, **dd),
  259. nn.BatchNorm2d(out_channels, **dd))
  260. self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut, **dd)
  261. else:
  262. cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
  263. self.tree1 = DlaTree(
  264. levels - 1,
  265. block,
  266. in_channels,
  267. out_channels,
  268. stride,
  269. root_dim=0,
  270. **cargs,
  271. )
  272. self.tree2 = DlaTree(
  273. levels - 1,
  274. block,
  275. out_channels,
  276. out_channels,
  277. root_dim=root_dim + out_channels,
  278. **cargs,
  279. )
  280. self.root = None
  281. self.level_root = level_root
  282. self.root_dim = root_dim
  283. self.levels = levels
  284. def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
  285. if children is None:
  286. children = []
  287. bottom = self.downsample(x)
  288. shortcut = self.project(bottom)
  289. if self.level_root:
  290. children.append(bottom)
  291. x1 = self.tree1(x, shortcut)
  292. if self.root is not None: # levels == 1
  293. x2 = self.tree2(x1)
  294. x = self.root([x2, x1] + children)
  295. else:
  296. children.append(x1)
  297. x = self.tree2(x1, None, children)
  298. return x
  299. class DLA(nn.Module):
  300. def __init__(
  301. self,
  302. levels: Tuple[int, ...],
  303. channels: Tuple[int, ...],
  304. output_stride: int = 32,
  305. num_classes: int = 1000,
  306. in_chans: int = 3,
  307. global_pool: str = 'avg',
  308. cardinality: int = 1,
  309. base_width: int = 64,
  310. block: Type[nn.Module] = DlaBottle2neck,
  311. shortcut_root: bool = False,
  312. drop_rate: float = 0.0,
  313. device=None,
  314. dtype=None,
  315. ):
  316. dd = {'device': device, 'dtype': dtype}
  317. super().__init__()
  318. self.channels = channels
  319. self.num_classes = num_classes
  320. self.in_chans = in_chans
  321. self.cardinality = cardinality
  322. self.base_width = base_width
  323. assert output_stride == 32 # FIXME support dilation
  324. self.base_layer = nn.Sequential(
  325. nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False, **dd),
  326. nn.BatchNorm2d(channels[0], **dd),
  327. nn.ReLU(inplace=True),
  328. )
  329. self.level0 = self._make_conv_level(channels[0], channels[0], levels[0], **dd)
  330. self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2, **dd)
  331. cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root, **dd)
  332. self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
  333. self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
  334. self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
  335. self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs)
  336. self.feature_info = [
  337. dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level
  338. dict(num_chs=channels[1], reduction=2, module='level1'),
  339. dict(num_chs=channels[2], reduction=4, module='level2'),
  340. dict(num_chs=channels[3], reduction=8, module='level3'),
  341. dict(num_chs=channels[4], reduction=16, module='level4'),
  342. dict(num_chs=channels[5], reduction=32, module='level5'),
  343. ]
  344. self.num_features = self.head_hidden_size = channels[-1]
  345. self.global_pool, self.head_drop, self.fc = create_classifier(
  346. self.num_features,
  347. self.num_classes,
  348. pool_type=global_pool,
  349. use_conv=True,
  350. drop_rate=drop_rate,
  351. **dd,
  352. )
  353. self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  354. for m in self.modules():
  355. if isinstance(m, nn.Conv2d):
  356. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  357. m.weight.data.normal_(0, math.sqrt(2. / n))
  358. elif isinstance(m, nn.BatchNorm2d):
  359. m.weight.data.fill_(1)
  360. m.bias.data.zero_()
  361. def _make_conv_level(self, inplanes: int, planes: int, convs: int, stride: int = 1, dilation: int = 1, device=None, dtype=None):
  362. dd = {'device': device, 'dtype': dtype}
  363. modules = []
  364. for i in range(convs):
  365. modules.extend([
  366. nn.Conv2d(
  367. inplanes,
  368. planes,
  369. kernel_size=3,
  370. stride=stride if i == 0 else 1,
  371. padding=dilation,
  372. bias=False,
  373. dilation=dilation,
  374. **dd,
  375. ),
  376. nn.BatchNorm2d(planes, **dd),
  377. nn.ReLU(inplace=True)])
  378. inplanes = planes
  379. return nn.Sequential(*modules)
  380. @torch.jit.ignore
  381. def group_matcher(self, coarse=False):
  382. matcher = dict(
  383. stem=r'^base_layer',
  384. blocks=r'^level(\d+)' if coarse else [
  385. # an unusual arch, this achieves somewhat more granularity without getting super messy
  386. (r'^level(\d+)\.tree(\d+)', None),
  387. (r'^level(\d+)\.root', (2,)),
  388. (r'^level(\d+)', (1,))
  389. ]
  390. )
  391. return matcher
  392. @torch.jit.ignore
  393. def set_grad_checkpointing(self, enable=True):
  394. assert not enable, 'gradient checkpointing not supported'
  395. @torch.jit.ignore
  396. def get_classifier(self) -> nn.Module:
  397. return self.fc
  398. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  399. self.num_classes = num_classes
  400. self.global_pool, self.fc = create_classifier(
  401. self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
  402. self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  403. def forward_features(self, x):
  404. x = self.base_layer(x)
  405. x = self.level0(x)
  406. x = self.level1(x)
  407. x = self.level2(x)
  408. x = self.level3(x)
  409. x = self.level4(x)
  410. x = self.level5(x)
  411. return x
  412. def forward_head(self, x, pre_logits: bool = False):
  413. x = self.global_pool(x)
  414. x = self.head_drop(x)
  415. if pre_logits:
  416. return self.flatten(x)
  417. x = self.fc(x)
  418. return self.flatten(x)
  419. def forward(self, x):
  420. x = self.forward_features(x)
  421. x = self.forward_head(x)
  422. return x
  423. def _create_dla(variant, pretrained=False, **kwargs):
  424. return build_model_with_cfg(
  425. DLA,
  426. variant,
  427. pretrained,
  428. pretrained_strict=False,
  429. feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
  430. **kwargs,
  431. )
  432. def _cfg(url='', **kwargs):
  433. return {
  434. 'url': url,
  435. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  436. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  437. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  438. 'first_conv': 'base_layer.0', 'classifier': 'fc', 'license': 'bsd-3-clause',
  439. **kwargs
  440. }
  441. default_cfgs = generate_default_cfgs({
  442. 'dla34.in1k': _cfg(hf_hub_id='timm/'),
  443. 'dla46_c.in1k': _cfg(hf_hub_id='timm/'),
  444. 'dla46x_c.in1k': _cfg(hf_hub_id='timm/'),
  445. 'dla60x_c.in1k': _cfg(hf_hub_id='timm/'),
  446. 'dla60.in1k': _cfg(hf_hub_id='timm/'),
  447. 'dla60x.in1k': _cfg(hf_hub_id='timm/'),
  448. 'dla102.in1k': _cfg(hf_hub_id='timm/'),
  449. 'dla102x.in1k': _cfg(hf_hub_id='timm/'),
  450. 'dla102x2.in1k': _cfg(hf_hub_id='timm/'),
  451. 'dla169.in1k': _cfg(hf_hub_id='timm/'),
  452. 'dla60_res2net.in1k': _cfg(hf_hub_id='timm/', license='unknown'),
  453. 'dla60_res2next.in1k': _cfg(hf_hub_id='timm/', license='unknown'),
  454. })
  455. @register_model
  456. def dla60_res2net(pretrained=False, **kwargs) -> DLA:
  457. model_args = dict(
  458. levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
  459. block=DlaBottle2neck, cardinality=1, base_width=28)
  460. return _create_dla('dla60_res2net', pretrained, **dict(model_args, **kwargs))
  461. @register_model
  462. def dla60_res2next(pretrained=False,**kwargs):
  463. model_args = dict(
  464. levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
  465. block=DlaBottle2neck, cardinality=8, base_width=4)
  466. return _create_dla('dla60_res2next', pretrained, **dict(model_args, **kwargs))
  467. @register_model
  468. def dla34(pretrained=False, **kwargs) -> DLA: # DLA-34
  469. model_args = dict(
  470. levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], block=DlaBasic)
  471. return _create_dla('dla34', pretrained, **dict(model_args, **kwargs))
  472. @register_model
  473. def dla46_c(pretrained=False, **kwargs) -> DLA: # DLA-46-C
  474. model_args = dict(
  475. levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], block=DlaBottleneck)
  476. return _create_dla('dla46_c', pretrained, **dict(model_args, **kwargs))
  477. @register_model
  478. def dla46x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-46-C
  479. model_args = dict(
  480. levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
  481. block=DlaBottleneck, cardinality=32, base_width=4)
  482. return _create_dla('dla46x_c', pretrained, **dict(model_args, **kwargs))
  483. @register_model
  484. def dla60x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-60-C
  485. model_args = dict(
  486. levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256],
  487. block=DlaBottleneck, cardinality=32, base_width=4)
  488. return _create_dla('dla60x_c', pretrained, **dict(model_args, **kwargs))
  489. @register_model
  490. def dla60(pretrained=False, **kwargs) -> DLA: # DLA-60
  491. model_args = dict(
  492. levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
  493. block=DlaBottleneck)
  494. return _create_dla('dla60', pretrained, **dict(model_args, **kwargs))
  495. @register_model
  496. def dla60x(pretrained=False, **kwargs) -> DLA: # DLA-X-60
  497. model_args = dict(
  498. levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
  499. block=DlaBottleneck, cardinality=32, base_width=4)
  500. return _create_dla('dla60x', pretrained, **dict(model_args, **kwargs))
  501. @register_model
  502. def dla102(pretrained=False, **kwargs) -> DLA: # DLA-102
  503. model_args = dict(
  504. levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
  505. block=DlaBottleneck, shortcut_root=True)
  506. return _create_dla('dla102', pretrained, **dict(model_args, **kwargs))
  507. @register_model
  508. def dla102x(pretrained=False, **kwargs) -> DLA: # DLA-X-102
  509. model_args = dict(
  510. levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
  511. block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True)
  512. return _create_dla('dla102x', pretrained, **dict(model_args, **kwargs))
  513. @register_model
  514. def dla102x2(pretrained=False, **kwargs) -> DLA: # DLA-X-102 64
  515. model_args = dict(
  516. levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
  517. block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True)
  518. return _create_dla('dla102x2', pretrained, **dict(model_args, **kwargs))
  519. @register_model
  520. def dla169(pretrained=False, **kwargs) -> DLA: # DLA-169
  521. model_args = dict(
  522. levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
  523. block=DlaBottleneck, shortcut_root=True)
  524. return _create_dla('dla169', pretrained, **dict(model_args, **kwargs))