|
|
@@ -291,7 +291,6 @@ def test_hpatches(model, name):
|
|
|
if __name__ == "__main__":
|
|
|
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
|
|
|
os.environ["OMP_NUM_THREADS"] = "16"
|
|
|
- torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
|
|
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
|
|
import roma
|
|
|
parser = ArgumentParser()
|