|
|
@@ -8,7 +8,8 @@ import gc
|
|
|
|
|
|
|
|
|
class ResNet50(nn.Module):
|
|
|
- def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
|
|
|
+ def __init__(self, pretrained=False, high_res = False, weights = None,
|
|
|
+ dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
|
|
|
super().__init__()
|
|
|
if dilation is None:
|
|
|
dilation = [False,False,False]
|
|
|
@@ -24,7 +25,7 @@ class ResNet50(nn.Module):
|
|
|
self.freeze_bn = freeze_bn
|
|
|
self.early_exit = early_exit
|
|
|
self.amp = amp
|
|
|
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
+ self.amp_dtype = amp_dtype
|
|
|
|
|
|
def forward(self, x, **kwargs):
|
|
|
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
@@ -56,11 +57,11 @@ class ResNet50(nn.Module):
|
|
|
pass
|
|
|
|
|
|
class VGG19(nn.Module):
|
|
|
- def __init__(self, pretrained=False, amp = False) -> None:
|
|
|
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
|
|
|
super().__init__()
|
|
|
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
|
|
self.amp = amp
|
|
|
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
+ self.amp_dtype = amp_dtype
|
|
|
|
|
|
def forward(self, x, **kwargs):
|
|
|
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
@@ -74,7 +75,7 @@ class VGG19(nn.Module):
|
|
|
return feats
|
|
|
|
|
|
class CNNandDinov2(nn.Module):
|
|
|
- def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
|
|
|
+ def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
|
|
|
super().__init__()
|
|
|
if dinov2_weights is None:
|
|
|
dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
|
|
|
@@ -94,7 +95,7 @@ class CNNandDinov2(nn.Module):
|
|
|
else:
|
|
|
self.cnn = VGG19(**cnn_kwargs)
|
|
|
self.amp = amp
|
|
|
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
+ self.amp_dtype = amp_dtype
|
|
|
if self.amp:
|
|
|
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
|
|
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|