|
|
@@ -11,16 +11,16 @@ import json
|
|
|
import wandb
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
-from roma.benchmarks import MegadepthDenseBenchmark
|
|
|
-from roma.datasets.megadepth import MegadepthBuilder
|
|
|
-from roma.datasets.scannet import ScanNetBuilder
|
|
|
-from roma.losses.robust_loss import RobustLosses
|
|
|
-from roma.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
|
|
|
-from roma.train.train import train_k_steps
|
|
|
-from roma.models.matcher import *
|
|
|
-from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
|
|
|
-from roma.models.encoders import *
|
|
|
-from roma.checkpointing import CheckPoint
|
|
|
+from romatch.benchmarks import MegadepthDenseBenchmark
|
|
|
+from romatch.datasets.megadepth import MegadepthBuilder
|
|
|
+from romatch.datasets.scannet import ScanNetBuilder
|
|
|
+from romatch.losses.robust_loss import RobustLosses
|
|
|
+from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
|
|
|
+from romatch.train.train import train_k_steps
|
|
|
+from romatch.models.matcher import *
|
|
|
+from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
|
|
|
+from romatch.models.encoders import *
|
|
|
+from romatch.checkpointing import CheckPoint
|
|
|
|
|
|
resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
|
|
|
|
|
|
@@ -174,14 +174,14 @@ def train(args):
|
|
|
rank = dist.get_rank()
|
|
|
print(f"Start running DDP on rank {rank}")
|
|
|
device_id = rank % torch.cuda.device_count()
|
|
|
- roma.LOCAL_RANK = device_id
|
|
|
+ romatch.LOCAL_RANK = device_id
|
|
|
torch.cuda.set_device(device_id)
|
|
|
|
|
|
resolution = args.train_resolution
|
|
|
wandb_log = not args.dont_log_wandb
|
|
|
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|
|
|
wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
|
|
|
- wandb.init(project="roma", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
|
|
+ wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
|
|
|
checkpoint_dir = "workspace/checkpoints/"
|
|
|
h,w = resolutions[resolution]
|
|
|
model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
|
|
|
@@ -189,11 +189,11 @@ def train(args):
|
|
|
global_step = 0
|
|
|
batch_size = args.gpu_batch_size
|
|
|
step_size = gpus*batch_size
|
|
|
- roma.STEP_SIZE = step_size
|
|
|
+ romatch.STEP_SIZE = step_size
|
|
|
|
|
|
N = (32 * 250000) # 250k steps of batch size 32
|
|
|
# checkpoint every
|
|
|
- k = 25000 // roma.STEP_SIZE
|
|
|
+ k = 25000 // romatch.STEP_SIZE
|
|
|
|
|
|
# Data
|
|
|
mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
|
|
|
@@ -233,20 +233,20 @@ def train(args):
|
|
|
alpha = 0.5,
|
|
|
c = 1e-4,)
|
|
|
parameters = [
|
|
|
- {"params": model.encoder.parameters(), "lr": roma.STEP_SIZE * 5e-6 / 8},
|
|
|
- {"params": model.decoder.parameters(), "lr": roma.STEP_SIZE * 1e-4 / 8},
|
|
|
+ {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
|
|
|
+ {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
|
|
|
]
|
|
|
optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
|
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
|
- optimizer, milestones=[(9*N/roma.STEP_SIZE)//10])
|
|
|
+ optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
|
|
|
megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
|
|
|
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
|
|
|
model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
|
|
|
- roma.GLOBAL_STEP = global_step
|
|
|
+ romatch.GLOBAL_STEP = global_step
|
|
|
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
|
|
|
grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
|
|
|
grad_clip_norm = 0.01
|
|
|
- for n in range(roma.GLOBAL_STEP, N, k * roma.STEP_SIZE):
|
|
|
+ for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
|
|
|
mega_sampler = torch.utils.data.WeightedRandomSampler(
|
|
|
mega_ws, num_samples = batch_size * k, replacement=False
|
|
|
)
|
|
|
@@ -269,15 +269,15 @@ def train(args):
|
|
|
num_workers=gpus * 8,
|
|
|
)
|
|
|
)
|
|
|
- for n_k in tqdm(range(n, n + 2 * k, 2),disable = roma.RANK > 0):
|
|
|
+ for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
|
|
|
train_k_steps(
|
|
|
n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
|
|
|
)
|
|
|
train_k_steps(
|
|
|
n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
|
|
|
)
|
|
|
- checkpointer.save(model, optimizer, lr_scheduler, roma.GLOBAL_STEP)
|
|
|
- wandb.log(megadense_benchmark.benchmark(model), step = roma.GLOBAL_STEP)
|
|
|
+ checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
|
|
|
+ wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
|
|
|
|
|
|
def test_scannet(model, name, resolution, sample_mode):
|
|
|
scannet_benchmark = ScanNetBenchmark("data/scannet")
|
|
|
@@ -291,7 +291,7 @@ if __name__ == "__main__":
|
|
|
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
|
|
os.environ["OMP_NUM_THREADS"] = "16"
|
|
|
|
|
|
- import roma
|
|
|
+ import romatch
|
|
|
parser = ArgumentParser()
|
|
|
parser.add_argument("--test", action='store_true')
|
|
|
parser.add_argument("--debug_mode", action='store_true')
|
|
|
@@ -301,7 +301,7 @@ if __name__ == "__main__":
|
|
|
parser.add_argument("--wandb_entity", required = False)
|
|
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
- roma.DEBUG_MODE = args.debug_mode
|
|
|
+ romatch.DEBUG_MODE = args.debug_mode
|
|
|
if not args.test:
|
|
|
train(args)
|
|
|
experiment_name = os.path.splitext(os.path.basename(__file__))[0]
|