Pārlūkot izejas kodu

add mps example

Johan Edstedt 1 gadu atpakaļ
vecāks
revīzija
90ee1bf773
1 mainītis faili ar 5 papildinājumiem un 2 dzēšanām
  1. 5 2
      demo/demo_match.py

+ 5 - 2
demo/demo_match.py

@@ -1,5 +1,7 @@
-from PIL import Image
+import os
+os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
 import torch
+from PIL import Image
 import torch.nn.functional as F
 import numpy as np
 from romatch.utils.utils import tensor_to_pil
@@ -7,7 +9,8 @@ from romatch.utils.utils import tensor_to_pil
 from romatch import roma_outdoor
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
+if torch.backends.mps.is_available():
+    device = torch.device('mps')
 
 if __name__ == "__main__":
     from argparse import ArgumentParser