|
|
@@ -1,7 +1,6 @@
|
|
|
from typing import Union
|
|
|
import torch
|
|
|
from .roma_models import roma_model, tiny_roma_v1_model
|
|
|
-from loguru import logger
|
|
|
|
|
|
|
|
|
weight_urls = {
|
|
|
@@ -37,20 +36,9 @@ def roma_outdoor(
|
|
|
upsample_res: Union[int, tuple[int, int]] = 864,
|
|
|
amp_dtype: torch.dtype = torch.float16,
|
|
|
symmetric=True,
|
|
|
- use_custom_corr=False,
|
|
|
+ use_custom_corr=True,
|
|
|
upsample_preds=True,
|
|
|
):
|
|
|
- if isinstance(coarse_res, int):
|
|
|
- coarse_res = (coarse_res, coarse_res)
|
|
|
- if isinstance(upsample_res, int):
|
|
|
- upsample_res = (upsample_res, upsample_res)
|
|
|
-
|
|
|
- if str(device) == "cpu":
|
|
|
- amp_dtype = torch.float32
|
|
|
-
|
|
|
- assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
|
|
- assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
|
|
-
|
|
|
if weights is None:
|
|
|
weights = torch.hub.load_state_dict_from_url(
|
|
|
weight_urls["romatch"]["outdoor"], map_location=device
|
|
|
@@ -68,10 +56,7 @@ def roma_outdoor(
|
|
|
amp_dtype=amp_dtype,
|
|
|
symmetric=symmetric,
|
|
|
use_custom_corr=use_custom_corr,
|
|
|
- )
|
|
|
- model.upsample_res = upsample_res
|
|
|
- logger.info(
|
|
|
- f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}"
|
|
|
+ upsample_res=upsample_res,
|
|
|
)
|
|
|
return model
|
|
|
|
|
|
@@ -83,15 +68,10 @@ def roma_indoor(
|
|
|
coarse_res: Union[int, tuple[int, int]] = 560,
|
|
|
upsample_res: Union[int, tuple[int, int]] = 864,
|
|
|
amp_dtype: torch.dtype = torch.float16,
|
|
|
+ symmetric=True,
|
|
|
+ use_custom_corr=True,
|
|
|
+ upsample_preds=True,
|
|
|
):
|
|
|
- if isinstance(coarse_res, int):
|
|
|
- coarse_res = (coarse_res, coarse_res)
|
|
|
- if isinstance(upsample_res, int):
|
|
|
- upsample_res = (upsample_res, upsample_res)
|
|
|
-
|
|
|
- assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
|
|
- assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
|
|
-
|
|
|
if weights is None:
|
|
|
weights = torch.hub.load_state_dict_from_url(
|
|
|
weight_urls["romatch"]["indoor"], map_location=device
|
|
|
@@ -102,14 +82,13 @@ def roma_indoor(
|
|
|
)
|
|
|
model = roma_model(
|
|
|
resolution=coarse_res,
|
|
|
- upsample_preds=True,
|
|
|
+ upsample_preds=upsample_preds,
|
|
|
weights=weights,
|
|
|
dinov2_weights=dinov2_weights,
|
|
|
device=device,
|
|
|
amp_dtype=amp_dtype,
|
|
|
+ symmetric=symmetric,
|
|
|
+ use_custom_corr=use_custom_corr,
|
|
|
+ upsample_res=upsample_res,
|
|
|
)
|
|
|
- model.upsample_res = upsample_res
|
|
|
- logger.info(
|
|
|
- f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}"
|
|
|
- )
|
|
|
- return model
|
|
|
+ return model
|