Prechádzať zdrojové kódy

add some checks for image type in match

Johan Edstedt 1 rok pred
rodič
commit
64f20c7ee6
2 zmenil súbory, kde vykonal 18 pridanie a 4 odobranie
  1. 9 3
      romatch/models/matcher.py
  2. 9 1
      romatch/utils/utils.py

+ 9 - 3
romatch/models/matcher.py

@@ -11,7 +11,7 @@ from PIL import Image
 
 from romatch.utils import get_tuple_transform_ops
 from romatch.utils.local_correlation import local_correlation
-from romatch.utils.utils import cls_to_flow_refine, get_autocast_params
+from romatch.utils.utils import check_rgb, cls_to_flow_refine, get_autocast_params, check_not_i16
 from romatch.utils.kde import kde
 
 class ConvRefiner(nn.Module):
@@ -604,13 +604,19 @@ class RegressionMatcher(nn.Module):
 
         # Check if inputs are file paths or already loaded images
         if isinstance(im_A_input, (str, os.PathLike)):
-            im_A = Image.open(im_A_input).convert("RGB")
+            im_A = Image.open(im_A_input)
+            check_not_i16(im_A)
+            im_A = im_A.convert("RGB")
         else:
+            check_rgb(im_A_input)
             im_A = im_A_input
 
         if isinstance(im_B_input, (str, os.PathLike)):
-            im_B = Image.open(im_B_input).convert("RGB")
+            im_B = Image.open(im_B_input)
+            check_not_i16(im_B)
+            im_B = im_B.convert("RGB")
         else:
+            check_rgb(im_B_input)
             im_B = im_B_input
 
         symmetric = self.symmetric

+ 9 - 1
romatch/utils/utils.py

@@ -651,4 +651,12 @@ def get_autocast_params(device=None, enabled=False, dtype=None):
         enabled = False
         # mps is not supported
         autocast_device = "cpu"
-    return autocast_device, enabled, out_dtype
+    return autocast_device, enabled, out_dtype
+
+def check_not_i16(im):
+    if im.mode == "I;16":
+        raise NotImplementedError("Can't handle 16 bit images")
+
+def check_rgb(im):
+    if im.mode != "RGB":
+        raise NotImplementedError("Can't handle non-RGB images")