Bläddra i källkod

remove :... stuff from device in get_autocast_params, autocast doesnt like (#57)

Johan Edstedt 1 år sedan
förälder
incheckning
7541bd4cc7
1 ändrade filer med 5 tillägg och 3 borttagningar
  1. 5 3
      romatch/utils/utils.py

+ 5 - 3
romatch/utils/utils.py

@@ -627,12 +627,14 @@ def get_grid(b, h, w, device):
 
 def get_autocast_params(device=None, enabled=False, dtype=None):
     if device is None:
-        device = "cuda" if torch.cuda.is_available() else "cpu"
-
+        autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
+    else:
+        #strip :X from device
+        autocast_device = str(device).split(":")[0]
     if 'cuda' in str(device):
         out_dtype = dtype
         enabled = True
     else:
         out_dtype = torch.bfloat16
         enabled = False
-    return str(device), enabled, out_dtype
+    return autocast_device, enabled, out_dtype