|
|
@@ -1,5 +1,7 @@
|
|
|
-from PIL import Image
|
|
|
+import os
|
|
|
+os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
|
|
import torch
|
|
|
+from PIL import Image
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
from romatch.utils.utils import tensor_to_pil
|
|
|
@@ -7,7 +9,8 @@ from romatch.utils.utils import tensor_to_pil
|
|
|
from romatch import roma_outdoor
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
-
|
|
|
+if torch.backends.mps.is_available():
|
|
|
+ device = torch.device('mps')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
from argparse import ArgumentParser
|