conv_bn_act.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """ Conv2d + BN + Act
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. from typing import Any, Dict, Optional, Type
  5. from torch import nn as nn
  6. from .typing import LayerType, PadType
  7. from .blur_pool import create_aa
  8. from .create_conv2d import create_conv2d
  9. from .create_norm_act import get_norm_act_layer
  10. class ConvNormAct(nn.Module):
  11. def __init__(
  12. self,
  13. in_channels: int,
  14. out_channels: int,
  15. kernel_size: int = 1,
  16. stride: int = 1,
  17. padding: PadType = '',
  18. dilation: int = 1,
  19. groups: int = 1,
  20. bias: bool = False,
  21. apply_norm: bool = True,
  22. apply_act: bool = True,
  23. norm_layer: LayerType = nn.BatchNorm2d,
  24. act_layer: Optional[LayerType] = nn.ReLU,
  25. aa_layer: Optional[LayerType] = None,
  26. drop_layer: Optional[Type[nn.Module]] = None,
  27. conv_kwargs: Optional[Dict[str, Any]] = None,
  28. norm_kwargs: Optional[Dict[str, Any]] = None,
  29. act_kwargs: Optional[Dict[str, Any]] = None,
  30. device=None,
  31. dtype=None,
  32. ):
  33. dd = {'device': device, 'dtype': dtype}
  34. super().__init__()
  35. conv_kwargs = {**dd, **(conv_kwargs or {})}
  36. norm_kwargs = {**dd, **(norm_kwargs or {})}
  37. act_kwargs = act_kwargs or {}
  38. use_aa = aa_layer is not None and stride > 1
  39. self.conv = create_conv2d(
  40. in_channels,
  41. out_channels,
  42. kernel_size,
  43. stride=1 if use_aa else stride,
  44. padding=padding,
  45. dilation=dilation,
  46. groups=groups,
  47. bias=bias,
  48. **conv_kwargs,
  49. )
  50. if apply_norm:
  51. # NOTE for backwards compatibility with models that use separate norm and act layer definitions
  52. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  53. # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
  54. if drop_layer:
  55. norm_kwargs['drop_layer'] = drop_layer
  56. self.bn = norm_act_layer(
  57. out_channels,
  58. apply_act=apply_act,
  59. act_kwargs=act_kwargs,
  60. **norm_kwargs,
  61. )
  62. else:
  63. self.bn = nn.Sequential()
  64. if drop_layer:
  65. norm_kwargs['drop_layer'] = drop_layer
  66. self.bn.add_module('drop', drop_layer())
  67. self.aa = create_aa(
  68. aa_layer,
  69. out_channels,
  70. stride=stride,
  71. enable=use_aa,
  72. noop=None,
  73. **dd,
  74. )
  75. @property
  76. def in_channels(self):
  77. return self.conv.in_channels
  78. @property
  79. def out_channels(self):
  80. return self.conv.out_channels
  81. def forward(self, x):
  82. x = self.conv(x)
  83. x = self.bn(x)
  84. aa = getattr(self, 'aa', None)
  85. if aa is not None:
  86. x = self.aa(x)
  87. return x
  88. ConvBnAct = ConvNormAct
  89. ConvNormActAa = ConvNormAct # backwards compat, when they were separate