smoke_test.py 472 B

123456789101112
  1. def test_smoke():
  2. import torch
  3. from romatch import roma_outdoor
  4. device = torch.device('cpu')
  5. model = roma_outdoor(device=device)
  6. assert model._get_device() == device
  7. assert model.w_resized == 560, f"Expected 560, got {model.w_resized}"
  8. assert model.h_resized == 560, f"Expected 560, got {model.h_resized}"
  9. assert model.upsample_res == (864, 864), f"Expected (864, 864), got {model.upsample_res}"
  10. if __name__ == "__main__":
  11. test_smoke()