| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from einops import rearrange
- from kornia.filters import laplacian
- from huggingface_hub import PyTorchModelHubMixin
- from config import Config
- from dataset import class_labels_TR_sorted
- from models.backbones.build_backbone import build_backbone
- from models.modules.decoder_blocks import BasicDecBlk, ResBlk
- from models.modules.lateral_blocks import BasicLatBlk
- from models.modules.aspp import ASPP, ASPPDeformable
- def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
- if patch_ref is not None:
- grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
- patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
- return patches
- def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
- if patch_ref is not None:
- grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
- image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
- return image
- class BiRefNet(
- nn.Module,
- PyTorchModelHubMixin,
- library_name="birefnet",
- repo_url="https://github.com/ZhengPeng7/BiRefNet",
- tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection']
- ):
- def __init__(self, bb_pretrained=True):
- super(BiRefNet, self).__init__()
- self.config = Config()
- self.epoch = 1
- self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
- channels = self.config.lateral_channels_in_collection
- if self.config.auxiliary_classification:
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.cls_head = nn.Sequential(
- nn.Linear(channels[0], len(class_labels_TR_sorted))
- )
- if self.config.squeeze_block:
- self.squeeze_module = nn.Sequential(*[
- eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
- for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
- ])
- self.decoder = Decoder(channels)
- if self.config.freeze_bb:
- # Freeze the backbone...
- for key, value in self.named_parameters():
- if 'bb.' in key and 'refiner.' not in key:
- value.requires_grad = False
- def forward_enc(self, x):
- if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
- x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
- else:
- x1, x2, x3, x4 = self.bb(x)
- if self.config.mul_scl_ipt:
- B, C, H, W = x.shape
- x_pyramid = F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)
- if self.config.mul_scl_ipt == 'cat':
- if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
- x1_ = self.bb.conv1(x_pyramid); x2_ = self.bb.conv2(x1_); x3_ = self.bb.conv3(x2_); x4_ = self.bb.conv4(x3_)
- else:
- x1_, x2_, x3_, x4_ = self.bb(x_pyramid)
- x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
- x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
- x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
- x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
- elif self.config.mul_scl_ipt == 'add':
- x1_, x2_, x3_, x4_ = self.bb(x_pyramid)
- x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
- x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
- x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
- x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
- class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
- if self.config.cxt:
- x4 = torch.cat(
- (
- *[
- F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
- F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
- F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
- ][-len(self.config.cxt):],
- x4
- ),
- dim=1
- )
- return (x1, x2, x3, x4), class_preds
- def forward_ori(self, x):
- ########## Encoder ##########
- (x1, x2, x3, x4), class_preds = self.forward_enc(x)
- if self.config.squeeze_block:
- x4 = self.squeeze_module(x4)
- ########## Decoder ##########
- features = [x, x1, x2, x3, x4]
- if self.training and self.config.out_ref:
- features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
- scaled_preds = self.decoder(features)
- return scaled_preds, class_preds
- def forward(self, x):
- scaled_preds, class_preds = self.forward_ori(x)
- class_preds_lst = [class_preds]
- return [scaled_preds, class_preds_lst] if self.training else scaled_preds
- class Decoder(nn.Module):
- def __init__(self, channels):
- super(Decoder, self).__init__()
- self.config = Config()
- DecoderBlock = eval(self.config.dec_blk)
- LateralBlock = eval(self.config.lat_blk)
- self.bbs_without_pyramid = ['vit', 'dino']
- self.use_pyramid_neck = any(bb_without_pyramid in self.config.bb for bb_without_pyramid in self.bbs_without_pyramid)
- if self.use_pyramid_neck:
- self.manually_controlled_decoder_in_channels = [c * (1 + int(self.config.mul_scl_ipt == 'cat')) for c in (1536, 768, 384, 192)] # Use the channels of swin_v1_l as default.
- self.pyramid_neck_x4 = LateralBlock(channels[0], self.manually_controlled_decoder_in_channels[0])
- self.pyramid_neck_x3 = LateralBlock(channels[1], self.manually_controlled_decoder_in_channels[1])
- self.pyramid_neck_x2 = LateralBlock(channels[2], self.manually_controlled_decoder_in_channels[2])
- self.pyramid_neck_x1 = LateralBlock(channels[3], self.manually_controlled_decoder_in_channels[3])
- if self.config.dec_ipt:
- self.split = self.config.dec_ipt_split
- N_dec_ipt = 64
- DBlock = SimpleConvs
- ic = 64
- ipt_cha_opt = 1
- ipt_blk_in_channels = [2**i*3 for i in (10, 8, 6, 4, 0)] if self.split else [3] * 5
- ipt_blk_out_channels = [[N_dec_ipt, channels[i]//8][ipt_cha_opt] for i in range(4)]
- self.ipt_blk5 = DBlock(ipt_blk_in_channels[0], ipt_blk_out_channels[0], inter_channels=ic)
- self.ipt_blk4 = DBlock(ipt_blk_in_channels[1], ipt_blk_out_channels[0], inter_channels=ic)
- self.ipt_blk3 = DBlock(ipt_blk_in_channels[2], ipt_blk_out_channels[1], inter_channels=ic)
- self.ipt_blk2 = DBlock(ipt_blk_in_channels[3], ipt_blk_out_channels[2], inter_channels=ic)
- self.ipt_blk1 = DBlock(ipt_blk_in_channels[4], ipt_blk_out_channels[3], inter_channels=ic)
- else:
- self.split = None
- if self.use_pyramid_neck:
- bb_neck_out_channels = [c for c in self.manually_controlled_decoder_in_channels]
- else:
- bb_neck_out_channels = channels.copy()
- dec_blk_out_channels = [c for c in bb_neck_out_channels[1:]] + [bb_neck_out_channels[-1] // 2]
- if self.config.dec_ipt:
- dec_blk_in_channels = [bb_neck_out_channels[i] + ipt_blk_out_channels[max(0, i - 1)] for i in range(len(bb_neck_out_channels))]
- self.decoder_block4 = DecoderBlock(dec_blk_in_channels[0], dec_blk_out_channels[0])
- self.decoder_block3 = DecoderBlock(dec_blk_in_channels[1], dec_blk_out_channels[1])
- self.decoder_block2 = DecoderBlock(dec_blk_in_channels[2], dec_blk_out_channels[2])
- self.decoder_block1 = DecoderBlock(dec_blk_in_channels[3], dec_blk_out_channels[3])
- self.conv_out1 = nn.Sequential(nn.Conv2d(dec_blk_out_channels[3] + (ipt_blk_out_channels[3] if self.config.dec_ipt else 0), 1, 1, 1, 0))
- # Backbone+PyramidNeck --> lateral block --> DecoderBlock
- self.lateral_block4 = LateralBlock(bb_neck_out_channels[1], dec_blk_out_channels[0])
- self.lateral_block3 = LateralBlock(bb_neck_out_channels[2], dec_blk_out_channels[1])
- self.lateral_block2 = LateralBlock(bb_neck_out_channels[3], dec_blk_out_channels[2])
- if self.config.ms_supervision:
- self.conv_ms_spvn_4 = nn.Conv2d(dec_blk_out_channels[0], 1, 1, 1, 0)
- self.conv_ms_spvn_3 = nn.Conv2d(dec_blk_out_channels[1], 1, 1, 1, 0)
- self.conv_ms_spvn_2 = nn.Conv2d(dec_blk_out_channels[2], 1, 1, 1, 0)
- if self.config.out_ref:
- _N = 16
- self.gdt_convs_4 = nn.Sequential(nn.Conv2d(dec_blk_out_channels[0], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
- self.gdt_convs_3 = nn.Sequential(nn.Conv2d(dec_blk_out_channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
- self.gdt_convs_2 = nn.Sequential(nn.Conv2d(dec_blk_out_channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
- self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
- self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
- self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
-
- self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
- self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
- self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
- def forward(self, features):
- if self.training and self.config.out_ref:
- outs_gdt_pred = []
- outs_gdt_label = []
- x, x1, x2, x3, x4, gdt_gt = features
- else:
- x, x1, x2, x3, x4 = features
- size_x1_to_x4_template = [(x.shape[2] // (2 ** i), x.shape[3] // (2 ** i)) for i in (2, 3, 4, 5)]
- if self.use_pyramid_neck:
- x1 = F.interpolate(x1, size=size_x1_to_x4_template[0], mode='bilinear', align_corners=True)
- x1 = self.pyramid_neck_x1(x1)
- x2 = F.interpolate(x2, size=size_x1_to_x4_template[1], mode='bilinear', align_corners=True)
- x2 = self.pyramid_neck_x2(x2)
- x3 = F.interpolate(x3, size=size_x1_to_x4_template[2], mode='bilinear', align_corners=True)
- x3 = self.pyramid_neck_x3(x3)
- x4 = F.interpolate(x4, size=size_x1_to_x4_template[3], mode='bilinear', align_corners=True)
- x4 = self.pyramid_neck_x4(x4)
- outs = []
- if self.config.dec_ipt:
- patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
- x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
- p4 = self.decoder_block4(x4)
- m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
- if self.config.out_ref:
- p4_gdt = self.gdt_convs_4(p4)
- if self.training:
- # >> GT:
- m4_dia = m4
- gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
- outs_gdt_label.append(gdt_label_main_4)
- # >> Pred:
- gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
- outs_gdt_pred.append(gdt_pred_4)
- gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
- # >> Finally:
- p4 = p4 * gdt_attn_4
- _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
- _p3 = _p4 + self.lateral_block4(x3)
- if self.config.dec_ipt:
- patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
- _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
- p3 = self.decoder_block3(_p3)
- m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
- if self.config.out_ref:
- p3_gdt = self.gdt_convs_3(p3)
- if self.training:
- # >> GT:
- # m3 --dilation--> m3_dia
- # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
- m3_dia = m3
- gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
- outs_gdt_label.append(gdt_label_main_3)
- # >> Pred:
- # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
- # F_3^G --sigmoid--> A_3^G
- gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
- outs_gdt_pred.append(gdt_pred_3)
- gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
- # >> Finally:
- # p3 = p3 * A_3^G
- p3 = p3 * gdt_attn_3
- _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
- _p2 = _p3 + self.lateral_block3(x2)
- if self.config.dec_ipt:
- patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
- _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
- p2 = self.decoder_block2(_p2)
- m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
- if self.config.out_ref:
- p2_gdt = self.gdt_convs_2(p2)
- if self.training:
- # >> GT:
- m2_dia = m2
- gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
- outs_gdt_label.append(gdt_label_main_2)
- # >> Pred:
- gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
- outs_gdt_pred.append(gdt_pred_2)
- gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
- # >> Finally:
- p2 = p2 * gdt_attn_2
- _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
- _p1 = _p2 + self.lateral_block2(x1)
- if self.config.dec_ipt:
- patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
- _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
- _p1 = self.decoder_block1(_p1)
- _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
- if self.config.dec_ipt:
- patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
- _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
- p1_out = self.conv_out1(_p1)
- if self.config.ms_supervision and self.training:
- outs.append(m4)
- outs.append(m3)
- outs.append(m2)
- outs.append(p1_out)
- return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs)
- class SimpleConvs(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, inter_channels=64
- ) -> None:
- super().__init__()
- self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1)
- self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
- def forward(self, x):
- return self.conv_out(self.conv1(x))
|