|
|
@@ -11,7 +11,7 @@ from PIL import Image
|
|
|
|
|
|
from romatch.utils import get_tuple_transform_ops
|
|
|
from romatch.utils.local_correlation import local_correlation
|
|
|
-from romatch.utils.utils import cls_to_flow_refine, get_autocast_params
|
|
|
+from romatch.utils.utils import check_rgb, cls_to_flow_refine, get_autocast_params, check_not_i16
|
|
|
from romatch.utils.kde import kde
|
|
|
|
|
|
class ConvRefiner(nn.Module):
|
|
|
@@ -604,13 +604,19 @@ class RegressionMatcher(nn.Module):
|
|
|
|
|
|
# Check if inputs are file paths or already loaded images
|
|
|
if isinstance(im_A_input, (str, os.PathLike)):
|
|
|
- im_A = Image.open(im_A_input).convert("RGB")
|
|
|
+ im_A = Image.open(im_A_input)
|
|
|
+ check_not_i16(im_A)
|
|
|
+ im_A = im_A.convert("RGB")
|
|
|
else:
|
|
|
+ check_rgb(im_A_input)
|
|
|
im_A = im_A_input
|
|
|
|
|
|
if isinstance(im_B_input, (str, os.PathLike)):
|
|
|
- im_B = Image.open(im_B_input).convert("RGB")
|
|
|
+ im_B = Image.open(im_B_input)
|
|
|
+ check_not_i16(im_B)
|
|
|
+ im_B = im_B.convert("RGB")
|
|
|
else:
|
|
|
+ check_rgb(im_B_input)
|
|
|
im_B = im_B_input
|
|
|
|
|
|
symmetric = self.symmetric
|