| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- """Implements several backbone networks."""
- import functools
- import operator
- from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.nn.functional import pixel_shuffle, softmax
- from kornia.core import Module, Tensor
- class HourglassConfig(NamedTuple):
- depth: int
- num_stacks: int
- num_blocks: int
- num_classes: int
- input_channels: int
- head: Type[Module]
- # [Hourglass backbone classes]
- class HourglassBackbone(Module):
- """Hourglass network, taken from https://github.com/zhou13/lcnn.
- Args:
- input_channel: number of input channels.
- depth: number of residual blocks per hourglass module.
- num_stacks: number of hourglass modules stacked together.
- num_blocks: number of layers in each residual block.
- num_classes: number of heads for the output of a hourglass module.
- """
- def __init__(
- self, input_channel: int = 1, depth: int = 4, num_stacks: int = 2, num_blocks: int = 1, num_classes: int = 5
- ) -> None:
- super().__init__()
- self.head = MultitaskHead
- self.net = hg(HourglassConfig(depth, num_stacks, num_blocks, num_classes, input_channel, head=self.head))
- def forward(self, input_images: Tensor) -> Tensor:
- return self.net(input_images)
- class MultitaskHead(Module):
- def __init__(self, input_channels: int) -> None:
- super().__init__()
- m = int(input_channels / 4)
- head_size = [[2], [1], [2]]
- heads = []
- _iter: list[int] = functools.reduce(operator.iconcat, head_size, [])
- for output_channels in _iter:
- heads.append(
- nn.Sequential(
- nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(m, output_channels, kernel_size=1),
- )
- )
- self.heads = nn.ModuleList(heads)
- def forward(self, x: Tensor) -> Tensor:
- return torch.cat([head(x) for head in self.heads], dim=1)
- class Bottleneck2D(Module):
- def __init__(
- self, inplanes: int, planes: int, stride: Union[int, Tuple[int, int]] = 1, downsample: Optional[Module] = None
- ) -> None:
- super().__init__()
- self.bn1 = nn.BatchNorm2d(inplanes)
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
- self.bn2 = nn.BatchNorm2d(planes)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
- self.bn3 = nn.BatchNorm2d(planes)
- self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
- def forward(self, x: Tensor) -> Tensor:
- residual = x
- out = self.bn1(x)
- out = self.relu(out)
- out = self.conv1(out)
- out = self.bn2(out)
- out = self.relu(out)
- out = self.conv2(out)
- out = self.bn3(out)
- out = self.relu(out)
- out = self.conv3(out)
- if self.downsample is not None:
- residual = self.downsample(x)
- out += residual
- return out
- class Hourglass(Module):
- def __init__(self, block: Type[Bottleneck2D], num_blocks: int, planes: int, depth: int, expansion: int = 2) -> None:
- super().__init__()
- self.depth = depth
- self.block = block
- self.expansion = expansion
- self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
- def _make_residual(self, block: Type[Bottleneck2D], num_blocks: int, planes: int) -> Module:
- layers = []
- for _ in range(0, num_blocks):
- layers.append(block(planes * self.expansion, planes))
- return nn.Sequential(*layers)
- def _make_hour_glass(self, block: Type[Bottleneck2D], num_blocks: int, planes: int, depth: int) -> nn.ModuleList:
- hgl = []
- for i in range(depth):
- res = []
- for _ in range(3):
- res.append(self._make_residual(block, num_blocks, planes))
- if i == 0:
- res.append(self._make_residual(block, num_blocks, planes))
- hgl.append(nn.ModuleList(res))
- return nn.ModuleList(hgl)
- def _hour_glass_forward(self, n: int, x: Tensor) -> Tensor:
- up1 = self.hg[n - 1][0](x) # type: ignore[index]
- low1 = F.max_pool2d(x, 2, stride=2)
- low1 = self.hg[n - 1][1](low1) # type: ignore[index]
- if n > 1:
- low2 = self._hour_glass_forward(n - 1, low1)
- else:
- low2 = self.hg[n - 1][3](low1) # type: ignore[index]
- low3 = self.hg[n - 1][2](low2) # type: ignore[index]
- up2 = F.interpolate(low3, size=up1.shape[2:])
- out = up1 + up2
- return out
- def forward(self, x: Tensor) -> Tensor:
- return self._hour_glass_forward(self.depth, x)
- class HourglassNet(Module):
- """Hourglass model from Newell et al ECCV 2016."""
- def __init__(
- self,
- block: Type[Bottleneck2D],
- head: Type[Module],
- depth: int,
- num_stacks: int,
- num_blocks: int,
- num_classes: int,
- input_channels: int,
- expansion: int = 2,
- ) -> None:
- super().__init__()
- self.inplanes = 64
- self.num_feats = 128
- self.num_stacks = num_stacks
- self.expansion = expansion
- self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=2, padding=3)
- self.bn1 = nn.BatchNorm2d(self.inplanes)
- self.relu = nn.ReLU(inplace=True)
- self.layer1 = self._make_residual(block, self.inplanes, 1)
- self.layer2 = self._make_residual(block, self.inplanes, 1)
- self.layer3 = self._make_residual(block, self.num_feats, 1)
- self.maxpool = nn.MaxPool2d(2, stride=2)
- # Build hourglass modules
- ch = self.num_feats * self.expansion
- hgl, res, fc, score, fc_, score_ = [], [], [], [], [], []
- for i in range(num_stacks):
- hgl.append(Hourglass(block, num_blocks, self.num_feats, depth))
- res.append(self._make_residual(block, self.num_feats, num_blocks))
- fc.append(self._make_fc(ch, ch))
- score.append(head(ch))
- if i < num_stacks - 1:
- fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
- score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
- self.hg = nn.ModuleList(hgl)
- self.res = nn.ModuleList(res)
- self.fc = nn.ModuleList(fc)
- self.score = nn.ModuleList(score)
- self.fc_ = nn.ModuleList(fc_)
- self.score_ = nn.ModuleList(score_)
- def _make_residual(
- self, block: Type[Bottleneck2D], planes: int, blocks: int, stride: Union[int, Tuple[int, int]] = 1
- ) -> Module:
- downsample = None
- if stride != 1 or self.inplanes != planes * self.expansion:
- downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * self.expansion, kernel_size=1, stride=stride))
- layers = []
- layers.append(block(self.inplanes, planes, stride, downsample))
- self.inplanes = planes * self.expansion
- for _ in range(1, blocks):
- layers.append(block(self.inplanes, planes))
- return nn.Sequential(*layers)
- def _make_fc(self, inplanes: int, outplanes: int) -> Module:
- bn = nn.BatchNorm2d(inplanes)
- conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
- return nn.Sequential(conv, bn, self.relu)
- def forward(self, x: Tensor) -> Tensor:
- out = []
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.layer1(x)
- x = self.maxpool(x)
- x = self.layer2(x)
- x = self.layer3(x)
- for i in range(self.num_stacks):
- y = self.hg[i](x)
- y = self.res[i](y)
- y = self.fc[i](y)
- score = self.score[i](y)
- out.append(score)
- if i < self.num_stacks - 1:
- fc_ = self.fc_[i](y)
- score_ = self.score_[i](score)
- x = x + fc_ + score_
- return y
- def hg(cfg: HourglassConfig) -> HourglassNet:
- """Create HourglassNet."""
- return HourglassNet(
- Bottleneck2D,
- head=cfg.head,
- depth=cfg.depth,
- num_stacks=cfg.num_stacks,
- num_blocks=cfg.num_blocks,
- num_classes=cfg.num_classes,
- input_channels=cfg.input_channels,
- )
- # [Backbone decoders]
- class SuperpointDecoder(Module):
- """Junction decoder based on the SuperPoint architecture.
- Args:
- input_feat_dim: channel size of the input features.
- Returns:
- the junction heatmap, with shape (B, H, W).
- """
- def __init__(self, input_feat_dim: int = 128, grid_size: int = 8) -> None:
- super().__init__()
- self.relu = nn.ReLU(inplace=True)
- # Perform strided convolution when using lcnn backbone.
- self.convPa = nn.Conv2d(input_feat_dim, 256, kernel_size=3, stride=2, padding=1)
- self.convPb = nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
- self.grid_size = grid_size
- def forward(self, input_features: Tensor) -> Tensor:
- feat = self.relu(self.convPa(input_features))
- semi = self.convPb(feat)
- # Convert from semi-dense to dense heatmap
- junc_prob = softmax(semi, dim=1)
- junc_pred = pixel_shuffle(junc_prob[:, :-1, :, :], self.grid_size)[:, 0]
- return junc_pred
- class PixelShuffleDecoder(Module):
- """Pixel shuffle decoder used to predict the line heatmap.
- Args:
- input_feat_dim: channel size of the input features.
- num_upsample: how many upsamples are performed.
- output_channel: number of output channels.
- Returns:
- the (B, 1, H, W) line heatmap.
- """
- def __init__(self, input_feat_dim: int = 128, num_upsample: int = 2, output_channel: int = 2) -> None:
- super().__init__()
- # Get channel parameters
- self.channel_conf = self.get_channel_conf(num_upsample)
- # Define the pixel shuffle
- self.pixshuffle = nn.PixelShuffle(2)
- # Process the feature
- conv_block_lst = []
- # The input block
- conv_block_lst.append(
- nn.Sequential(
- nn.Conv2d(input_feat_dim, self.channel_conf[0], kernel_size=3, stride=1, padding=1),
- nn.BatchNorm2d(self.channel_conf[0]),
- nn.ReLU(inplace=True),
- )
- )
- # Intermediate block
- for channel in self.channel_conf[1:-1]:
- conv_block_lst.append(
- nn.Sequential(
- nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
- nn.BatchNorm2d(channel),
- nn.ReLU(inplace=True),
- )
- )
- # Output block
- conv_block_lst.append(
- nn.Sequential(nn.Conv2d(self.channel_conf[-1], output_channel, kernel_size=1, stride=1, padding=0))
- )
- self.conv_block_lst = nn.ModuleList(conv_block_lst)
- def get_channel_conf(self, num_upsample: int) -> List[int]:
- """Get num of channels based on number of upsampling."""
- if num_upsample == 2:
- return [256, 64, 16]
- return [256, 64, 16, 4]
- def forward(self, input_features: Tensor) -> Tensor:
- # Iterate til output block
- out = input_features
- for block in self.conv_block_lst[:-1]:
- out = block(out)
- out = self.pixshuffle(out)
- # Output layer
- out = self.conv_block_lst[-1](out)
- heatmap = softmax(out, dim=1)[:, 1, :, :]
- return heatmap
- class SuperpointDescriptor(Module):
- """Descriptor decoder based on the SuperPoint arcihtecture.
- Args:
- input_feat_dim: channel size of the input features.
- Returns:
- the semi-dense descriptors with shape (B, 128, H/4, W/4).
- """
- def __init__(self, input_feat_dim: int = 128) -> None:
- super().__init__()
- self.relu = nn.ReLU(inplace=True)
- self.convPa = nn.Conv2d(input_feat_dim, 256, kernel_size=3, stride=1, padding=1)
- self.convPb = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)
- def forward(self, input_features: Tensor) -> Tensor:
- feat = self.relu(self.convPa(input_features))
- semi = self.convPb(feat)
- return semi
- # [Combination of all previous models in one]
- class SOLD2Net(Module):
- """Full network for SOLD².
- Args:
- model_cfg: the configuration as a Dict.
- Returns:
- a Dict with the following values:
- junctions: heatmap of junctions.
- heatmap: line heatmap.
- descriptors: semi-dense descriptors.
- """
- def __init__(self, model_cfg: Dict[str, Any]) -> None:
- super().__init__()
- self.cfg = model_cfg
- # Backbone
- self.backbone_net = HourglassBackbone(**self.cfg["backbone_cfg"])
- feat_channel = 256
- # Junction decoder
- self.junction_decoder = SuperpointDecoder(feat_channel, self.cfg["grid_size"])
- # Line heatmap decoder
- self.heatmap_decoder = PixelShuffleDecoder(feat_channel, num_upsample=2)
- # Descriptor decoder
- if "use_descriptor" in self.cfg:
- self.descriptor_decoder = SuperpointDescriptor(feat_channel)
- def forward(self, input_images: Tensor) -> Dict[str, Tensor]:
- # The backbone
- features = self.backbone_net(input_images)
- # junction decoder
- junctions = self.junction_decoder(features)
- # heatmap decoder
- heatmaps = self.heatmap_decoder(features)
- outputs = {"junctions": junctions, "heatmap": heatmaps}
- # Descriptor decoder
- if "use_descriptor" in self.cfg:
- outputs["descriptors"] = self.descriptor_decoder(features)
- return outputs
|