separable_conv.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """ Depthwise Separable Conv Modules
  2. Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
  3. DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. from typing import Optional, Type, Union
  7. from torch import nn as nn
  8. from .create_conv2d import create_conv2d
  9. from .create_norm_act import get_norm_act_layer
  10. class SeparableConvNormAct(nn.Module):
  11. """ Separable Conv w/ trailing Norm and Activation
  12. """
  13. def __init__(
  14. self,
  15. in_channels: int,
  16. out_channels: int,
  17. kernel_size: int = 3,
  18. stride: int = 1,
  19. dilation: int = 1,
  20. padding: str = '',
  21. bias: bool = False,
  22. channel_multiplier: float = 1.0,
  23. pw_kernel_size: int = 1,
  24. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  25. act_layer: Type[nn.Module] = nn.ReLU,
  26. apply_act: bool = True,
  27. drop_layer: Optional[Type[nn.Module]] = None,
  28. device=None,
  29. dtype=None,
  30. ):
  31. dd = {'device': device, 'dtype': dtype}
  32. super().__init__()
  33. self.conv_dw = create_conv2d(
  34. in_channels,
  35. int(in_channels * channel_multiplier),
  36. kernel_size,
  37. stride=stride,
  38. dilation=dilation,
  39. padding=padding,
  40. depthwise=True,
  41. **dd,
  42. )
  43. self.conv_pw = create_conv2d(
  44. int(in_channels * channel_multiplier),
  45. out_channels,
  46. pw_kernel_size,
  47. padding=padding,
  48. bias=bias,
  49. **dd,
  50. )
  51. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  52. norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
  53. self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs, **dd)
  54. @property
  55. def in_channels(self):
  56. return self.conv_dw.in_channels
  57. @property
  58. def out_channels(self):
  59. return self.conv_pw.out_channels
  60. def forward(self, x):
  61. x = self.conv_dw(x)
  62. x = self.conv_pw(x)
  63. x = self.bn(x)
  64. return x
  65. SeparableConvBnAct = SeparableConvNormAct
  66. class SeparableConv2d(nn.Module):
  67. """ Separable Conv
  68. """
  69. def __init__(
  70. self,
  71. in_channels,
  72. out_channels,
  73. kernel_size=3,
  74. stride=1,
  75. dilation=1,
  76. padding='',
  77. bias=False,
  78. channel_multiplier=1.0,
  79. pw_kernel_size=1,
  80. device=None,
  81. dtype=None,
  82. ):
  83. dd = {'device': device, 'dtype': dtype}
  84. super().__init__()
  85. self.conv_dw = create_conv2d(
  86. in_channels,
  87. int(in_channels * channel_multiplier),
  88. kernel_size,
  89. stride=stride,
  90. dilation=dilation,
  91. padding=padding,
  92. depthwise=True,
  93. **dd,
  94. )
  95. self.conv_pw = create_conv2d(
  96. int(in_channels * channel_multiplier),
  97. out_channels,
  98. pw_kernel_size,
  99. padding=padding,
  100. bias=bias,
  101. **dd,
  102. )
  103. @property
  104. def in_channels(self):
  105. return self.conv_dw.in_channels
  106. @property
  107. def out_channels(self):
  108. return self.conv_pw.out_channels
  109. def forward(self, x):
  110. x = self.conv_dw(x)
  111. x = self.conv_pw(x)
  112. return x