decoder_blocks.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import torch.nn as nn
  3. from models.modules.aspp import ASPP, ASPPDeformable
  4. from config import Config
  5. config = Config()
  6. class BasicDecBlk(nn.Module):
  7. def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
  8. super(BasicDecBlk, self).__init__()
  9. inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
  10. self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
  11. self.relu_in = nn.ReLU(inplace=True)
  12. if config.dec_att == 'ASPP':
  13. self.dec_att = ASPP(in_channels=inter_channels)
  14. elif config.dec_att == 'ASPPDeformable':
  15. self.dec_att = ASPPDeformable(in_channels=inter_channels)
  16. self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
  17. self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
  18. self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
  19. def forward(self, x):
  20. x = self.conv_in(x)
  21. x = self.bn_in(x)
  22. x = self.relu_in(x)
  23. if hasattr(self, 'dec_att'):
  24. x = self.dec_att(x)
  25. x = self.conv_out(x)
  26. x = self.bn_out(x)
  27. return x
  28. class ResBlk(nn.Module):
  29. def __init__(self, in_channels=64, out_channels=None, inter_channels=64):
  30. super(ResBlk, self).__init__()
  31. if out_channels is None:
  32. out_channels = in_channels
  33. inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
  34. self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
  35. self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
  36. self.relu_in = nn.ReLU(inplace=True)
  37. if config.dec_att == 'ASPP':
  38. self.dec_att = ASPP(in_channels=inter_channels)
  39. elif config.dec_att == 'ASPPDeformable':
  40. self.dec_att = ASPPDeformable(in_channels=inter_channels)
  41. self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
  42. self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
  43. self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
  44. def forward(self, x):
  45. _x = self.conv_resi(x)
  46. x = self.conv_in(x)
  47. x = self.bn_in(x)
  48. x = self.relu_in(x)
  49. if hasattr(self, 'dec_att'):
  50. x = self.dec_att(x)
  51. x = self.conv_out(x)
  52. x = self.bn_out(x)
  53. return x + _x