backbones.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Implements several backbone networks."""
  18. import functools
  19. import operator
  20. from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union
  21. import torch
  22. import torch.nn.functional as F
  23. from torch import nn
  24. from torch.nn.functional import pixel_shuffle, softmax
  25. from kornia.core import Module, Tensor
  26. class HourglassConfig(NamedTuple):
  27. depth: int
  28. num_stacks: int
  29. num_blocks: int
  30. num_classes: int
  31. input_channels: int
  32. head: Type[Module]
  33. # [Hourglass backbone classes]
  34. class HourglassBackbone(Module):
  35. """Hourglass network, taken from https://github.com/zhou13/lcnn.
  36. Args:
  37. input_channel: number of input channels.
  38. depth: number of residual blocks per hourglass module.
  39. num_stacks: number of hourglass modules stacked together.
  40. num_blocks: number of layers in each residual block.
  41. num_classes: number of heads for the output of a hourglass module.
  42. """
  43. def __init__(
  44. self, input_channel: int = 1, depth: int = 4, num_stacks: int = 2, num_blocks: int = 1, num_classes: int = 5
  45. ) -> None:
  46. super().__init__()
  47. self.head = MultitaskHead
  48. self.net = hg(HourglassConfig(depth, num_stacks, num_blocks, num_classes, input_channel, head=self.head))
  49. def forward(self, input_images: Tensor) -> Tensor:
  50. return self.net(input_images)
  51. class MultitaskHead(Module):
  52. def __init__(self, input_channels: int) -> None:
  53. super().__init__()
  54. m = int(input_channels / 4)
  55. head_size = [[2], [1], [2]]
  56. heads = []
  57. _iter: list[int] = functools.reduce(operator.iconcat, head_size, [])
  58. for output_channels in _iter:
  59. heads.append(
  60. nn.Sequential(
  61. nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
  62. nn.ReLU(inplace=True),
  63. nn.Conv2d(m, output_channels, kernel_size=1),
  64. )
  65. )
  66. self.heads = nn.ModuleList(heads)
  67. def forward(self, x: Tensor) -> Tensor:
  68. return torch.cat([head(x) for head in self.heads], dim=1)
  69. class Bottleneck2D(Module):
  70. def __init__(
  71. self, inplanes: int, planes: int, stride: Union[int, Tuple[int, int]] = 1, downsample: Optional[Module] = None
  72. ) -> None:
  73. super().__init__()
  74. self.bn1 = nn.BatchNorm2d(inplanes)
  75. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
  76. self.bn2 = nn.BatchNorm2d(planes)
  77. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
  78. self.bn3 = nn.BatchNorm2d(planes)
  79. self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
  80. self.relu = nn.ReLU(inplace=True)
  81. self.downsample = downsample
  82. self.stride = stride
  83. def forward(self, x: Tensor) -> Tensor:
  84. residual = x
  85. out = self.bn1(x)
  86. out = self.relu(out)
  87. out = self.conv1(out)
  88. out = self.bn2(out)
  89. out = self.relu(out)
  90. out = self.conv2(out)
  91. out = self.bn3(out)
  92. out = self.relu(out)
  93. out = self.conv3(out)
  94. if self.downsample is not None:
  95. residual = self.downsample(x)
  96. out += residual
  97. return out
  98. class Hourglass(Module):
  99. def __init__(self, block: Type[Bottleneck2D], num_blocks: int, planes: int, depth: int, expansion: int = 2) -> None:
  100. super().__init__()
  101. self.depth = depth
  102. self.block = block
  103. self.expansion = expansion
  104. self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
  105. def _make_residual(self, block: Type[Bottleneck2D], num_blocks: int, planes: int) -> Module:
  106. layers = []
  107. for _ in range(0, num_blocks):
  108. layers.append(block(planes * self.expansion, planes))
  109. return nn.Sequential(*layers)
  110. def _make_hour_glass(self, block: Type[Bottleneck2D], num_blocks: int, planes: int, depth: int) -> nn.ModuleList:
  111. hgl = []
  112. for i in range(depth):
  113. res = []
  114. for _ in range(3):
  115. res.append(self._make_residual(block, num_blocks, planes))
  116. if i == 0:
  117. res.append(self._make_residual(block, num_blocks, planes))
  118. hgl.append(nn.ModuleList(res))
  119. return nn.ModuleList(hgl)
  120. def _hour_glass_forward(self, n: int, x: Tensor) -> Tensor:
  121. up1 = self.hg[n - 1][0](x) # type: ignore[index]
  122. low1 = F.max_pool2d(x, 2, stride=2)
  123. low1 = self.hg[n - 1][1](low1) # type: ignore[index]
  124. if n > 1:
  125. low2 = self._hour_glass_forward(n - 1, low1)
  126. else:
  127. low2 = self.hg[n - 1][3](low1) # type: ignore[index]
  128. low3 = self.hg[n - 1][2](low2) # type: ignore[index]
  129. up2 = F.interpolate(low3, size=up1.shape[2:])
  130. out = up1 + up2
  131. return out
  132. def forward(self, x: Tensor) -> Tensor:
  133. return self._hour_glass_forward(self.depth, x)
  134. class HourglassNet(Module):
  135. """Hourglass model from Newell et al ECCV 2016."""
  136. def __init__(
  137. self,
  138. block: Type[Bottleneck2D],
  139. head: Type[Module],
  140. depth: int,
  141. num_stacks: int,
  142. num_blocks: int,
  143. num_classes: int,
  144. input_channels: int,
  145. expansion: int = 2,
  146. ) -> None:
  147. super().__init__()
  148. self.inplanes = 64
  149. self.num_feats = 128
  150. self.num_stacks = num_stacks
  151. self.expansion = expansion
  152. self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=2, padding=3)
  153. self.bn1 = nn.BatchNorm2d(self.inplanes)
  154. self.relu = nn.ReLU(inplace=True)
  155. self.layer1 = self._make_residual(block, self.inplanes, 1)
  156. self.layer2 = self._make_residual(block, self.inplanes, 1)
  157. self.layer3 = self._make_residual(block, self.num_feats, 1)
  158. self.maxpool = nn.MaxPool2d(2, stride=2)
  159. # Build hourglass modules
  160. ch = self.num_feats * self.expansion
  161. hgl, res, fc, score, fc_, score_ = [], [], [], [], [], []
  162. for i in range(num_stacks):
  163. hgl.append(Hourglass(block, num_blocks, self.num_feats, depth))
  164. res.append(self._make_residual(block, self.num_feats, num_blocks))
  165. fc.append(self._make_fc(ch, ch))
  166. score.append(head(ch))
  167. if i < num_stacks - 1:
  168. fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
  169. score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
  170. self.hg = nn.ModuleList(hgl)
  171. self.res = nn.ModuleList(res)
  172. self.fc = nn.ModuleList(fc)
  173. self.score = nn.ModuleList(score)
  174. self.fc_ = nn.ModuleList(fc_)
  175. self.score_ = nn.ModuleList(score_)
  176. def _make_residual(
  177. self, block: Type[Bottleneck2D], planes: int, blocks: int, stride: Union[int, Tuple[int, int]] = 1
  178. ) -> Module:
  179. downsample = None
  180. if stride != 1 or self.inplanes != planes * self.expansion:
  181. downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * self.expansion, kernel_size=1, stride=stride))
  182. layers = []
  183. layers.append(block(self.inplanes, planes, stride, downsample))
  184. self.inplanes = planes * self.expansion
  185. for _ in range(1, blocks):
  186. layers.append(block(self.inplanes, planes))
  187. return nn.Sequential(*layers)
  188. def _make_fc(self, inplanes: int, outplanes: int) -> Module:
  189. bn = nn.BatchNorm2d(inplanes)
  190. conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
  191. return nn.Sequential(conv, bn, self.relu)
  192. def forward(self, x: Tensor) -> Tensor:
  193. out = []
  194. x = self.conv1(x)
  195. x = self.bn1(x)
  196. x = self.relu(x)
  197. x = self.layer1(x)
  198. x = self.maxpool(x)
  199. x = self.layer2(x)
  200. x = self.layer3(x)
  201. for i in range(self.num_stacks):
  202. y = self.hg[i](x)
  203. y = self.res[i](y)
  204. y = self.fc[i](y)
  205. score = self.score[i](y)
  206. out.append(score)
  207. if i < self.num_stacks - 1:
  208. fc_ = self.fc_[i](y)
  209. score_ = self.score_[i](score)
  210. x = x + fc_ + score_
  211. return y
  212. def hg(cfg: HourglassConfig) -> HourglassNet:
  213. """Create HourglassNet."""
  214. return HourglassNet(
  215. Bottleneck2D,
  216. head=cfg.head,
  217. depth=cfg.depth,
  218. num_stacks=cfg.num_stacks,
  219. num_blocks=cfg.num_blocks,
  220. num_classes=cfg.num_classes,
  221. input_channels=cfg.input_channels,
  222. )
  223. # [Backbone decoders]
  224. class SuperpointDecoder(Module):
  225. """Junction decoder based on the SuperPoint architecture.
  226. Args:
  227. input_feat_dim: channel size of the input features.
  228. Returns:
  229. the junction heatmap, with shape (B, H, W).
  230. """
  231. def __init__(self, input_feat_dim: int = 128, grid_size: int = 8) -> None:
  232. super().__init__()
  233. self.relu = nn.ReLU(inplace=True)
  234. # Perform strided convolution when using lcnn backbone.
  235. self.convPa = nn.Conv2d(input_feat_dim, 256, kernel_size=3, stride=2, padding=1)
  236. self.convPb = nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
  237. self.grid_size = grid_size
  238. def forward(self, input_features: Tensor) -> Tensor:
  239. feat = self.relu(self.convPa(input_features))
  240. semi = self.convPb(feat)
  241. # Convert from semi-dense to dense heatmap
  242. junc_prob = softmax(semi, dim=1)
  243. junc_pred = pixel_shuffle(junc_prob[:, :-1, :, :], self.grid_size)[:, 0]
  244. return junc_pred
  245. class PixelShuffleDecoder(Module):
  246. """Pixel shuffle decoder used to predict the line heatmap.
  247. Args:
  248. input_feat_dim: channel size of the input features.
  249. num_upsample: how many upsamples are performed.
  250. output_channel: number of output channels.
  251. Returns:
  252. the (B, 1, H, W) line heatmap.
  253. """
  254. def __init__(self, input_feat_dim: int = 128, num_upsample: int = 2, output_channel: int = 2) -> None:
  255. super().__init__()
  256. # Get channel parameters
  257. self.channel_conf = self.get_channel_conf(num_upsample)
  258. # Define the pixel shuffle
  259. self.pixshuffle = nn.PixelShuffle(2)
  260. # Process the feature
  261. conv_block_lst = []
  262. # The input block
  263. conv_block_lst.append(
  264. nn.Sequential(
  265. nn.Conv2d(input_feat_dim, self.channel_conf[0], kernel_size=3, stride=1, padding=1),
  266. nn.BatchNorm2d(self.channel_conf[0]),
  267. nn.ReLU(inplace=True),
  268. )
  269. )
  270. # Intermediate block
  271. for channel in self.channel_conf[1:-1]:
  272. conv_block_lst.append(
  273. nn.Sequential(
  274. nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
  275. nn.BatchNorm2d(channel),
  276. nn.ReLU(inplace=True),
  277. )
  278. )
  279. # Output block
  280. conv_block_lst.append(
  281. nn.Sequential(nn.Conv2d(self.channel_conf[-1], output_channel, kernel_size=1, stride=1, padding=0))
  282. )
  283. self.conv_block_lst = nn.ModuleList(conv_block_lst)
  284. def get_channel_conf(self, num_upsample: int) -> List[int]:
  285. """Get num of channels based on number of upsampling."""
  286. if num_upsample == 2:
  287. return [256, 64, 16]
  288. return [256, 64, 16, 4]
  289. def forward(self, input_features: Tensor) -> Tensor:
  290. # Iterate til output block
  291. out = input_features
  292. for block in self.conv_block_lst[:-1]:
  293. out = block(out)
  294. out = self.pixshuffle(out)
  295. # Output layer
  296. out = self.conv_block_lst[-1](out)
  297. heatmap = softmax(out, dim=1)[:, 1, :, :]
  298. return heatmap
  299. class SuperpointDescriptor(Module):
  300. """Descriptor decoder based on the SuperPoint arcihtecture.
  301. Args:
  302. input_feat_dim: channel size of the input features.
  303. Returns:
  304. the semi-dense descriptors with shape (B, 128, H/4, W/4).
  305. """
  306. def __init__(self, input_feat_dim: int = 128) -> None:
  307. super().__init__()
  308. self.relu = nn.ReLU(inplace=True)
  309. self.convPa = nn.Conv2d(input_feat_dim, 256, kernel_size=3, stride=1, padding=1)
  310. self.convPb = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)
  311. def forward(self, input_features: Tensor) -> Tensor:
  312. feat = self.relu(self.convPa(input_features))
  313. semi = self.convPb(feat)
  314. return semi
  315. # [Combination of all previous models in one]
  316. class SOLD2Net(Module):
  317. """Full network for SOLD².
  318. Args:
  319. model_cfg: the configuration as a Dict.
  320. Returns:
  321. a Dict with the following values:
  322. junctions: heatmap of junctions.
  323. heatmap: line heatmap.
  324. descriptors: semi-dense descriptors.
  325. """
  326. def __init__(self, model_cfg: Dict[str, Any]) -> None:
  327. super().__init__()
  328. self.cfg = model_cfg
  329. # Backbone
  330. self.backbone_net = HourglassBackbone(**self.cfg["backbone_cfg"])
  331. feat_channel = 256
  332. # Junction decoder
  333. self.junction_decoder = SuperpointDecoder(feat_channel, self.cfg["grid_size"])
  334. # Line heatmap decoder
  335. self.heatmap_decoder = PixelShuffleDecoder(feat_channel, num_upsample=2)
  336. # Descriptor decoder
  337. if "use_descriptor" in self.cfg:
  338. self.descriptor_decoder = SuperpointDescriptor(feat_channel)
  339. def forward(self, input_images: Tensor) -> Dict[str, Tensor]:
  340. # The backbone
  341. features = self.backbone_net(input_images)
  342. # junction decoder
  343. junctions = self.junction_decoder(features)
  344. # heatmap decoder
  345. heatmaps = self.heatmap_decoder(features)
  346. outputs = {"junctions": junctions, "heatmap": heatmaps}
  347. # Descriptor decoder
  348. if "use_descriptor" in self.cfg:
  349. outputs["descriptors"] = self.descriptor_decoder(features)
  350. return outputs