lateral_blocks.py 354 B

12345678910111213141516
  1. import torch.nn as nn
  2. from config import Config
  3. config = Config()
  4. class BasicLatBlk(nn.Module):
  5. def __init__(self, in_channels=64, out_channels=64, ks=1, s=1, p=0):
  6. super(BasicLatBlk, self).__init__()
  7. self.conv = nn.Conv2d(in_channels, out_channels, ks, s, p)
  8. def forward(self, x):
  9. x = self.conv(x)
  10. return x