_efficientnet_builder.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. """ EfficientNet, MobileNetV3, etc Builder
  2. Assembles EfficieNet and related network feature blocks from string definitions.
  3. Handles stride, dilation calculations, and selects feature extraction points.
  4. Hacked together by / Copyright 2019, Ross Wightman
  5. """
  6. from typing import Callable, Optional
  7. import logging
  8. import math
  9. import re
  10. from copy import deepcopy
  11. from functools import partial
  12. from typing import Any, Dict, List
  13. import torch.nn as nn
  14. from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
  15. from ._efficientnet_blocks import *
  16. from ._manipulate import named_modules
  17. __all__ = ["EfficientNetBuilder", "BlockArgs", "decode_arch_def", "efficientnet_init_weights",
  18. 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
  19. _logger = logging.getLogger(__name__)
  20. _DEBUG_BUILDER = False
  21. # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
  22. # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
  23. # NOTE: momentum varies btw .99 and .9997 depending on source
  24. # .99 in official TF TPU impl
  25. # .9997 (/w .999 in search space) for paper
  26. BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
  27. BN_EPS_TF_DEFAULT = 1e-3
  28. _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
  29. BlockArgs = List[List[Dict[str, Any]]]
  30. def get_bn_args_tf():
  31. return _BN_ARGS_TF.copy()
  32. def resolve_bn_args(kwargs):
  33. bn_args = {}
  34. bn_momentum = kwargs.pop('bn_momentum', None)
  35. if bn_momentum is not None:
  36. bn_args['momentum'] = bn_momentum
  37. bn_eps = kwargs.pop('bn_eps', None)
  38. if bn_eps is not None:
  39. bn_args['eps'] = bn_eps
  40. return bn_args
  41. def resolve_act_layer(kwargs, default='relu'):
  42. return get_act_layer(kwargs.pop('act_layer', default))
  43. def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
  44. """Round number of filters based on depth multiplier."""
  45. if not multiplier:
  46. return channels
  47. return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
  48. def _log_info_if(msg, condition):
  49. if condition:
  50. _logger.info(msg)
  51. def _parse_ksize(ss):
  52. if ss.isdigit():
  53. return int(ss)
  54. else:
  55. return [int(k) for k in ss.split('.')]
  56. def _decode_block_str(block_str):
  57. """ Decode block definition string
  58. Gets a list of block arg (dicts) through a string notation of arguments.
  59. E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
  60. All args can exist in any order with the exception of the leading string which
  61. is assumed to indicate the block type.
  62. leading string - block type (
  63. ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
  64. r - number of repeat blocks,
  65. k - kernel size,
  66. s - strides (1-9),
  67. e - expansion ratio,
  68. c - output channels,
  69. se - squeeze/excitation ratio
  70. n - activation fn ('re', 'r6', 'hs', or 'sw')
  71. Args:
  72. block_str: a string representation of block arguments.
  73. Returns:
  74. A list of block args (dicts)
  75. Raises:
  76. ValueError: if the string def not properly specified (TODO)
  77. """
  78. assert isinstance(block_str, str)
  79. ops = block_str.split('_')
  80. block_type = ops[0] # take the block type off the front
  81. ops = ops[1:]
  82. options = {}
  83. skip = None
  84. for op in ops:
  85. # string options being checked on individual basis, combine if they grow
  86. if op == 'noskip':
  87. skip = False # force no skip connection
  88. elif op == 'skip':
  89. skip = True # force a skip connection
  90. elif op.startswith('n'):
  91. # activation fn
  92. key = op[0]
  93. v = op[1:]
  94. if v == 're':
  95. value = get_act_layer('relu')
  96. elif v == 'r6':
  97. value = get_act_layer('relu6')
  98. elif v == 'hs':
  99. value = get_act_layer('hard_swish')
  100. elif v == 'sw':
  101. value = get_act_layer('swish') # aka SiLU
  102. elif v == 'mi':
  103. value = get_act_layer('mish')
  104. else:
  105. continue
  106. options[key] = value
  107. else:
  108. # all numeric options
  109. splits = re.split(r'(\d.*)', op)
  110. if len(splits) >= 2:
  111. key, value = splits[:2]
  112. options[key] = value
  113. # if act_layer is None, the model default (passed to model init) will be used
  114. act_layer = options['n'] if 'n' in options else None
  115. start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
  116. end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
  117. force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
  118. num_repeat = int(options['r'])
  119. # each type of block has different valid arguments, fill accordingly
  120. block_args = dict(
  121. block_type=block_type,
  122. out_chs=int(options['c']),
  123. stride=int(options['s']),
  124. act_layer=act_layer,
  125. )
  126. if block_type == 'ir':
  127. block_args.update(dict(
  128. dw_kernel_size=_parse_ksize(options['k']),
  129. exp_kernel_size=start_kernel_size,
  130. pw_kernel_size=end_kernel_size,
  131. exp_ratio=float(options['e']),
  132. se_ratio=float(options.get('se', 0.)),
  133. noskip=skip is False,
  134. s2d=int(options.get('d', 0)) > 0,
  135. ))
  136. if 'cc' in options:
  137. block_args['num_experts'] = int(options['cc'])
  138. elif block_type == 'ds' or block_type == 'dsa':
  139. block_args.update(dict(
  140. dw_kernel_size=_parse_ksize(options['k']),
  141. pw_kernel_size=end_kernel_size,
  142. se_ratio=float(options.get('se', 0.)),
  143. pw_act=block_type == 'dsa',
  144. noskip=block_type == 'dsa' or skip is False,
  145. s2d=int(options.get('d', 0)) > 0,
  146. ))
  147. elif block_type == 'er':
  148. block_args.update(dict(
  149. exp_kernel_size=_parse_ksize(options['k']),
  150. pw_kernel_size=end_kernel_size,
  151. exp_ratio=float(options['e']),
  152. force_in_chs=force_in_chs,
  153. se_ratio=float(options.get('se', 0.)),
  154. noskip=skip is False,
  155. ))
  156. elif block_type == 'cn':
  157. block_args.update(dict(
  158. kernel_size=int(options['k']),
  159. skip=skip is True,
  160. ))
  161. elif block_type == 'uir':
  162. # override exp / proj kernels for start/end in uir block
  163. start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0
  164. end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0
  165. block_args.update(dict(
  166. dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start
  167. dw_kernel_size_mid=_parse_ksize(options['k']),
  168. dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end
  169. exp_ratio=float(options['e']),
  170. se_ratio=float(options.get('se', 0.)),
  171. noskip=skip is False,
  172. ))
  173. elif block_type == 'mha':
  174. kv_dim = int(options['d'])
  175. block_args.update(dict(
  176. dw_kernel_size=_parse_ksize(options['k']),
  177. num_heads=int(options['h']),
  178. key_dim=kv_dim,
  179. value_dim=kv_dim,
  180. kv_stride=int(options.get('v', 1)),
  181. noskip=skip is False,
  182. ))
  183. elif block_type == 'mqa':
  184. kv_dim = int(options['d'])
  185. block_args.update(dict(
  186. dw_kernel_size=_parse_ksize(options['k']),
  187. num_heads=int(options['h']),
  188. key_dim=kv_dim,
  189. value_dim=kv_dim,
  190. kv_stride=int(options.get('v', 1)),
  191. noskip=skip is False,
  192. ))
  193. else:
  194. assert False, 'Unknown block type (%s)' % block_type
  195. if 'gs' in options:
  196. block_args['group_size'] = int(options['gs'])
  197. return block_args, num_repeat
  198. def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
  199. """ Per-stage depth scaling
  200. Scales the block repeats in each stage. This depth scaling impl maintains
  201. compatibility with the EfficientNet scaling method, while allowing sensible
  202. scaling for other models that may have multiple block arg definitions in each stage.
  203. """
  204. # We scale the total repeat count for each stage, there may be multiple
  205. # block arg defs per stage so we need to sum.
  206. num_repeat = sum(repeats)
  207. if depth_trunc == 'round':
  208. # Truncating to int by rounding allows stages with few repeats to remain
  209. # proportionally smaller for longer. This is a good choice when stage definitions
  210. # include single repeat stages that we'd prefer to keep that way as long as possible
  211. num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
  212. else:
  213. # The default for EfficientNet truncates repeats to int via 'ceil'.
  214. # Any multiplier > 1.0 will result in an increased depth for every stage.
  215. num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
  216. # Proportionally distribute repeat count scaling to each block definition in the stage.
  217. # Allocation is done in reverse as it results in the first block being less likely to be scaled.
  218. # The first block makes less sense to repeat in most of the arch definitions.
  219. repeats_scaled = []
  220. for r in repeats[::-1]:
  221. rs = max(1, round((r / num_repeat * num_repeat_scaled)))
  222. repeats_scaled.append(rs)
  223. num_repeat -= r
  224. num_repeat_scaled -= rs
  225. repeats_scaled = repeats_scaled[::-1]
  226. # Apply the calculated scaling to each block arg in the stage
  227. sa_scaled = []
  228. for ba, rep in zip(stack_args, repeats_scaled):
  229. sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
  230. return sa_scaled
  231. def decode_arch_def(
  232. arch_def,
  233. depth_multiplier=1.0,
  234. depth_trunc='ceil',
  235. experts_multiplier=1,
  236. fix_first_last=False,
  237. group_size=None,
  238. ):
  239. """ Decode block architecture definition strings -> block kwargs
  240. Args:
  241. arch_def: architecture definition strings, list of list of strings
  242. depth_multiplier: network depth multiplier
  243. depth_trunc: networ depth truncation mode when applying multiplier
  244. experts_multiplier: CondConv experts multiplier
  245. fix_first_last: fix first and last block depths when multiplier is applied
  246. group_size: group size override for all blocks that weren't explicitly set in arch string
  247. Returns:
  248. list of list of block kwargs
  249. """
  250. arch_args = []
  251. if isinstance(depth_multiplier, tuple):
  252. assert len(depth_multiplier) == len(arch_def)
  253. else:
  254. depth_multiplier = (depth_multiplier,) * len(arch_def)
  255. for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
  256. assert isinstance(block_strings, list)
  257. stack_args = []
  258. repeats = []
  259. for block_str in block_strings:
  260. assert isinstance(block_str, str)
  261. ba, rep = _decode_block_str(block_str)
  262. if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
  263. ba['num_experts'] *= experts_multiplier
  264. if group_size is not None:
  265. ba.setdefault('group_size', group_size)
  266. stack_args.append(ba)
  267. repeats.append(rep)
  268. if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
  269. arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
  270. else:
  271. arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
  272. return arch_args
  273. class EfficientNetBuilder:
  274. """ Build Trunk Blocks
  275. This ended up being somewhat of a cross between
  276. https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
  277. and
  278. https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
  279. """
  280. def __init__(
  281. self,
  282. output_stride: int = 32,
  283. pad_type: str = '',
  284. round_chs_fn: Callable = round_channels,
  285. se_from_exp: bool = False,
  286. act_layer: Optional[LayerType] = None,
  287. norm_layer: Optional[LayerType] = None,
  288. aa_layer: Optional[LayerType] = None,
  289. se_layer: Optional[LayerType] = None,
  290. drop_path_rate: float = 0.,
  291. layer_scale_init_value: Optional[float] = None,
  292. feature_location: str = '',
  293. device=None,
  294. dtype=None,
  295. ):
  296. self.output_stride = output_stride
  297. self.pad_type = pad_type
  298. self.round_chs_fn = round_chs_fn
  299. self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
  300. self.act_layer = act_layer
  301. self.norm_layer = norm_layer
  302. self.aa_layer = aa_layer
  303. self.se_layer = get_attn(se_layer)
  304. try:
  305. self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
  306. self.se_has_ratio = True
  307. except TypeError:
  308. self.se_has_ratio = False
  309. self.drop_path_rate = drop_path_rate
  310. self.layer_scale_init_value = layer_scale_init_value
  311. if feature_location == 'depthwise':
  312. # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
  313. _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
  314. feature_location = 'expansion'
  315. self.feature_location = feature_location
  316. assert feature_location in ('bottleneck', 'expansion', '')
  317. self.dd = {'device': device, 'dtype': dtype} # device/dtype factory kwargs
  318. self.verbose = _DEBUG_BUILDER
  319. # state updated during build, consumed by model
  320. self.in_chs = None
  321. self.features = []
  322. def _make_block(self, ba, block_idx, block_count):
  323. drop_path_rate = self.drop_path_rate * block_idx / block_count
  324. bt = ba.pop('block_type')
  325. ba['in_chs'] = self.in_chs
  326. ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
  327. s2d = ba.get('s2d', 0)
  328. if s2d > 0:
  329. # adjust while space2depth active
  330. ba['out_chs'] *= 4
  331. if 'force_in_chs' in ba and ba['force_in_chs']:
  332. # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
  333. ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
  334. ba['pad_type'] = self.pad_type
  335. # block act fn overrides the model default
  336. ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
  337. assert ba['act_layer'] is not None
  338. ba['norm_layer'] = self.norm_layer
  339. ba['drop_path_rate'] = drop_path_rate
  340. if self.aa_layer is not None:
  341. ba['aa_layer'] = self.aa_layer
  342. se_ratio = ba.pop('se_ratio', None)
  343. if se_ratio and self.se_layer is not None:
  344. if not self.se_from_exp:
  345. # adjust se_ratio by expansion ratio if calculating se channels from block input
  346. se_ratio /= ba.get('exp_ratio', 1.0)
  347. if s2d == 1:
  348. # adjust for start of space2depth
  349. se_ratio /= 4
  350. if self.se_has_ratio:
  351. ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
  352. else:
  353. ba['se_layer'] = self.se_layer
  354. ba.update(self.dd) # device/type factory kwargs
  355. if bt == 'ir':
  356. _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
  357. block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
  358. elif bt == 'ds' or bt == 'dsa':
  359. _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
  360. block = DepthwiseSeparableConv(**ba)
  361. elif bt == 'er':
  362. _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
  363. block = EdgeResidual(**ba)
  364. elif bt == 'cn':
  365. _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
  366. block = ConvBnAct(**ba)
  367. elif bt == 'uir':
  368. _log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
  369. block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value)
  370. elif bt == 'mqa':
  371. _log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
  372. block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value)
  373. elif bt == 'mha':
  374. _log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
  375. block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value)
  376. else:
  377. assert False, 'Unknown block type (%s) while building model.' % bt
  378. self.in_chs = ba['out_chs'] # update in_chs for arg of next block
  379. return block
  380. def __call__(self, in_chs, model_block_args):
  381. """ Build the blocks
  382. Args:
  383. in_chs: Number of input-channels passed to first block
  384. model_block_args: A list of lists, outer list defines stages, inner
  385. list contains strings defining block configuration(s)
  386. Return:
  387. List of block stacks (each stack wrapped in nn.Sequential)
  388. """
  389. _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
  390. self.in_chs = in_chs
  391. total_block_count = sum([len(x) for x in model_block_args])
  392. total_block_idx = 0
  393. current_stride = 2
  394. current_dilation = 1
  395. stages = []
  396. if model_block_args[0][0]['stride'] > 1:
  397. # if the first block starts with a stride, we need to extract first level feat from stem
  398. feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
  399. self.features.append(feature_info)
  400. # outer list of block_args defines the stacks
  401. space2depth = 0
  402. for stack_idx, stack_args in enumerate(model_block_args):
  403. last_stack = stack_idx + 1 == len(model_block_args)
  404. _log_info_if('Stack: {}'.format(stack_idx), self.verbose)
  405. assert isinstance(stack_args, list)
  406. blocks = []
  407. # each stack (stage of blocks) contains a list of block arguments
  408. for block_idx, block_args in enumerate(stack_args):
  409. last_block = block_idx + 1 == len(stack_args)
  410. _log_info_if(' Block: {}'.format(block_idx), self.verbose)
  411. assert block_args['stride'] in (1, 2)
  412. if block_idx >= 1: # only the first block in any stack can have a stride > 1
  413. block_args['stride'] = 1
  414. if not space2depth and block_args.pop('s2d', False):
  415. assert block_args['stride'] == 1
  416. space2depth = 1
  417. if space2depth > 0:
  418. # FIXME s2d is a WIP
  419. if space2depth == 2 and block_args['stride'] == 2:
  420. block_args['stride'] = 1
  421. # to end s2d region, need to correct expansion and se ratio relative to input
  422. block_args['exp_ratio'] /= 4
  423. space2depth = 0
  424. else:
  425. block_args['s2d'] = space2depth
  426. extract_features = False
  427. if last_block:
  428. next_stack_idx = stack_idx + 1
  429. extract_features = next_stack_idx >= len(model_block_args) or \
  430. model_block_args[next_stack_idx][0]['stride'] > 1
  431. next_dilation = current_dilation
  432. if block_args['stride'] > 1:
  433. next_output_stride = current_stride * block_args['stride']
  434. if next_output_stride > self.output_stride:
  435. next_dilation = current_dilation * block_args['stride']
  436. block_args['stride'] = 1
  437. _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
  438. self.output_stride), self.verbose)
  439. else:
  440. current_stride = next_output_stride
  441. block_args['dilation'] = current_dilation
  442. if next_dilation != current_dilation:
  443. current_dilation = next_dilation
  444. # create the block
  445. block = self._make_block(block_args, total_block_idx, total_block_count)
  446. blocks.append(block)
  447. if space2depth == 1:
  448. space2depth = 2
  449. # stash feature module name and channel info for model feature extraction
  450. if extract_features:
  451. feature_info = dict(
  452. stage=stack_idx + 1,
  453. reduction=current_stride,
  454. **block.feature_info(self.feature_location),
  455. )
  456. leaf_name = feature_info.get('module', '')
  457. if leaf_name:
  458. feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
  459. else:
  460. assert last_block
  461. feature_info['module'] = f'blocks.{stack_idx}'
  462. self.features.append(feature_info)
  463. total_block_idx += 1 # incr global block idx (across all stacks)
  464. stages.append(nn.Sequential(*blocks))
  465. return stages
  466. def _init_weight_goog(m, n='', fix_group_fanout=True):
  467. """ Weight initialization as per Tensorflow official implementations.
  468. Args:
  469. m (nn.Module): module to init
  470. n (str): module name
  471. fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
  472. Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
  473. * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
  474. * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
  475. """
  476. if isinstance(m, CondConv2d):
  477. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  478. if fix_group_fanout:
  479. fan_out //= m.groups
  480. init_weight_fn = get_condconv_initializer(
  481. lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
  482. init_weight_fn(m.weight)
  483. if m.bias is not None:
  484. nn.init.zeros_(m.bias)
  485. elif isinstance(m, nn.Conv2d):
  486. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  487. if fix_group_fanout:
  488. fan_out //= m.groups
  489. nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
  490. if m.bias is not None:
  491. nn.init.zeros_(m.bias)
  492. elif isinstance(m, nn.BatchNorm2d):
  493. nn.init.ones_(m.weight)
  494. nn.init.zeros_(m.bias)
  495. elif isinstance(m, nn.Linear):
  496. fan_out = m.weight.size(0) # fan-out
  497. fan_in = 0
  498. if 'routing_fn' in n:
  499. fan_in = m.weight.size(1)
  500. init_range = 1.0 / math.sqrt(fan_in + fan_out)
  501. nn.init.uniform_(m.weight, -init_range, init_range)
  502. nn.init.zeros_(m.bias)
  503. def efficientnet_init_weights(model: nn.Module, init_fn=None):
  504. init_fn = init_fn or _init_weight_goog
  505. for n, m in model.named_modules():
  506. init_fn(m, n)
  507. # iterate and call any module.init_weights() fn, children first
  508. for n, m in named_modules(model):
  509. if hasattr(m, 'init_weights'):
  510. m.init_weights()