| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Any, Optional, Tuple
- import torch
- from torch import nn
- from kornia.core import Module, Tensor
- from .vgg import vgg19_bn
- class VGG19(Module):
- def __init__(self, amp: bool = False, amp_dtype: torch.dtype = torch.float16) -> None:
- super().__init__()
- self.layers = nn.ModuleList(vgg19_bn().features[:40]) # type: ignore
- # Maxpool layers: 6, 13, 26, 39
- self.amp = amp
- self.amp_dtype = amp_dtype
- def forward(self, x: Tensor, **kwargs): # type: ignore[no-untyped-def]
- with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
- feats = []
- sizes = []
- for layer in self.layers:
- if isinstance(layer, nn.MaxPool2d):
- feats.append(x)
- sizes.append(x.shape[-2:])
- x = layer(x)
- return feats, sizes
- class FrozenDINOv2(Module):
- def __init__(self, amp: bool = True, amp_dtype: torch.dtype = torch.float16, dinov2_weights: Optional[Any] = None):
- super().__init__()
- if dinov2_weights is None:
- dinov2_weights = torch.hub.load_state_dict_from_url(
- "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu"
- )
- from .transformer import vit_large
- vit_kwargs = dict(
- img_size=518,
- patch_size=14,
- init_values=1.0,
- ffn_layer="mlp",
- block_chunks=0,
- )
- dinov2_vitl14 = vit_large(**vit_kwargs).eval()
- dinov2_vitl14.load_state_dict(dinov2_weights)
- self.amp = amp
- self.amp_dtype = amp_dtype
- if self.amp:
- dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
- self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
- def forward(self, x: Tensor): # type: ignore[no-untyped-def]
- B, _C, H, W = x.shape
- if self.dinov2_vitl14[0].device != x.device:
- self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
- with torch.inference_mode():
- dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
- features_16 = dinov2_features_16["x_norm_patchtokens"].permute(0, 2, 1).reshape(B, 1024, H // 14, W // 14)
- return [features_16.clone()], [(H // 14, W // 14)] # clone from inference mode to use in autograd
- class VGG_DINOv2(Module):
- def __init__(self, vgg_kwargs=None, dinov2_kwargs=None): # type: ignore[no-untyped-def]
- if (vgg_kwargs is None) or (dinov2_kwargs is None):
- raise ValueError("Input kwargs please")
- super().__init__()
- self.vgg = VGG19(**vgg_kwargs)
- self.frozen_dinov2 = FrozenDINOv2(**dinov2_kwargs)
- def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
- feats_vgg, sizes_vgg = self.vgg(x)
- feat_dinov2, size_dinov2 = self.frozen_dinov2(x)
- return feats_vgg + feat_dinov2, sizes_vgg + size_dinov2
|