|
|
@@ -627,12 +627,14 @@ def get_grid(b, h, w, device):
|
|
|
|
|
|
def get_autocast_params(device=None, enabled=False, dtype=None):
|
|
|
if device is None:
|
|
|
- device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
-
|
|
|
+ autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ else:
|
|
|
+ #strip :X from device
|
|
|
+ autocast_device = str(device).split(":")[0]
|
|
|
if 'cuda' in str(device):
|
|
|
out_dtype = dtype
|
|
|
enabled = True
|
|
|
else:
|
|
|
out_dtype = torch.bfloat16
|
|
|
enabled = False
|
|
|
- return str(device), enabled, out_dtype
|
|
|
+ return autocast_device, enabled, out_dtype
|