# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any, Optional, Tuple import torch from torch import nn from kornia.core import Module, Tensor from .vgg import vgg19_bn class VGG19(Module): def __init__(self, amp: bool = False, amp_dtype: torch.dtype = torch.float16) -> None: super().__init__() self.layers = nn.ModuleList(vgg19_bn().features[:40]) # type: ignore # Maxpool layers: 6, 13, 26, 39 self.amp = amp self.amp_dtype = amp_dtype def forward(self, x: Tensor, **kwargs): # type: ignore[no-untyped-def] with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): feats = [] sizes = [] for layer in self.layers: if isinstance(layer, nn.MaxPool2d): feats.append(x) sizes.append(x.shape[-2:]) x = layer(x) return feats, sizes class FrozenDINOv2(Module): def __init__(self, amp: bool = True, amp_dtype: torch.dtype = torch.float16, dinov2_weights: Optional[Any] = None): super().__init__() if dinov2_weights is None: dinov2_weights = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu" ) from .transformer import vit_large vit_kwargs = dict( img_size=518, patch_size=14, init_values=1.0, ffn_layer="mlp", block_chunks=0, ) dinov2_vitl14 = vit_large(**vit_kwargs).eval() dinov2_vitl14.load_state_dict(dinov2_weights) self.amp = amp self.amp_dtype = amp_dtype if self.amp: dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP def forward(self, x: Tensor): # type: ignore[no-untyped-def] B, _C, H, W = x.shape if self.dinov2_vitl14[0].device != x.device: self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) with torch.inference_mode(): dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype)) features_16 = dinov2_features_16["x_norm_patchtokens"].permute(0, 2, 1).reshape(B, 1024, H // 14, W // 14) return [features_16.clone()], [(H // 14, W // 14)] # clone from inference mode to use in autograd class VGG_DINOv2(Module): def __init__(self, vgg_kwargs=None, dinov2_kwargs=None): # type: ignore[no-untyped-def] if (vgg_kwargs is None) or (dinov2_kwargs is None): raise ValueError("Input kwargs please") super().__init__() self.vgg = VGG19(**vgg_kwargs) self.frozen_dinov2 = FrozenDINOv2(**dinov2_kwargs) def forward(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]: feats_vgg, sizes_vgg = self.vgg(x) feat_dinov2, size_dinov2 = self.frozen_dinov2(x) return feats_vgg + feat_dinov2, sizes_vgg + size_dinov2