yichael 1 месяц назад
Родитель
Сommit
2c3d2e2966
100 измененных файлов с 2218 добавлено и 6402 удалено
  1. 5 2
      config.js
  2. 3 12
      electron/main.js
  3. 86 8
      nodejs/ai/ai.js
  4. 21 2
      nodejs/ai/config.js
  5. 57 8
      nodejs/ai/request/img2img.js
  6. 86 26
      nodejs/ai/request/img2text.js
  7. 11 5
      nodejs/ai/request/text2img.js
  8. 12 6
      nodejs/ai/request/text2text.js
  9. 88 0
      nodejs/ai/scripts/verify-img-center-gpt54.js
  10. 41 0
      nodejs/ef-compiler/actions/fun/IO/create-forder.js
  11. 0 0
      nodejs/ef-compiler/actions/fun/IO/read-txt.js
  12. 52 0
      nodejs/ef-compiler/actions/fun/IO/remove-forder.js
  13. 0 0
      nodejs/ef-compiler/actions/fun/IO/save-txt.js
  14. 3 3
      nodejs/ef-compiler/actions/fun/adb/adb-parser.js
  15. 0 0
      nodejs/ef-compiler/actions/fun/adb/click.js
  16. 8 4
      nodejs/ef-compiler/actions/fun/adb/input.js
  17. 0 0
      nodejs/ef-compiler/actions/fun/adb/keyevent.js
  18. 1 1
      nodejs/ef-compiler/actions/fun/adb/locate.js
  19. 1 1
      nodejs/ef-compiler/actions/fun/adb/press.js
  20. 0 0
      nodejs/ef-compiler/actions/fun/adb/scroll.js
  21. 3 2
      nodejs/ef-compiler/actions/fun/adb/send-img-to-device.js
  22. 0 0
      nodejs/ef-compiler/actions/fun/adb/string-press.js
  23. 0 0
      nodejs/ef-compiler/actions/fun/adb/swipe.js
  24. 0 0
      nodejs/ef-compiler/actions/fun/adb/utils.js
  25. 4 13
      nodejs/ef-compiler/actions/fun/ai/img2img.js
  26. 16 9
      nodejs/ef-compiler/actions/fun/ai/img2text.js
  27. 60 0
      nodejs/ef-compiler/actions/fun/ai/shared.js
  28. 4 13
      nodejs/ef-compiler/actions/fun/ai/text2img.js
  29. 3 6
      nodejs/ef-compiler/actions/fun/ai/text2text.js
  30. 40 14
      nodejs/ef-compiler/actions/fun/download-img.js
  31. 82 0
      nodejs/ef-compiler/actions/fun/fun-adb-json-bridge.js
  32. 28 1
      nodejs/ef-compiler/actions/fun/fun-node-registry.js
  33. 67 38
      nodejs/ef-compiler/actions/fun/fun-parser.js
  34. 0 120
      nodejs/ef-compiler/actions/fun/img-center-point-location.js
  35. 0 104
      nodejs/ef-compiler/actions/fun/img-cropping.js
  36. 0 0
      nodejs/ef-compiler/actions/fun/img/img-bounding-box-location.js
  37. 1035 0
      nodejs/ef-compiler/actions/fun/img/img-center-point-location.js
  38. 140 0
      nodejs/ef-compiler/actions/fun/img/img-cropping.js
  39. 101 0
      nodejs/ef-compiler/actions/fun/img/img-scale.js
  40. 17 2
      nodejs/ef-compiler/actions/fun/json/json-to-arr.js
  41. 18 13
      nodejs/ef-compiler/actions/fun/ocr.js
  42. 19 10
      nodejs/ef-compiler/ef-compiler.js
  43. 45 22
      nodejs/ef-compiler/sequence-runner.js
  44. 2 0
      nodejs/ef-compiler/variable-parser.js
  45. 11 3
      nodejs/ef-compiler/workflow-json-parser.js
  46. 24 0
      nodejs/python-exe-from-config.js
  47. 11 5
      nodejs/run-process.js
  48. 5 2
      package/pack-resources/config.js
  49. 8 3
      package/pack-resources/electron-pack-win.js
  50. 0 25
      python/RoMa/.github/actions/uv-build/action.yml
  51. 0 22
      python/RoMa/.github/workflows/build.yml
  52. 0 39
      python/RoMa/.github/workflows/publish.yml
  53. 0 12
      python/RoMa/.gitignore
  54. 0 1
      python/RoMa/.python-version
  55. 0 21
      python/RoMa/LICENSE
  56. 0 163
      python/RoMa/README.md
  57. BIN
      python/RoMa/assets/sacre_coeur_A.jpg
  58. BIN
      python/RoMa/assets/sacre_coeur_B.jpg
  59. BIN
      python/RoMa/assets/toronto_A.jpg
  60. BIN
      python/RoMa/assets/toronto_B.jpg
  61. 0 2
      python/RoMa/data/.gitignore
  62. 0 47
      python/RoMa/demo/demo_3D_effect.py
  63. 0 34
      python/RoMa/demo/demo_fundamental.py
  64. 0 50
      python/RoMa/demo/demo_match.py
  65. 0 43
      python/RoMa/demo/demo_match_opencv_sift.py
  66. 0 77
      python/RoMa/demo/demo_match_tiny.py
  67. 0 2
      python/RoMa/demo/gif/.gitignore
  68. 0 59
      python/RoMa/experiments/eval_roma_outdoor.py
  69. 0 84
      python/RoMa/experiments/eval_tiny_roma_v1_outdoor.py
  70. 0 322
      python/RoMa/experiments/roma_indoor.py
  71. 0 308
      python/RoMa/experiments/train_roma_outdoor.py
  72. 0 498
      python/RoMa/experiments/train_tiny_roma_v1_outdoor.py
  73. 0 42
      python/RoMa/pyproject.toml
  74. 0 8
      python/RoMa/romatch/__init__.py
  75. 0 6
      python/RoMa/romatch/benchmarks/__init__.py
  76. 0 113
      python/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py
  77. 0 105
      python/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py
  78. 0 116
      python/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py
  79. 0 116
      python/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py
  80. 0 143
      python/RoMa/romatch/benchmarks/scannet_benchmark.py
  81. 0 1
      python/RoMa/romatch/checkpointing/__init__.py
  82. 0 60
      python/RoMa/romatch/checkpointing/checkpoint.py
  83. 0 2
      python/RoMa/romatch/datasets/__init__.py
  84. 0 232
      python/RoMa/romatch/datasets/megadepth.py
  85. 0 160
      python/RoMa/romatch/datasets/scannet.py
  86. 0 1
      python/RoMa/romatch/losses/__init__.py
  87. 0 161
      python/RoMa/romatch/losses/robust_loss.py
  88. 0 160
      python/RoMa/romatch/losses/robust_loss_tiny_roma.py
  89. 0 1
      python/RoMa/romatch/models/__init__.py
  90. 0 68
      python/RoMa/romatch/models/encoders.py
  91. 0 1001
      python/RoMa/romatch/models/matcher.py
  92. 0 110
      python/RoMa/romatch/models/model_zoo/__init__.py
  93. 0 399
      python/RoMa/romatch/models/model_zoo/roma_models.py
  94. 0 304
      python/RoMa/romatch/models/tiny.py
  95. 0 48
      python/RoMa/romatch/models/transformer/__init__.py
  96. 0 359
      python/RoMa/romatch/models/transformer/dinov2.py
  97. 0 12
      python/RoMa/romatch/models/transformer/layers/__init__.py
  98. 0 96
      python/RoMa/romatch/models/transformer/layers/attention.py
  99. 0 252
      python/RoMa/romatch/models/transformer/layers/block.py
  100. 0 59
      python/RoMa/romatch/models/transformer/layers/dino_head.py

+ 5 - 2
config.js

@@ -9,10 +9,12 @@ const projectRoot = (typeof __dirname !== 'undefined' && __dirname.includes('app
 // Node.js:便携版固定 nodejs/node
 const nodeDir = path.join(projectRoot, 'nodejs', 'node')
 
+const isWin = process.platform === 'win32'
+
 // Python:嵌入式解释器在 python/py(依赖装入 py/Lib/site-packages,不使用虚拟环境)
 const pythonDir = path.join(projectRoot, 'python', 'py')
-
-const isWin = process.platform === 'win32'
+/** 解释器可执行文件:业务代码请通过 nodejs/python-exe-from-config 或此字段获取,禁止自行拼路径 */
+const pythonExePath = path.join(pythonDir, isWin ? 'python.exe' : 'python')
 
 module.exports = {
   // 项目根目录:开发时为仓库根,打包后由 package/pack-resources/electron-pack-win 流程写入 exe 同目录的 config.js
@@ -38,6 +40,7 @@ module.exports = {
     path: pythonDir,
   },
   pythonDir,
+  pythonExePath,
 
   adbPath: {
     path: path.join(projectRoot, 'lib/scrcpy-adb/adb.exe'),

+ 3 - 12
electron/main.js

@@ -31,6 +31,7 @@ try {
   }
 }
 const isDev = process.env.NODE_ENV === 'development' || !app.isPackaged
+const { getPythonExeFromConfig } = require('../nodejs/python-exe-from-config.js')
 
 // Node 可执行文件路径:统一从根目录 config.js 的 nodejsPath 读取(源码与打包一致)
 function getNodeExecutable() {
@@ -289,20 +290,10 @@ ipcMain.handle('check-scrcpy-running', async () => {
   }
 })
 
-// Execute Python script(python 路径与 config:优先 env/Scripts、py、目录下 python.exe
+// Execute Python script(解释器路径仅来自根目录 config.js,见 nodejs/python-exe-from-config
 ipcMain.handle('run-python-script', async (event, scriptName, ...parameters) => {
   return new Promise((resolve, reject) => {
-    let pythonPath = 'python'
-    if (config.pythonPath?.path) {
-      const base = config.pythonPath.path
-      const candidates = [
-        path.join(base, 'python.exe'),
-        path.join(base, 'python')
-      ]
-      for (const p of candidates) {
-        if (fs.existsSync(p)) { pythonPath = p; break }
-      }
-    }
+    const pythonPath = getPythonExeFromConfig(config)
     
     const scriptPath = path.join(unpackedRoot, 'python', 'scripts', `${scriptName}.py`)
     

+ 86 - 8
nodejs/ai/ai.js

@@ -6,6 +6,14 @@ const REQUEST_TIMEOUT_IMG_MS = 180000;
 module.exports.REQUEST_TIMEOUT_MS = REQUEST_TIMEOUT_MS;
 module.exports.REQUEST_TIMEOUT_IMG_MS = REQUEST_TIMEOUT_IMG_MS;
 
+/**
+ * run(action, ...args, options?)
+ * options: { model?: string, timeoutMs?: number } 可选;未传 model 时各 request 使用 config 默认。
+ * 例:run('text2text', '你好', { model: 'gpt-4o-mini' })
+ *     run('img2text', prompt, imageUrl, { model: 'gpt-4o', timeoutMs: 60000 })
+ *     run('img2text', prompt, [urlScreen, urlTpl], { model: 'gpt-5.4' })  // 多图,如 img-center ROI
+ */
+
 function request(path, body, timeoutMs) {
   const baseUrl = (config.BASE_URL || '').replace(/\/$/, '');
   const url = path.startsWith('http') ? path : `${baseUrl}/${path.replace(/^\//, '')}`;
@@ -34,6 +42,34 @@ function request(path, body, timeoutMs) {
     });
 }
 
+/** images/edits 等:multipart,勿设 Content-Type(由 fetch 带 boundary) */
+function requestMultipart (subPath, formData, timeoutMs) {
+  const baseUrl = (config.BASE_URL || '').replace(/\/$/, '');
+  const url = subPath.startsWith('http') ? subPath : `${baseUrl}/${String(subPath).replace(/^\//, '')}`;
+  const controller = new AbortController();
+  const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
+  return fetch(url, {
+    method: 'POST',
+    headers: {
+      Authorization: `Bearer ${config.API_KEY || ''}`,
+    },
+    body: formData,
+    signal: controller.signal,
+  })
+    .then((res) => {
+      clearTimeout(timeoutId);
+      return res.json().catch(() => ({})).then((data) => {
+        if (!res.ok) throw new Error(data.error?.message || data.error || `HTTP ${res.status}`);
+        return data;
+      });
+    })
+    .catch((e) => {
+      clearTimeout(timeoutId);
+      if (e.name === 'AbortError') throw new Error(`请求超时 (${timeoutMs / 1000} 秒)`);
+      throw e;
+    });
+}
+
 function doubaoRequest(path, body, timeoutMs) {
   const baseUrl = (config.DOUBAO_BASE_URL || '').replace(/\/$/, '');
   const url = path.startsWith('http') ? path : `${baseUrl}/${path.replace(/^\//, '')}`;
@@ -69,15 +105,32 @@ const img2img = require('./request/img2img');
 
 const REQ = { text2text, img2text, text2img, img2img };
 
-async function run(action, ...args) {
+/** 最后一个参数可为 { model?: string, timeoutMs?: number },未传 model 则用各 request 内默认 */
+function popRunOptions (args) {
+  const last = args[args.length - 1];
+  if (
+    last != null &&
+    typeof last === 'object' &&
+    !Array.isArray(last) &&
+    (Object.prototype.hasOwnProperty.call(last, 'model') ||
+      Object.prototype.hasOwnProperty.call(last, 'timeoutMs'))
+  ) {
+    return { options: args.pop(), rest: args };
+  }
+  return { options: {}, rest: args };
+}
+
+async function run (action, ...args) {
+  const { options, rest } = popRunOptions(args);
+  const callArgs = rest;
   if (action.startsWith('doubao_')) {
-    return await doRequest(action.substring(7), true, args);
+    return await doRequest(action.substring(7), true, callArgs, options);
   }
-  return await doRequest(action, false, args);
+  return await doRequest(action, false, callArgs, options);
 }
 
 /** 从 request 模块取参数,在 ai.js 内发请求 */
-async function doRequest(action, isDoubao, args) {
+async function doRequest (action, isDoubao, args, options = {}) {
   const req = REQ[action];
   if (!req) return { success: false, error: 'Unknown action: ' + action };
   if (isDoubao && (!config.DOUBAO_MODEL || !config.DOUBAO_MODEL.trim())) {
@@ -85,12 +138,37 @@ async function doRequest(action, isDoubao, args) {
   }
 
   const path = req.path;
-  const timeoutMs = req.timeoutMs;
-  const body = action === 'text2text'
-    ? (isDoubao ? req.getDoubaoBody(args[0]) : req.getBody(args[0]))
-    : (isDoubao ? req.getDoubaoBody(args[0], args[1]) : req.getBody(args[0], args[1]));
+  const modelOverride =
+    options.model != null && String(options.model).trim() !== ''
+      ? String(options.model).trim()
+      : undefined;
+  const timeoutMs =
+    options.timeoutMs != null && Number.isFinite(Number(options.timeoutMs))
+      ? Number(options.timeoutMs)
+      : req.timeoutMs;
 
   try {
+    if (action === 'img2img' && !isDoubao && typeof req.buildFormData === 'function') {
+      const form = req.buildFormData(args[0], args[1], modelOverride);
+      const data = await requestMultipart(path, form, timeoutMs);
+      return { success: true, data };
+    }
+
+    let body;
+    if (action === 'text2text') {
+      body = isDoubao
+        ? req.getDoubaoBody(args[0], modelOverride)
+        : req.getBody(args[0], modelOverride);
+    } else if (action === 'text2img') {
+      body = isDoubao
+        ? req.getDoubaoBody(args[0], args[1], modelOverride)
+        : req.getBody(args[0], args[1], modelOverride);
+    } else {
+      body = isDoubao
+        ? req.getDoubaoBody(args[0], args[1], modelOverride)
+        : req.getBody(args[0], args[1], modelOverride);
+    }
+
     const data = isDoubao
       ? await doubaoRequest(path, body, timeoutMs)
       : await request(path, body, timeoutMs);

+ 21 - 2
nodejs/ai/config.js

@@ -1,14 +1,30 @@
 // ---------- 一般配置(OpenAI 兼容) ----------
-const API_KEY = process.env.API_KEY || 'sk-j32LgDixK6pfESYGfJtgc2Tzlmszx5NZhSH0sOzpLQkYuKek';
-const BASE_URL = process.env.BASE_URL || 'https://api.chatanywhere.tech/v1';
+const API_KEY =
+  process.env.OPENAI_API_KEY ||
+  process.env.VLM_API_KEY ||
+  process.env.API_KEY ||
+  'sk-j32LgDixK6pfESYGfJtgc2Tzlmszx5NZhSH0sOzpLQkYuKek';
+const BASE_URL =
+  (process.env.OPENAI_API_URL || process.env.BASE_URL || 'https://api.chatanywhere.tech/v1').replace(
+    /\/$/,
+    ''
+  );
 const MODEL_NAME = process.env.MODEL_NAME || 'gpt-4.1';
 
+/** 图生文(img2text):默认用强视觉模型;勿与纯文本 MODEL_NAME 混用时可单独设 IMG2TEXT_MODEL */
+const IMG2TEXT_MODEL = process.env.IMG2TEXT_MODEL || 'gpt-4o';
+
+/** 截图模板 ROI 视觉模型:默认用多模态旗舰档;可用环境变量 IMG_CENTER_MODEL 覆盖(如网关无 gpt-5.4 可改为 gpt-4o) */
+const IMG_CENTER_MODEL = process.env.IMG_CENTER_MODEL || 'gpt-5.4';
+
 // ---------- 豆包配置(火山引擎) ----------
 // 需在 火山引擎控制台 → 模型推理 → 模型接入 创建接入点,将 endpoint ID 填到 DOUBAO_MODEL(或设置环境变量 DOUBAO_MODEL)
 const DOUBAO_BASE_URL = process.env.DOUBAO_BASE_URL || 'https://ark.cn-beijing.volces.com/api/v3';
 const DOUBAO_API_KEY = process.env.DOUBAO_API_KEY || 'f13f97be-c990-4a43-8d17-4816357f2e47'; // id: api-key-yichael
 // 模型名称(官方示例):doubao-seed-2-0-pro-260215 对应控制台「Doubao-Seed-2.0-pro」已开通
 const DOUBAO_MODEL = process.env.DOUBAO_MODEL || 'doubao-seed-2-0-pro-260215';
+/** img-center 第二次 ROI(置信度不足):豆包接入点,默认同 DOUBAO_MODEL,可单独设更强识图端点 */
+const IMG_CENTER_DOUBAO_ROI_MODEL = process.env.IMG_CENTER_DOUBAO_ROI_MODEL || DOUBAO_MODEL;
 // 豆包文生图模型:需在火山引擎控制台开通「图像生成」类模型并创建接入点,将 endpoint ID 填于此(或环境变量 DOUBAO_IMAGE_MODEL)
 const DOUBAO_IMAGE_MODEL = process.env.DOUBAO_IMAGE_MODEL || '';
 
@@ -16,6 +32,9 @@ module.exports = {
   API_KEY,
   BASE_URL,
   MODEL_NAME,
+  IMG2TEXT_MODEL,
+  IMG_CENTER_MODEL,
+  IMG_CENTER_DOUBAO_ROI_MODEL,
   DOUBAO_BASE_URL,
   DOUBAO_API_KEY,
   DOUBAO_MODEL,

+ 57 - 8
nodejs/ai/request/img2img.js

@@ -1,26 +1,75 @@
+const fs = require('fs');
+const path = require('path');
 const config = require('../config');
 
 const PATH = 'images/edits';
 const TIMEOUT_MS = 180000;
 
-// 普通 AI 请求参数
-function getBody(prompt, imageUrl) {
+/**
+ * @param {string} imageUrl data:image/...;base64,... 或本地绝对/相对路径
+ */
+function parseImageInput (imageUrl) {
+  const s = String(imageUrl || '').trim();
+  const m = s.match(/^data:([^;]+);base64,([\s\S]+)$/i);
+  if (m) {
+    return {
+      buffer: Buffer.from(m[2], 'base64'),
+      filename: 'image.png',
+      mime: (m[1] || 'image/png').split(';')[0].trim() || 'image/png',
+    };
+  }
+  if (fs.existsSync(s)) {
+    const buf = fs.readFileSync(s);
+    const base = path.basename(s) || 'image.png';
+    return { buffer: buf, filename: base, mime: 'image/png' };
+  }
+  throw new Error('img2img: image 需为 data URL 或存在的本地图片路径');
+}
+
+/**
+ * OpenAI 官方 /v1/images/edits 要求 multipart/form-data(非 JSON)
+ * 字段:prompt, image(文件), n, size;可选 model
+ */
+function buildFormData (prompt, imageUrl, modelOverride) {
+  const form = new FormData();
+  form.append('prompt', String(prompt || ''));
+  form.append('n', '1');
+  form.append('size', '1024x1024');
+  const { buffer, filename, mime } = parseImageInput(imageUrl);
+  const blob = new Blob([buffer], { type: mime });
+  form.append('image', blob, filename.endsWith('.png') || filename.endsWith('.jpg') || filename.endsWith('.webp') ? filename : `${filename}.png`);
+  if (modelOverride && String(modelOverride).trim()) {
+    form.append('model', String(modelOverride).trim());
+  }
+  return form;
+}
+
+/** @deprecated 仅结构占位;真实请求请用 buildFormData + multipart */
+function getBody (prompt, imageUrl) {
   return {
     prompt,
-    image: imageUrl,
+    image: typeof imageUrl === 'string' && imageUrl.length > 80 ? `${imageUrl.slice(0, 40)}…` : imageUrl,
     n: 1,
-    size: '1024x1024'
+    size: '1024x1024',
   };
 }
 
-// 豆包请求参数(与普通共用结构,model 若豆包需要可再扩展)
-function getDoubaoBody(prompt, imageUrl) {
+function getDoubaoBody (prompt, imageUrl, modelOverride) {
+  const model =
+    (modelOverride && String(modelOverride).trim()) || config.DOUBAO_MODEL;
   return {
+    model,
     prompt,
     image: imageUrl,
     n: 1,
-    size: '1024x1024'
+    size: '1024x1024',
   };
 }
 
-module.exports = { path: PATH, getBody, getDoubaoBody, timeoutMs: TIMEOUT_MS };
+module.exports = {
+  path: PATH,
+  getBody,
+  getDoubaoBody,
+  buildFormData,
+  timeoutMs: TIMEOUT_MS,
+};

+ 86 - 26
nodejs/ai/request/img2text.js

@@ -1,38 +1,98 @@
 const config = require('../config');
 
 const PATH = 'chat/completions';
-const TIMEOUT_MS = 120000;
+/** 读图较慢,与 ai.js REQUEST_TIMEOUT_IMG_MS 对齐 */
+const TIMEOUT_MS = 180000;
 
-// 普通 AI 请求参数
-function getBody(prompt, imageUrl) {
+/**
+ * 将单 URL 或 URL 数组规范为字符串数组(OpenAI 兼容:同一 user 消息里多段 image_url = 多图)
+ */
+function normalizeImageUrls (imageUrlOrUrls) {
+  if (Array.isArray(imageUrlOrUrls)) {
+    return imageUrlOrUrls.map((u) => String(u || '').trim()).filter(Boolean);
+  }
+  const s = String(imageUrlOrUrls || '').trim();
+  return s ? [s] : [];
+}
+
+function buildUserContent (prompt, imageUrlOrUrls) {
+  const urls = normalizeImageUrls(imageUrlOrUrls);
+  const content = [{ type: 'text', text: String(prompt || '') }];
+  for (const url of urls) {
+    content.push({
+      type: 'image_url',
+      image_url: { url, detail: 'high' },
+    });
+  }
+  return content;
+}
+
+function resolveModel (modelOverride) {
+  return (
+    (modelOverride && String(modelOverride).trim()) ||
+    (config.IMG2TEXT_MODEL && String(config.IMG2TEXT_MODEL).trim()) ||
+    config.MODEL_NAME ||
+    'gpt-4o'
+  );
+}
+
+/** 截图 ROI(img-center):默认读 config.IMG_CENTER_MODEL,与 img2text 默认解耦 */
+function resolveImgCenterModel (modelOverride) {
+  return (
+    (modelOverride && String(modelOverride).trim()) ||
+    (config.IMG_CENTER_MODEL && String(config.IMG_CENTER_MODEL).trim()) ||
+    (config.IMG2TEXT_MODEL && String(config.IMG2TEXT_MODEL).trim()) ||
+    config.MODEL_NAME ||
+    'gpt-5.4'
+  );
+}
+
+/**
+ * OpenAI 兼容 POST /v1/chat/completions
+ * @param {string} prompt
+ * @param {string|string[]} imageUrlOrUrls 单张 data URL / https URL,或多张(多附件)
+ * @param {string} [modelOverride]
+ */
+function getBody (prompt, imageUrlOrUrls, modelOverride) {
+  const urls = normalizeImageUrls(imageUrlOrUrls);
+  if (urls.length === 0) {
+    throw new Error('img2text: 至少提供一张图片(字符串 URL 或 URL 数组)');
+  }
+  const model = resolveModel(modelOverride);
+  const content = buildUserContent(prompt, imageUrlOrUrls);
   return {
-    model: config.MODEL_NAME || 'gpt-4o',
-    messages: [{
-      role: 'user',
-      content: [
-        { type: 'text', text: prompt },
-        { type: 'image_url', image_url: { url: imageUrl, detail: 'high' } }
-      ]
-    }],
-    max_tokens: 300,
-    stream: false
+    model,
+    messages: [{ role: 'user', content }],
+    // 勿与 max_completion_tokens 同时传:部分聚合网关(如 ChatAnywhere)会报错
+    max_tokens: 1024,
+    stream: false,
   };
 }
 
-// 豆包请求参数
-function getDoubaoBody(prompt, imageUrl) {
+function getDoubaoBody (prompt, imageUrlOrUrls, modelOverride) {
+  const urls = normalizeImageUrls(imageUrlOrUrls);
+  if (urls.length === 0) {
+    throw new Error('img2text: 至少提供一张图片');
+  }
+  const model =
+    (modelOverride && String(modelOverride).trim()) || config.DOUBAO_MODEL;
+  const content = buildUserContent(prompt, imageUrlOrUrls);
   return {
-    model: config.DOUBAO_MODEL,
-    messages: [{
-      role: 'user',
-      content: [
-        { type: 'text', text: prompt },
-        { type: 'image_url', image_url: { url: imageUrl, detail: 'high' } }
-      ]
-    }],
-    max_tokens: 300,
-    stream: false
+    model,
+    messages: [{ role: 'user', content }],
+    max_tokens: 1024,
+    stream: false,
   };
 }
 
-module.exports = { path: PATH, getBody, getDoubaoBody, timeoutMs: TIMEOUT_MS };
+module.exports = {
+  path: PATH,
+  getBody,
+  getDoubaoBody,
+  /** img-center:无 override 时用 config.IMG_CENTER_MODEL(环境变量 IMG_CENTER_MODEL) */
+  resolveImgCenterModel,
+  timeoutMs: TIMEOUT_MS,
+  /** 测试/调试:看 content 结构 */
+  buildUserContent,
+  normalizeImageUrls,
+};

+ 11 - 5
nodejs/ai/request/text2img.js

@@ -3,10 +3,12 @@ const config = require('../config');
 const PATH = 'images/generations';
 const TIMEOUT_MS = 180000;
 
-// 普通 AI 请求参数
-function getBody(prompt, outputPath) {
+// 普通 AI 请求参数(modelOverride 有值则优先,否则 dall-e-2)
+function getBody (prompt, outputPath, modelOverride) {
+  const model =
+    (modelOverride && String(modelOverride).trim()) || 'dall-e-2';
   return {
-    model: 'dall-e-2',
+    model,
     prompt,
     n: 1,
     size: '1024x1024',
@@ -15,8 +17,12 @@ function getBody(prompt, outputPath) {
 }
 
 // 豆包文生图请求参数(优先用 DOUBAO_IMAGE_MODEL,未配置则用 DOUBAO_MODEL)
-function getDoubaoBody(prompt, outputPath) {
-  const model = (config.DOUBAO_IMAGE_MODEL && config.DOUBAO_IMAGE_MODEL.trim()) ? config.DOUBAO_IMAGE_MODEL.trim() : config.DOUBAO_MODEL;
+function getDoubaoBody (prompt, outputPath, modelOverride) {
+  const model =
+    (modelOverride && String(modelOverride).trim()) ||
+    (config.DOUBAO_IMAGE_MODEL && config.DOUBAO_IMAGE_MODEL.trim()
+      ? config.DOUBAO_IMAGE_MODEL.trim()
+      : config.DOUBAO_MODEL);
   return {
     model,
     prompt,

+ 12 - 6
nodejs/ai/request/text2text.js

@@ -3,19 +3,25 @@ const config = require('../config');
 const PATH = 'chat/completions';
 const TIMEOUT_MS = 120000;
 
-// 普通 AI 请求参数
-function getBody(prompt) {
+// 普通 AI 请求参数(modelOverride 有值则优先,否则 config.MODEL_NAME,再否则默认)
+function getBody (prompt, modelOverride) {
+  const model =
+    (modelOverride && String(modelOverride).trim()) ||
+    config.MODEL_NAME ||
+    'gpt-4.1';
   return {
-    model: config.MODEL_NAME || 'gpt-4.1',
+    model,
     messages: [{ role: 'user', content: prompt }],
     stream: false
   };
 }
 
-// 豆包请求参数
-function getDoubaoBody(prompt) {
+// 豆包请求参数(可选 modelOverride 覆盖接入点 ID)
+function getDoubaoBody (prompt, modelOverride) {
+  const model =
+    (modelOverride && String(modelOverride).trim()) || config.DOUBAO_MODEL;
   return {
-    model: config.DOUBAO_MODEL,
+    model,
     messages: [{ role: 'user', content: prompt }],
     stream: false
   };

+ 88 - 0
nodejs/ai/scripts/verify-img-center-gpt54.js

@@ -0,0 +1,88 @@
+/**
+ * 与 img-center-point-location 相同栈:nodejs/ai ai.run('img2text', prompt, [url1,url2], { model })。
+ * 默认模型用 nodejs/ai/config IMG_CENTER_MODEL(可用 VERIFY_FORCE_MODEL 覆盖)。
+ */
+const path = require('path')
+const fs = require('fs')
+
+const aiModule = require(path.join(__dirname, '..', 'ai.js'))
+const aiConfig = require(path.join(__dirname, '..', 'config.js'))
+const img2textReq = require(path.join(__dirname, '..', 'request', 'img2text.js'))
+
+function fileToDataUrlPng (absPath) {
+  const buf = fs.readFileSync(absPath)
+  return `data:image/png;base64,${buf.toString('base64')}`
+}
+
+function defaultFixtureDir () {
+  const projectRoot = path.resolve(__dirname, '..', '..', '..')
+  const d = path.join(
+    projectRoot,
+    'static',
+    'process',
+    'GenerateNote',
+    'tmp',
+    'img-center-1774080885011'
+  )
+  return fs.existsSync(path.join(d, 'screenshot.png')) ? d : null
+}
+
+async function main () {
+  const model = (
+    process.env.VERIFY_FORCE_MODEL ||
+    img2textReq.resolveImgCenterModel(undefined)
+  ).trim()
+
+  const def = defaultFixtureDir()
+  const screenPath =
+    process.argv[2] ||
+    (def && path.join(def, 'screenshot.png')) ||
+    path.join(__dirname, 'fixture-verify.png')
+  const tplPath =
+    process.argv[3] ||
+    (def && path.join(def, 'template.png')) ||
+    screenPath
+
+  if (!fs.existsSync(screenPath)) {
+    console.error('缺少截图:', screenPath)
+    process.exit(1)
+  }
+  if (!fs.existsSync(tplPath)) {
+    console.error('缺少模板:', tplPath)
+    process.exit(1)
+  }
+
+  const verifyPrompt = `你收到两张图(与 img-center ROI 流程相同的多模态调用)。
+请严格只输出纯文本三行,不要 markdown:
+第1行:你认为本次 API 实际为你分配的模型名称或系列(若不确定写 未知)。
+第2行:若你认为是 GPT-5.4 系列则写 YES,否则写 NO。
+第3行:简短说明依据(不超过 40 字)。
+
+注意:以服务端返回的 JSON 里 model 字段为准;你的自称可能不准。`
+
+  console.log('IMG_CENTER_MODEL (resolveImgCenterModel default):', img2textReq.resolveImgCenterModel(undefined))
+  console.log('request options.model:', model)
+  const r = await aiModule.run('img2text', verifyPrompt, [fileToDataUrlPng(screenPath), fileToDataUrlPng(tplPath)], {
+    model,
+    timeoutMs: 300000,
+  })
+  if (!r.success) {
+    console.error(r.error || '请求失败')
+    process.exit(1)
+  }
+  const resp = r.data
+  console.log('response.model (authoritative):', resp.model)
+  const content = resp?.choices?.[0]?.message?.content
+  console.log('assistant (self-report, not authoritative):\n---\n', content, '\n---')
+  const ok =
+    String(resp.model || '')
+      .toLowerCase()
+      .includes('gpt-5.4') || String(resp.model || '').toLowerCase().includes('5.4')
+  console.log('response.model looks like gpt-5.4:', ok ? 'YES' : 'NO (check gateway mapping)')
+  console.log('BASE_URL:', aiConfig.BASE_URL)
+}
+
+main().catch((e) => {
+  console.error(e)
+  process.exit(1)
+})

+ 41 - 0
nodejs/ef-compiler/actions/fun/IO/create-forder.js

@@ -0,0 +1,41 @@
+/**
+ * fun 结点:create-folder(脚本名 create-forder 为历史拼写)
+ * inVars[0] 或字段 path:目录路径(相对当前流程目录或绝对路径)
+ * 使用 fs.mkdirSync(..., { recursive: true })
+ */
+
+const path = require('path')
+const fs = require('fs')
+
+function buildAbsolutePath (p, folderPath) {
+  if (p == null || p === '') return null
+  const s = typeof p === 'string' ? p.trim() : String(p).trim()
+  if (!s) return null
+  if (path.isAbsolute(s) || /^[A-Za-z]:/.test(s)) return path.normalize(s)
+  return folderPath ? path.join(folderPath, s) : path.resolve(s)
+}
+
+/**
+ * @param {{ path?: string, dirPath?: string, folderPath?: string }} input
+ */
+async function executeCreateFolder ({ path: p, dirPath, folderPath }) {
+  const rel = p != null && String(p).trim() !== '' ? p : dirPath
+  if (rel == null || String(rel).trim() === '') {
+    return { success: false, error: 'create-folder 缺少 path(inVars[0] 或字段 path)' }
+  }
+  const abs = buildAbsolutePath(String(rel).trim(), folderPath)
+  if (!abs) return { success: false, error: 'create-folder 路径无效' }
+  try {
+    fs.mkdirSync(abs, { recursive: true })
+  } catch (e) {
+    return { success: false, error: e && e.message ? e.message : String(e) }
+  }
+  return {
+    success: true,
+    path: abs,
+    value: abs,
+    result: abs,
+  }
+}
+
+module.exports = { executeCreateFolder }

+ 0 - 0
nodejs/ef-compiler/actions/fun/read-txt.js → nodejs/ef-compiler/actions/fun/IO/read-txt.js


+ 52 - 0
nodejs/ef-compiler/actions/fun/IO/remove-forder.js

@@ -0,0 +1,52 @@
+/**
+ * fun 结点:remove-folder(脚本名 remove-forder 为历史拼写)
+ * inVars[0] 或字段 path:目录路径(相对当前流程目录或绝对路径)
+ * 默认递归删除目录及内容(等同 rm -rf)。若仅需删除空目录,设 recursive 为 0 / false / no / empty-only
+ */
+
+const path = require('path')
+const fs = require('fs')
+
+function buildAbsolutePath (p, folderPath) {
+  if (p == null || p === '') return null
+  const s = typeof p === 'string' ? p.trim() : String(p).trim()
+  if (!s) return null
+  if (path.isAbsolute(s) || /^[A-Za-z]:/.test(s)) return path.normalize(s)
+  return folderPath ? path.join(folderPath, s) : path.resolve(s)
+}
+
+/** 未指定时默认 true(非空目录也可删);0 / false / no / empty-only 则仅删空目录 */
+function parseRecursive (raw) {
+  if (raw == null || raw === '') return true
+  const t = String(raw).trim().toLowerCase()
+  if (t === '') return true
+  if (t === '0' || t === 'false' || t === 'no' || t === 'empty-only') return false
+  return true
+}
+
+/**
+ * @param {{ path?: string, dirPath?: string, recursive?: string, folderPath?: string }} input
+ */
+async function executeRemoveFolder ({ path: p, dirPath, recursive, folderPath }) {
+  const rel = p != null && String(p).trim() !== '' ? p : dirPath
+  if (rel == null || String(rel).trim() === '') {
+    return { success: false, error: 'remove-folder 缺少 path(inVars[0] 或字段 path)' }
+  }
+  const abs = buildAbsolutePath(String(rel).trim(), folderPath)
+  if (!abs) return { success: false, error: 'remove-folder 路径无效' }
+  const rec = parseRecursive(recursive)
+  try {
+    if (!fs.existsSync(abs)) return { success: true, skipped: true, path: abs }
+    const st = fs.statSync(abs)
+    if (!st.isDirectory()) return { success: false, error: 'remove-folder 目标不是目录' }
+    fs.rmSync(abs, { recursive: rec, force: false })
+  } catch (e) {
+    if (!rec && (e.code === 'ENOTEMPTY' || e.code === 'EISDIR')) {
+      return { success: false, error: '目录非空或无法以非递归方式删除:请去掉 recursive=false/0,或不要设置 recursive(默认递归删除整目录)' }
+    }
+    return { success: false, error: e && e.message ? e.message : String(e) }
+  }
+  return { success: true, path: abs }
+}
+
+module.exports = { executeRemoveFolder }

+ 0 - 0
nodejs/ef-compiler/actions/fun/save-txt.js → nodejs/ef-compiler/actions/fun/IO/save-txt.js


+ 3 - 3
nodejs/ef-compiler/actions/adb/adb-parser.js → nodejs/ef-compiler/actions/fun/adb/adb-parser.js

@@ -70,7 +70,7 @@ async function execute(action, ctx) {
     if (method === 'image') {
       const imagePath = action.target.startsWith('/') || action.target.includes(':') ? action.target : `${folderPath}/${action.target}`
       if (!api?.matchImageAndGetCoordinate) return { success: false, error: '图像匹配 API 不可用' }
-      const matchResult = await api.matchImageAndGetCoordinate(device, imagePath)
+      const matchResult = await api.matchImageAndGetCoordinate(device, imagePath, folderPath)
       if (!matchResult.success) return { success: false, error: `图像匹配失败: ${matchResult.error != null ? matchResult.error : 'unknown'}` }
       position = matchResult.clickPosition
     } else if (method === 'text') {
@@ -93,7 +93,7 @@ async function execute(action, ctx) {
     else if (method === 'image') {
       const imagePath = action.target.startsWith('/') || action.target.includes(':') ? action.target : `${folderPath}/${action.target}`
       if (!api?.matchImageAndGetCoordinate) return { success: false, error: '图像匹配 API 不可用' }
-      const matchResult = await api.matchImageAndGetCoordinate(device, imagePath)
+      const matchResult = await api.matchImageAndGetCoordinate(device, imagePath, folderPath)
       if (!matchResult.success) return { success: false, error: `图像匹配失败: ${matchResult.error != null ? matchResult.error : 'unknown'}` }
       position = matchResult.clickPosition
     } else if (method === 'text') {
@@ -113,7 +113,7 @@ async function execute(action, ctx) {
   if (action.type === 'press') {
     const imagePath = `${folderPath}/resources/${action.value}`
     if (!api?.matchImageAndGetCoordinate) return { success: false, error: '图像匹配 API 不可用' }
-    const matchResult = await api.matchImageAndGetCoordinate(device, imagePath)
+    const matchResult = await api.matchImageAndGetCoordinate(device, imagePath, folderPath)
     if (!matchResult.success) return { success: false, error: `图像匹配失败: ${matchResult.error != null ? matchResult.error : 'unknown'}` }
     const { x, y } = matchResult.clickPosition
     if (!api?.sendTap) return { success: false, error: '点击 API 不可用' }

+ 0 - 0
nodejs/ef-compiler/actions/adb/click.js → nodejs/ef-compiler/actions/fun/adb/click.js


+ 8 - 4
nodejs/ef-compiler/actions/adb/input.js → nodejs/ef-compiler/actions/fun/adb/input.js

@@ -12,16 +12,20 @@ const B64_CHUNK_CHARS = 200
 function getProjectRoot(ctx) {
   const root = ctx.compilerConfig?.projectRoot
   if (root && fs.existsSync(root)) return root
-  const defaultRoot = path.resolve(__dirname, '..', '..', '..', '..')
-  return defaultRoot
+  // 本文件在 .../nodejs/ef-compiler/actions/fun/adb:上溯 **5** 层才到仓库根(4 层只会停在 nodejs/)
+  return path.resolve(__dirname, '..', '..', '..', '..', '..')
 }
 
 function getAdbPath(projectRoot) {
   try {
-    const configPath = path.join(projectRoot, 'config.js')
+    const configPath = process.env.STATIC_ROOT
+      ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
+      : path.join(projectRoot, 'config.js')
     const config = fs.existsSync(configPath) ? require(configPath) : {}
+    const cfgRoot =
+      (config.projectRoot && fs.existsSync(config.projectRoot)) ? config.projectRoot : projectRoot
     const p = config.adbPath?.path
-    if (p) return path.isAbsolute(p) ? p : path.resolve(projectRoot, p)
+    if (p) return path.isAbsolute(p) ? p : path.resolve(cfgRoot, p)
   } catch (e) {}
   return path.join(projectRoot, 'lib', 'scrcpy-adb', process.platform === 'win32' ? 'adb.exe' : 'adb')
 }

+ 0 - 0
nodejs/ef-compiler/actions/adb/keyevent.js → nodejs/ef-compiler/actions/fun/adb/keyevent.js


+ 1 - 1
nodejs/ef-compiler/actions/adb/locate.js → nodejs/ef-compiler/actions/fun/adb/locate.js

@@ -12,7 +12,7 @@ async function run(action, ctx) {
     if (!imagePath) return { success: false, error: 'locate 操作(image)缺少图片路径' }
     const fullPath = imagePath.startsWith('/') || imagePath.includes(':') ? imagePath : `${folderPath}/resources/${imagePath}`
     if (!api?.matchImageAndGetCoordinate) return { success: false, error: '图像匹配 API 不可用' }
-    const matchResult = await api.matchImageAndGetCoordinate(device, fullPath)
+    const matchResult = await api.matchImageAndGetCoordinate(device, fullPath, folderPath)
     if (!matchResult.success) return { success: false, error: `图像匹配失败: ${matchResult.error != null ? matchResult.error : 'unknown'}` }
     position = matchResult.clickPosition
   } else if (locateMethod === 'text') {

+ 1 - 1
nodejs/ef-compiler/actions/adb/press.js → nodejs/ef-compiler/actions/fun/adb/press.js

@@ -8,7 +8,7 @@ async function run(action, ctx) {
   if (!imagePath) return { success: false, error: 'press 操作缺少图片路径' }
   const fullPath = imagePath.startsWith('/') || imagePath.includes(':') ? imagePath : `${folderPath}/${imagePath}`
   if (!api?.matchImageAndGetCoordinate) return { success: false, error: '图像匹配 API 不可用' }
-  const matchResult = await api.matchImageAndGetCoordinate(device, fullPath)
+  const matchResult = await api.matchImageAndGetCoordinate(device, fullPath, folderPath)
   if (!matchResult.success) return { success: false, error: `图像匹配失败: ${matchResult.error != null ? matchResult.error : 'unknown'}` }
   const { x, y } = matchResult.clickPosition
   if (!api?.sendTap) return { success: false, error: '点击 API 不可用' }

+ 0 - 0
nodejs/ef-compiler/actions/adb/scroll.js → nodejs/ef-compiler/actions/fun/adb/scroll.js


+ 3 - 2
nodejs/ef-compiler/actions/adb/send-img-to-device.js → nodejs/ef-compiler/actions/fun/adb/send-img-to-device.js

@@ -6,9 +6,10 @@ const { spawnSync } = require('child_process')
 const path = require('path')
 const fs = require('fs')
 
-const defaultRoot = path.resolve(__dirname, '..', '..', '..', '..')
+// .../nodejs/ef-compiler/actions/fun/adb/send-img-to-device.js → 上溯 5 层到仓库根
+const defaultRoot = path.resolve(__dirname, '..', '..', '..', '..', '..')
 const configPath = process.env.STATIC_ROOT
-  ? path.join(path.dirname(process.env.STATIC_ROOT), 'config.js')
+  ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
   : path.join(defaultRoot, 'config.js')
 const config = fs.existsSync(configPath) ? require(configPath) : {}
 const projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot)) ? config.projectRoot : defaultRoot

+ 0 - 0
nodejs/ef-compiler/actions/adb/string-press.js → nodejs/ef-compiler/actions/fun/adb/string-press.js


+ 0 - 0
nodejs/ef-compiler/actions/adb/swipe.js → nodejs/ef-compiler/actions/fun/adb/swipe.js


+ 0 - 0
nodejs/ef-compiler/actions/adb/utils.js → nodejs/ef-compiler/actions/fun/adb/utils.js


+ 4 - 13
nodejs/ef-compiler/actions/fun/ai/img2img.js

@@ -1,24 +1,15 @@
-const path = require('path')
 const fs = require('fs')
-const aiModule = require(path.join(__dirname, '../../../../ai/ai.js'))
-
-function resolveSavePath(savePath, folderPath) {
-  if (!savePath || typeof savePath !== 'string') return null
-  const trimmed = savePath.trim()
-  if (path.isAbsolute(trimmed) || /^[A-Za-z]:/.test(trimmed)) return trimmed
-  return folderPath ? path.join(folderPath, trimmed) : path.resolve(trimmed)
-}
+const path = require('path')
+const { runWithModel, resolveSavePath } = require('./shared')
 
 /** 入参:prompt, model, imageUrl(参考图地址), savePath(可选,生成图保存路径) */
-async function executeImg2img({ prompt, model, imageUrl, savePath, folderPath }) {
+async function executeImg2img ({ prompt, model, imageUrl, savePath, folderPath }) {
   const p = prompt != null ? String(prompt).trim() : ''
   const url = imageUrl != null ? String(imageUrl).trim() : ''
   if (!url) return { success: false, error: 'img2img 缺少 imageUrl' }
-  const m = model != null ? String(model).trim().toLowerCase() : ''
-  const action = m === 'doubao' ? 'doubao_img2img' : 'img2img'
   const outPath = savePath ? resolveSavePath(savePath, folderPath) : null
   try {
-    const result = await aiModule.run(action, p, url)
+    const result = await runWithModel('img2img', 'doubao_img2img', [p, url], model)
     if (!result.success) return { success: false, error: result.error || 'img2img 失败' }
     const data = result.data
     const item = data?.data?.[0]

+ 16 - 9
nodejs/ef-compiler/actions/fun/ai/img2text.js

@@ -1,15 +1,22 @@
-const path = require('path')
-const aiModule = require(path.join(__dirname, '../../../../ai/ai.js'))
+const { runWithModel } = require('./shared')
 
-/** 入参:prompt, model, imageUrl(参考图地址,如 data URL 或 http URL) */
-async function executeImg2text({ prompt, model, imageUrl, folderPath }) {
+/**
+ * 入参:prompt, model, imageUrl
+ * imageUrl:单张 data URL / https,或 **多张 URL 的数组**(与 OpenAI 多图 content 一致)
+ */
+async function executeImg2text ({ prompt, model, imageUrl, folderPath }) {
   const p = prompt != null ? String(prompt).trim() : ''
-  const url = imageUrl != null ? String(imageUrl).trim() : ''
-  if (!url) return { success: false, error: 'img2text 缺少 imageUrl' }
-  const m = model != null ? String(model).trim().toLowerCase() : ''
-  const action = m === 'doubao' ? 'doubao_img2text' : 'img2text'
+  let url
+  if (Array.isArray(imageUrl)) {
+    url = imageUrl.map((u) => String(u || '').trim()).filter(Boolean)
+  } else {
+    url = imageUrl != null ? String(imageUrl).trim() : ''
+  }
+  if (!url || (Array.isArray(url) && url.length === 0)) {
+    return { success: false, error: 'img2text 缺少 imageUrl(或非空 URL 数组)' }
+  }
   try {
-    const result = await aiModule.run(action, p, url)
+    const result = await runWithModel('img2text', 'doubao_img2text', [p, url], model)
     if (!result.success) return { success: false, error: result.error || 'img2text 失败' }
     const data = result.data
     const text = data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? ''

+ 60 - 0
nodejs/ef-compiler/actions/fun/ai/shared.js

@@ -0,0 +1,60 @@
+const path = require('path')
+
+/** 与 fun/ai 各 execute 一致:统一走 nodejs/ai/ai.js */
+const aiModule = require(path.join(__dirname, '../../../../ai/ai.js'))
+
+/**
+ * @param {string} [model] 流程传入的 model;'doubao' 走豆包分支
+ * @param {{ openai: string, doubao: string }} actions openai 为 ai.run 名;doubao 为 doubao_ 前缀全名
+ */
+function resolveModelAction (model, actions) {
+  const mRaw = model != null ? String(model).trim() : ''
+  const m = mRaw.toLowerCase()
+  if (m === 'doubao') {
+    return { action: actions.doubao, opts: {} }
+  }
+  const opts = {}
+  if (mRaw) opts.model = mRaw
+  return { action: actions.openai, opts }
+}
+
+function mergeRunOptions (opts, extraOpts) {
+  const merged = { ...opts, ...(extraOpts || {}) }
+  const out = {}
+  for (const [k, v] of Object.entries(merged)) {
+    if (v === undefined || v === null) continue
+    if (k === 'model' && String(v).trim() === '') continue
+    out[k] = v
+  }
+  return out
+}
+
+/**
+ * @param {string} openaiAction
+ * @param {string} doubaoAction 如 doubao_img2text
+ * @param {unknown[]} runArgs ai.run 中位于 options 之前的参数
+ * @param {string} [model]
+ * @param {Record<string, unknown>} [extraOpts] 如 { timeoutMs }
+ */
+async function runWithModel (openaiAction, doubaoAction, runArgs, model, extraOpts) {
+  const { action, opts } = resolveModelAction(model, { openai: openaiAction, doubao: doubaoAction })
+  const finalOpts = mergeRunOptions(opts, extraOpts)
+  if (Object.keys(finalOpts).length === 0) {
+    return aiModule.run(action, ...runArgs)
+  }
+  return aiModule.run(action, ...runArgs, finalOpts)
+}
+
+function resolveSavePath (savePath, folderPath) {
+  if (!savePath || typeof savePath !== 'string') return null
+  const trimmed = savePath.trim()
+  if (path.isAbsolute(trimmed) || /^[A-Za-z]:/.test(trimmed)) return trimmed
+  return folderPath ? path.join(folderPath, trimmed) : path.resolve(trimmed)
+}
+
+module.exports = {
+  aiModule,
+  resolveModelAction,
+  runWithModel,
+  resolveSavePath,
+}

+ 4 - 13
nodejs/ef-compiler/actions/fun/ai/text2img.js

@@ -1,22 +1,13 @@
-const path = require('path')
 const fs = require('fs')
-const aiModule = require(path.join(__dirname, '../../../../ai/ai.js'))
-
-function resolveSavePath(savePath, folderPath) {
-  if (!savePath || typeof savePath !== 'string') return null
-  const trimmed = savePath.trim()
-  if (path.isAbsolute(trimmed) || /^[A-Za-z]:/.test(trimmed)) return trimmed
-  return folderPath ? path.join(folderPath, trimmed) : path.resolve(trimmed)
-}
+const path = require('path')
+const { runWithModel, resolveSavePath } = require('./shared')
 
 /** 入参:prompt, model, savePath(可选,图片保存路径;有则返回 b64 并写入文件,否则返回 url) */
-async function executeText2img({ prompt, model, savePath, folderPath }) {
+async function executeText2img ({ prompt, model, savePath, folderPath }) {
   const p = prompt != null ? String(prompt).trim() : ''
-  const m = model != null ? String(model).trim().toLowerCase() : ''
-  const action = m === 'doubao' ? 'doubao_text2img' : 'text2img'
   const outPath = savePath ? resolveSavePath(savePath, folderPath) : null
   try {
-    const result = await aiModule.run(action, p, outPath)
+    const result = await runWithModel('text2img', 'doubao_text2img', [p, outPath], model)
     if (!result.success) return { success: false, error: result.error || 'text2img 失败' }
     const data = result.data
     const item = data?.data?.[0]

+ 3 - 6
nodejs/ef-compiler/actions/fun/ai/text2text.js

@@ -1,12 +1,9 @@
-const path = require('path')
-const aiModule = require(path.join(__dirname, '../../../../ai/ai.js'))
+const { runWithModel } = require('./shared')
 
-async function executeText2text({ prompt, model, folderPath }) {
+async function executeText2text ({ prompt, model, folderPath }) {
   const p = prompt != null ? String(prompt).trim() : ''
-  const m = model != null ? String(model).trim().toLowerCase() : ''
-  const action = m === 'doubao' ? 'doubao_text2text' : 'text2text'
   try {
-    const result = await aiModule.run(action, p)
+    const result = await runWithModel('text2text', 'doubao_text2text', [p], model)
     if (!result.success) return { success: false, error: result.error || 'text2text 失败' }
     const data = result.data
     const text = data?.choices?.[0]?.message?.content ?? data?.choices?.[0]?.text ?? ''

+ 40 - 14
nodejs/ef-compiler/actions/fun/download-img.js

@@ -5,25 +5,19 @@
 const path = require('path')
 const fs = require('fs')
 const { spawnSync } = require('child_process')
+const { getPythonExeFromConfig } = require('../../../python-exe-from-config.js')
 
 const configPath = process.env.STATIC_ROOT
-  ? path.join(path.dirname(process.env.STATIC_ROOT), 'config.js')
+  ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
   : path.join(__dirname, '..', '..', '..', '..', 'config.js')
-const projectRoot = path.dirname(path.resolve(configPath))
 const config = fs.existsSync(configPath) ? require(configPath) : {}
+const projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot))
+  ? config.projectRoot
+  : path.dirname(path.resolve(configPath))
 const scriptPath = path.join(projectRoot, 'python', 'scripts', 'download-img-by-prompt.py')
 const imagedlParent = path.join(projectRoot, 'python')
 const imagedlRequirements = path.join(projectRoot, 'python', 'imagedl', 'requirements.txt')
 
-function getPythonPath() {
-  const base = config.pythonPath?.path || path.join(projectRoot, 'python', 'py')
-  const winPy = path.join(base, 'python.exe')
-  if (fs.existsSync(winPy)) return winPy
-  const unixPy = path.join(base, 'python')
-  if (fs.existsSync(unixPy)) return unixPy
-  return 'python'
-}
-
 function buildSavePath(savePath, folderPath) {
   if (!savePath || typeof savePath !== 'string') return null
   const trimmed = savePath.trim()
@@ -31,6 +25,36 @@ function buildSavePath(savePath, folderPath) {
   return folderPath ? path.join(folderPath, trimmed) : path.resolve(projectRoot, trimmed)
 }
 
+/** spawnSync 非 0 退出时拼详细日志(含启动失败、信号、截断后的 stdout/stderr) */
+function formatSpawnFailure (r, pythonPath, scriptPath) {
+  const maxLen = 1500
+  const trunc = (s) => {
+    const t = String(s || '').trim()
+    if (!t) return ''
+    return t.length > maxLen ? `${t.slice(0, maxLen)}…(truncated)` : t
+  }
+  const parts = []
+  if (pythonPath) {
+    const abs = path.isAbsolute(pythonPath) || /^[A-Za-z]:/.test(pythonPath)
+    const missing = abs && !fs.existsSync(pythonPath)
+    parts.push(`python=${pythonPath}${missing ? ' (file missing)' : ''}`)
+  }
+  if (scriptPath) parts.push(`script=${scriptPath}`)
+  if (r.error) {
+    const e = r.error
+    const code = e.code != null ? String(e.code) : ''
+    parts.push(`spawnError${code ? `[${code}]` : ''}: ${e.message || e}`)
+  }
+  if (r.signal) parts.push(`signal=${r.signal}`)
+  if (r.status !== null && r.status !== undefined) parts.push(`exitCode=${r.status}`)
+  else parts.push('exitCode=(null)')
+  const stderr = trunc(r.stderr)
+  const stdout = trunc(r.stdout)
+  if (stderr) parts.push(`stderr: ${stderr}`)
+  if (stdout) parts.push(`stdout: ${stdout}`)
+  return parts.length ? parts.join(' | ') : 'download-img 执行失败(无子进程输出)'
+}
+
 async function executeDownloadImg({ prompt, savePath, folderPath }) {
   if (!prompt || typeof prompt !== 'string' || !prompt.trim()) return { success: false, error: 'download-img 缺少 prompt 参数' }
   if (savePath == null) return { success: false, error: 'download-img 缺少 savePath 参数' }
@@ -38,7 +62,7 @@ async function executeDownloadImg({ prompt, savePath, folderPath }) {
   if (!absolutePath) return { success: false, error: 'download-img savePath 无效' }
   if (!fs.existsSync(scriptPath)) return { success: false, error: `脚本不存在: ${scriptPath}` }
 
-  const pythonPath = getPythonPath()
+  const pythonPath = getPythonExeFromConfig(config)
   const args = [scriptPath, '--prompt', prompt.trim(), '--save-path', absolutePath.replace(/\\/g, '/')]
   const runScript = () => spawnSync(pythonPath, args, {
     encoding: 'utf-8',
@@ -63,7 +87,7 @@ async function executeDownloadImg({ prompt, savePath, folderPath }) {
   }
 
   if (r.status !== 0) {
-    return { success: false, error: (r.stderr || r.stdout || '').trim() || 'download-img 执行失败' }
+    return { success: false, error: formatSpawnFailure(r, pythonPath, scriptPath) }
   }
   // stdout 可能包含进度条等,取最后一行以 { 开头的行作为 JSON
   const lines = out.split(/\r?\n/).map(s => s.trim()).filter(Boolean)
@@ -76,7 +100,9 @@ async function executeDownloadImg({ prompt, savePath, folderPath }) {
   try {
     result = JSON.parse(jsonStr)
   } catch (_) {
-    return { success: false, error: out.slice(-300) || '无法解析输出' }
+    const raw = String(out || '').trim()
+    const tail = raw.length > 800 ? raw.slice(-800) : raw
+    return { success: false, error: `无法解析 JSON 输出 | stdout(tail): ${tail || '(empty)'}` }
   }
   if (!result.success) return { success: false, error: result.error || '未下载到图片' }
   return { success: true, path: result.path }

+ 82 - 0
nodejs/ef-compiler/actions/fun/fun-adb-json-bridge.js

@@ -0,0 +1,82 @@
+/**
+ * 供 type: fun 的 method 桥接:adb-* → adb-parser;json-to-arr / json-json-to-arr → json-to-arr.js
+ */
+const path = require('path')
+
+const adbParserExecute = require(path.join(__dirname, 'adb', 'adb-parser.js')).execute
+
+const ADB_SUBMETHODS = new Set([
+  'input',
+  'click',
+  'locate',
+  'swipe',
+  'scroll',
+  'keyevent',
+  'press',
+  'string-press',
+  'send-img-to-device',
+])
+
+function buildSyntheticAdbAction (action, adbMethod) {
+  return {
+    type: 'adb',
+    method: adbMethod,
+    inVars: Array.isArray(action.inVars) ? action.inVars : [],
+    outVars: Array.isArray(action.outVars) ? action.outVars : [],
+    target: action.target,
+    value: action.value,
+    variable: action.variable,
+    clear: action.clear,
+    area: action.area,
+    avatar: action.avatar,
+  }
+}
+
+function buildAdbExecuteCtx (ctx, device, folderPath) {
+  const resolution = ctx.resolution || { width: 1080, height: 1920 }
+  return {
+    device,
+    folderPath,
+    resolution,
+    variableContext: ctx.variableContext,
+    api: ctx.electronAPI,
+    extractVarName: ctx.extractVarName,
+    resolveValue: ctx.resolveValue,
+    logOutVars: ctx.logOutVars,
+    DEFAULT_SCROLL_DISTANCE: ctx.DEFAULT_SCROLL_DISTANCE ?? 100,
+  }
+}
+
+/**
+ * @returns {Promise<null|{success:boolean}>} null 表示非桥接 method,由 run() 继续分发
+ */
+async function runFunBridgedMethod (actionType, action, ctx, device, folderPath) {
+  if (typeof actionType !== 'string') return null
+
+  if (actionType.startsWith('adb-')) {
+    const sub = actionType.slice(4)
+    if (!ADB_SUBMETHODS.has(sub)) {
+      return { success: false, error: `未知的 fun.method adb 子命令: ${sub}(支持: ${[...ADB_SUBMETHODS].join(', ')})` }
+    }
+    const syn = buildSyntheticAdbAction(action, sub)
+    const adbCtx = buildAdbExecuteCtx(ctx, device, folderPath)
+    return adbParserExecute(syn, adbCtx)
+  }
+
+  if (actionType === 'json-to-arr' || actionType === 'json-json-to-arr') {
+    const { executeJsonToArr } = require(path.join(__dirname, 'json', 'json-to-arr.js'))
+    const jsonString = action.inVars && action.inVars.length > 0 ? action.inVars[0] : undefined
+    const result = await executeJsonToArr({ jsonString })
+    if (!result.success) return { success: false, error: result.error }
+    const { extractVarName, variableContext, logOutVars } = ctx
+    const outputVarName =
+      action.outVars && action.outVars.length > 0 ? extractVarName(String(action.outVars[0]).trim()) : null
+    if (outputVarName && Array.isArray(result.result)) variableContext[outputVarName] = result.result
+    await logOutVars(action, variableContext, folderPath)
+    return { success: true, ...result }
+  }
+
+  return null
+}
+
+module.exports = { runFunBridgedMethod, ADB_SUBMETHODS }

+ 28 - 1
nodejs/ef-compiler/actions/fun/fun-node-registry.js

@@ -17,6 +17,33 @@ module.exports = [
   { type: 'img2text', category: 'io', in: ['prompt', 'model', 'imageUrl'], execute: 'executeImg2text', script: 'ai/img2text.js', displayName: 'ai img2text' },
   { type: 'text2img', category: 'io', in: ['prompt', 'model', 'savePath'], execute: 'executeText2img', script: 'ai/text2img.js', displayName: 'ai text2img' },
   { type: 'img2img', category: 'io', in: ['prompt', 'model', 'imageUrl', 'savePath'], execute: 'executeImg2img', script: 'ai/img2img.js', displayName: 'ai img2img' },
-  { type: 'json', category: 'io', in: ['jsonString'], execute: 'executeJsonToArr', script: '../json/json-to-arr.js', displayName: 'json to arr' },
+  { type: 'json', category: 'io', in: ['jsonString'], execute: 'executeJsonToArr', script: 'json/json-to-arr.js', displayName: 'json to arr' },
   { type: 'download-img', category: 'io', in: ['prompt', 'savePath'], inAlt: { savePath: 'save-path' }, execute: 'executeDownloadImg', script: 'download-img.js', displayName: 'download img by prompt' },
+  {
+    type: 'img-scale',
+    category: 'io',
+    in: ['imagePath', 'savePath', 'scale'],
+    inAlt: { savePath: 'save-path', imagePath: 'image-path', scale: 'scale-factor' },
+    execute: 'executeImgScale',
+    script: 'img/img-scale.js',
+    displayName: 'img scale proportional',
+  },
+  {
+    type: 'create-folder',
+    category: 'io',
+    in: ['path'],
+    inAlt: { path: 'dirPath' },
+    execute: 'executeCreateFolder',
+    script: 'IO/create-forder.js',
+    displayName: 'create folder',
+  },
+  {
+    type: 'remove-folder',
+    category: 'io',
+    in: ['path'],
+    inAlt: { path: 'dirPath' },
+    execute: 'executeRemoveFolder',
+    script: 'IO/remove-forder.js',
+    displayName: 'remove folder',
+  },
 ]

+ 67 - 38
nodejs/ef-compiler/actions/fun/fun-parser.js

@@ -5,9 +5,23 @@
 const path = require('path')
 const variableParser = require('../../variable-parser.js')
 const FUN_NODE_REGISTRY = require('./fun-node-registry.js')
+const funAdbJsonBridge = require('./fun-adb-json-bridge.js')
+
+/** type: io + method 与常见拼写 → 注册表 type */
+function normalizeRegistryMethodName (name) {
+  if (name == null || name === '') return name
+  const key = String(name).trim().toLowerCase().replace(/_/g, '-')
+  const map = {
+    'remov-folder': 'remove-folder',
+    'remove-forder': 'remove-folder',
+    'creat-folder': 'create-folder',
+    'create-forder': 'create-folder',
+  }
+  return map[key] || String(name).trim()
+}
 
 const LEGACY_FUN_TYPES = [
-  'fun', 'ai',
+  'fun', 'ai', 'io',
   'read-txt', 'read-text', 'save-txt', 'save-text',
   'img-bounding-box-location', 'img-center-point-location', 'img-cropping',
   'ocr',
@@ -74,8 +88,8 @@ function parse(action, parseContext) {
   switch (action.type) {
     case 'fun':
       parsed.method = action.method
-      parsed.inVars = action.inVars && Array.isArray(action.inVars) ? action.inVars.map(v => extractVarName(v)) : []
-      parsed.outVars = action.outVars && Array.isArray(action.outVars) ? action.outVars.map(v => extractVarName(v)) : []
+      parsed.inVars = Array.isArray(action.inVars) ? action.inVars.map((v) => extractVarName(v)) : []
+      parsed.outVars = Array.isArray(action.outVars) ? action.outVars.map((v) => extractVarName(v)) : []
       break
     case 'extract-messages':
     case 'ocr-chat':
@@ -149,9 +163,12 @@ function parse(action, parseContext) {
       break
     }
     case 'img-cropping':
-      parsed.inVars = action.inVars && Array.isArray(action.inVars) ? action.inVars.map(v => extractVarName(v)) : []
-      parsed.area = action.inVars?.[0] ?? action.area
+      parsed.inVars = action.inVars && Array.isArray(action.inVars)
+        ? action.inVars.map((v, i) => (i === 2 && Array.isArray(v) ? v : extractVarName(v)))
+        : []
+      parsed.imagePath = action.inVars?.[0] ?? action.imagePath
       parsed.savePath = action.inVars?.[1] ?? action.savePath
+      parsed.squareSpec = action.inVars?.[2] ?? action.squareSpec
       if (action.outVars && action.outVars.length > 0) parsed.variable = extractVarName(action.outVars[0])
       break
     case 'ocr':
@@ -178,14 +195,22 @@ async function runAction(action, device, folderPath, resolution, ctx) {
   }
 
   const resolvedAction = variableParser.resolveActionInputs(action, variableContext)
+  ctx.resolution = resolution || ctx.resolution || { width: 1080, height: 1920 }
 
-  if (resolvedAction.type === 'fun' && resolvedAction.method) {
-    return run(resolvedAction.method, resolvedAction, ctx, device, folderPath)
+  if (resolvedAction.type === 'fun') {
+    const m = resolvedAction.method != null ? String(resolvedAction.method).trim() : ''
+    if (!m) return { success: false, error: 'fun 结点缺少 method(如 adb-click、json-to-arr)' }
+    return run(m, resolvedAction, ctx, device, folderPath)
   }
   if (resolvedAction.type === 'ai' && resolvedAction.method) {
     return run(resolvedAction.method, resolvedAction, ctx, device, folderPath)
   }
 
+  if (resolvedAction.type === 'io' && resolvedAction.method) {
+    const m = normalizeRegistryMethodName(resolvedAction.method)
+    return run(m, resolvedAction, ctx, device, folderPath)
+  }
+
   if (supports(resolvedAction.type)) {
     return run(resolvedAction.type, resolvedAction, ctx, device, folderPath)
   }
@@ -227,9 +252,9 @@ function get(funcDir, category) {
   switch (category) {
     case 'img':
       mod = {
-        executeImgBoundingBoxLocation: require(path.join(funcDir, 'img-bounding-box-location.js')).executeImgBoundingBoxLocation,
-        executeImgCenterPointLocation: require(path.join(funcDir, 'img-center-point-location.js')).executeImgCenterPointLocation,
-        executeImgCropping: require(path.join(funcDir, 'img-cropping.js')).executeImgCropping,
+        executeImgBoundingBoxLocation: require(path.join(funcDir, 'img', 'img-bounding-box-location.js')).executeImgBoundingBoxLocation,
+        executeImgCenterPointLocation: require(path.join(funcDir, 'img', 'img-center-point-location.js')).executeImgCenterPointLocation,
+        executeImgCropping: require(path.join(funcDir, 'img', 'img-cropping.js')).executeImgCropping,
         executeOcr: require(path.join(funcDir, 'ocr.js')).executeOcr,
         executeOcrFindText: require(path.join(funcDir, 'ocr.js')).executeOcrFindText,
       }
@@ -237,9 +262,9 @@ function get(funcDir, category) {
     case 'io':
       mod = {
         executeReadLastMessage: require(path.join(funcDir, 'chat', 'read-last-message.js')).executeReadLastMessage,
-        executeReadTxt: require(path.join(funcDir, 'read-txt.js')).executeReadTxt,
+        executeReadTxt: require(path.join(funcDir, 'IO', 'read-txt.js')).executeReadTxt,
         executeSmartChatAppend: require(path.join(funcDir, 'chat', 'smart-chat-append.js')).executeSmartChatAppend,
-        executeSaveTxt: require(path.join(funcDir, 'save-txt.js')).executeSaveTxt,
+        executeSaveTxt: require(path.join(funcDir, 'IO', 'save-txt.js')).executeSaveTxt,
       }
       ;(FUN_NODE_REGISTRY || []).filter((r) => r.category === 'io').forEach((def) => {
         const scriptPath = path.join(funcDir, def.script || def.type + '.js')
@@ -287,6 +312,9 @@ async function run(actionType, action, ctx, device, folderPath) {
   const funcDir = ctx.compilerConfig && ctx.compilerConfig.funcDir
   if (!funcDir) return { success: false, error: 'compilerConfig.funcDir 未提供' }
 
+  const bridged = await funAdbJsonBridge.runFunBridgedMethod(actionType, action, ctx, device, folderPath)
+  if (bridged != null) return bridged
+
   switch (actionType) {
     case 'img-bounding-box-location': {
       const { executeImgBoundingBoxLocation } = get(funcDir, 'img')
@@ -327,25 +355,11 @@ async function run(actionType, action, ctx, device, folderPath) {
       }
       if (!templatePath) templatePath = action.template
       if (!templatePath) return { success: false, error: '缺少模板图片路径' }
-      let scaleRange = action.inVars?.[1] ?? action.scaleRange
-      if (!Array.isArray(scaleRange) || scaleRange.length < 2) return { success: false, error: 'img-center-point-location 必须填写 inVars[1] 缩放比范围 [min, max],如 [0.2, 1.6]' }
-      const minS = Number(scaleRange[0])
-      const maxS = Number(scaleRange[1])
-      if (Number.isNaN(minS) || Number.isNaN(maxS) || minS >= maxS) return { success: false, error: 'img-center-point-location inVars[1] 缩放比范围无效,需为两个数字且 min < max' }
-      let centerRatio = 1
-      if (action.inVars?.length > 2 && action.inVars[2] != null) {
-        const third = action.inVars[2]
-        if (Array.isArray(third) && third.length >= 2) {
-          const p = Number(third[0])
-          const b = (third[1] != null && String(third[1]).trim().toLowerCase()) || ''
-          if (!Number.isNaN(p) && p > 0 && (b === 'w' || b === 'h')) centerRatio = [p, b]
-        } else {
-          const v = Number(third)
-          if (!Number.isNaN(v) && v > 0 && v <= 1) centerRatio = v
-        }
+      if (!Array.isArray(action.inVars) || action.inVars.length < 1) {
+        return { success: false, error: 'img-center-point-location 至少填写 inVars[0]=模板路径' }
       }
       if (!device) return { success: false, error: '缺少设备 ID,无法自动获取截图' }
-      const result = await executeImgCenterPointLocation({ device, template: templatePath, folderPath, scaleRange: [minS, maxS], centerRatio })
+      const result = await executeImgCenterPointLocation({ device, template: templatePath, folderPath })
       if (!result.success) return { success: false, error: `图像中心点定位失败: ${result.error}` }
       const outputVarName = action.outVars?.[0] != null ? extractVarName(String(action.outVars[0]).trim()) : (action.variable ? extractVarName(action.variable) : null)
       if (outputVarName) {
@@ -361,22 +375,28 @@ async function run(actionType, action, ctx, device, folderPath) {
 
     case 'img-cropping': {
       const { executeImgCropping } = get(funcDir, 'img')
-      let area = action.area
+      let imagePath = action.imagePath
+      let squareSpec = action.squareSpec
       let savePath = action.savePath
       if (action.inVars && Array.isArray(action.inVars)) {
-        if (action.inVars.length > 0) area = action.inVars[0]
+        if (action.inVars.length > 0) imagePath = action.inVars[0]
         if (action.inVars.length > 1) savePath = action.inVars[1]
+        if (action.inVars.length > 2) squareSpec = action.inVars[2]
+      }
+      if (!imagePath) return { success: false, error: 'img-cropping 缺少 imagePath(inVars[0])' }
+      if (!savePath) return { success: false, error: 'img-cropping 缺少 savePath(inVars[1])' }
+      if (squareSpec === undefined || squareSpec === null || squareSpec === '') {
+        return { success: false, error: 'img-cropping 缺少 squareSpec(inVars[2],如 [0.8,"w"])' }
       }
-      if (!area) return { success: false, error: 'img-cropping 缺少 area 参数' }
-      if (!savePath) return { success: false, error: 'img-cropping 缺少 savePath 参数' }
-      const result = await executeImgCropping({ area, savePath, folderPath, device })
+      const result = await executeImgCropping({ imagePath, squareSpec, savePath, folderPath })
       if (!result.success) return { success: false, error: result.error }
-      if (action.outVars?.[0] != null) {
-        const outputVarName = extractVarName(String(action.outVars[0]).trim())
-        if (outputVarName) variableContext[outputVarName] = result.success ? '1' : '0'
+      const outputVarName = action.outVars?.[0] != null ? extractVarName(String(action.outVars[0]).trim()) : null
+      if (outputVarName) {
+        const outVal = result.path ?? result.value ?? result.result
+        if (outVal !== undefined && outVal !== null) variableContext[outputVarName] = String(outVal)
       }
       await logOutVars(action, variableContext, folderPath)
-      return { success: true }
+      return { success: true, ...result }
     }
 
     case 'ocr': {
@@ -693,6 +713,15 @@ async function run(actionType, action, ctx, device, folderPath) {
           input[key] = val != null ? String(val).trim() : val
         })
         input.folderPath = folderPath
+        if (actionType === 'remove-folder') {
+          let rec = action.recursive
+          if ((rec === undefined || rec === null || rec === '') && action.inVars && action.inVars.length > 1 && action.inVars[1] !== undefined) {
+            rec = action.inVars[1]
+          }
+          if (rec !== undefined && rec !== null && rec !== '') {
+            input.recursive = typeof rec === 'string' ? rec.trim() : String(rec).trim()
+          }
+        }
         const mod = get(funcDir, regDef.category)
         const fn = mod[regDef.execute]
         if (!fn) {

+ 0 - 120
nodejs/ef-compiler/actions/fun/img-center-point-location.js

@@ -1,120 +0,0 @@
-/**
- * fun 标签:img-center-point-location
- * 图像匹配:识别模板图片在截图中的位置,返回中心点坐标
- */
-
-const path = require('path')
-const fs = require('fs')
-const os = require('os')
-const { spawnSync } = require('child_process')
-
-const configPath = process.env.STATIC_ROOT
-  ? path.join(path.dirname(process.env.STATIC_ROOT), 'config.js')
-  : path.join(__dirname, '..', '..', '..', '..', 'config.js')
-const config = fs.existsSync(configPath) ? require(configPath) : {}
-// 打包后优先使用 config.projectRoot(如 package/pack-resources/config.js 拷贝到输出后的配置),否则按路径推导
-const projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot))
-  ? config.projectRoot
-  : path.dirname(path.resolve(configPath))
-const imageMatchScriptPath = path.join(projectRoot, 'python', 'scripts', 'image-match.py')
-
-const tagName = 'img-center-point-location'
-
-const schema = {
-  description: '在屏幕截图中查找模板图片的位置并返回中心点坐标。inVars[1] 缩放比 [min,max]。inVars[2] 可选:数字 0–1 为旧版中心比例;或 [裁剪百分比, "w"|"h"] 表示以模板宽/高的该比例作为正方形边长取中心方形裁剪后匹配,如 [1,"w"]、[0.1,"h"]。',
-  inputs: { template: '模板图片路径', scaleRange: '缩放比范围数组 [min, max]', centerRatio: '可选:数字 0–1 或 [percent, "w"|"h"] 方形裁剪', variable: '输出变量名' },
-  outputs: { variable: '中心点坐标(JSON 字符串格式,如:{"x":123,"y":456})' },
-}
-
-/** 解析 Python 可执行路径(与 config.pythonPath:python/py)。 */
-function getPythonPath() {
-  const base = config.pythonPath?.path || path.join(projectRoot, 'python', 'py')
-  const winPy = path.join(base, 'python.exe')
-  if (fs.existsSync(winPy)) return winPy
-  const unixPy = path.join(base, 'python')
-  if (fs.existsSync(unixPy)) return unixPy
-  return 'python'
-}
-
-/** 在设备截图中匹配模板,返回坐标与中心点。scaleRange 为 [min, max]。centerRatio 为数字 0–1(旧版)或 [percent, 'w'|'h'] 方形裁剪。 */
-function matchImageAndGetCoordinate(device, imagePath, scaleRange, centerRatio) {
-  if (!imagePath || typeof imagePath !== 'string') return { success: false, error: '模板路径为空' }
-  if (!Array.isArray(scaleRange) || scaleRange.length < 2) return { success: false, error: '缩放比范围 scaleRange 必填,且为 [min, max] 数组,如 [0.2, 1.6]' }
-  const minScale = Number(scaleRange[0])
-  const maxScale = Number(scaleRange[1])
-  if (Number.isNaN(minScale) || Number.isNaN(maxScale) || minScale >= maxScale) return { success: false, error: '缩放比范围无效,需为两个数字且 min < max' }
-  const templatePath = path.isAbsolute(imagePath) ? imagePath : path.resolve(projectRoot, imagePath)
-  if (!fs.existsSync(templatePath)) return { success: false, error: `模板文件不存在: ${templatePath}` }
-  const ts = Date.now()
-  const templateDir = path.dirname(templatePath)
-  const templateBase = path.basename(templatePath, path.extname(templatePath))
-  const screenshotPath = path.join(templateDir, `Screenshot-${templateBase}.png`)
-  try { fs.mkdirSync(templateDir, { recursive: true }) } catch (_) {}
-  const templateCopyPath = path.join(os.tmpdir(), `ef-template-${ts}.png`)
-  fs.copyFileSync(templatePath, templateCopyPath)
-
-  const pythonPath = getPythonPath()
-  const adbPath = config.adbPath?.path
-    ? (path.isAbsolute(config.adbPath.path) ? config.adbPath.path : path.resolve(projectRoot, config.adbPath.path))
-    : path.join(projectRoot, 'lib', 'scrcpy-adb', process.platform === 'win32' ? 'adb.exe' : 'adb')
-  const cropOutputPath = path.join(templateDir, `Matched-${templateBase}.png`)
-  const args = [imageMatchScriptPath, '--adb', adbPath, '--device', device, '--screenshot', screenshotPath.replace(/\\/g, '/'), '--template', templateCopyPath.replace(/\\/g, '/'), '--method', 'feature', '--scale-min', String(minScale), '--scale-max', String(maxScale), '--crop-output', cropOutputPath.replace(/\\/g, '/')]
-  const isCropSquare = Array.isArray(centerRatio) && centerRatio.length >= 2 && typeof centerRatio[0] === 'number' && centerRatio[0] > 0
-  const cropBase = isCropSquare ? String(centerRatio[1]).trim().toLowerCase() : ''
-  const hasCrop = isCropSquare && (cropBase === 'w' || cropBase === 'h') || (centerRatio != null && typeof centerRatio === 'number' && centerRatio > 0 && centerRatio < 1)
-  if (isCropSquare && (cropBase === 'w' || cropBase === 'h')) {
-    args.push('--crop-square', String(centerRatio[0]), cropBase)
-  } else {
-    const ratio = centerRatio != null && typeof centerRatio === 'number' && centerRatio > 0 && centerRatio <= 1 ? centerRatio : 1
-    if (ratio < 1) args.push('--center-ratio', String(ratio))
-  }
-  if (hasCrop) args.push('--template-output', templatePath.replace(/\\/g, '/'))
-  const env = { ...process.env, PYTHONIOENCODING: 'utf-8' }
-  if (process.platform === 'win32') {
-    const pyDir = path.dirname(pythonPath)
-    const pyRoot = path.dirname(path.dirname(pyDir))
-    env.PATH = [pyDir, pyRoot, process.env.PATH].filter(Boolean).join(path.delimiter)
-  }
-  const spawnOpts = { encoding: 'utf-8', timeout: 20000, env, cwd: projectRoot }
-  const r = spawnSync(pythonPath, args, spawnOpts)
-  try { fs.unlinkSync(templateCopyPath) } catch (_) {}
-  // 截图保留在模板同级目录(Screenshot-pic0.png 等),便于排查匹配失败原因
-
-  if (r.status !== 0) {
-    const msg = [r.stderr, r.stdout].filter(Boolean).map(s => String(s).trim()).join('\n') || '图像匹配失败'
-    const extra = r.signal ? ` [signal: ${r.signal}]` : (r.error ? ` [${r.error.message}]` : '')
-    const scriptExists = fs.existsSync(imageMatchScriptPath)
-    const pyExists = fs.existsSync(pythonPath)
-    const diag = ` | projectRoot=${projectRoot} script存在=${scriptExists} python存在=${pyExists} status=${r.status}`
-    return { success: false, error: msg + extra + diag }
-  }
-  let out
-  try {
-    out = JSON.parse(r.stdout.trim())
-  } catch (e) {
-    return { success: false, error: `脚本输出非 JSON: ${(r.stdout || r.stderr || '').slice(0, 200)}` }
-  }
-  if (!out.success) return { success: false, error: out.error || '未找到图片' }
-  return {
-    success: true,
-    coordinate: { x: out.x, y: out.y, width: out.width, height: out.height },
-    clickPosition: { x: out.center_x, y: out.center_y }
-  }
-}
-
-async function executeImgCenterPointLocation({ device, template, folderPath, scaleRange, centerRatio }) {
-  if (!device) return { success: false, error: '缺少设备 ID,无法自动获取截图' }
-  if (!template || typeof template !== 'string') return { success: false, error: '缺少模板图片路径' }
-  if (!Array.isArray(scaleRange) || scaleRange.length < 2) return { success: false, error: 'img-center-point-location 必须填写 inVars[1] 缩放比范围 [min, max],如 [0.2, 1.6]' }
-  const baseDir = folderPath && typeof folderPath === 'string' ? folderPath : projectRoot
-  // 绝对路径或带盘符的保持原样;已含子路径(如 tmp/pic0.png)相对 baseDir;否则视为 resources 下文件名
-  const isAbsoluteOrDrive = template.startsWith('/') || template.includes(':')
-  const hasSubPath = template.includes('/') || template.includes(path.sep)
-  const templatePath = isAbsoluteOrDrive ? template : (hasSubPath ? path.join(baseDir, template) : path.join(baseDir, 'resources', template))
-  const result = matchImageAndGetCoordinate(device, templatePath, scaleRange, centerRatio)
-  if (!result.success) return { success: false, error: result.error }
-  const center = result.clickPosition || { x: result.coordinate.x + result.coordinate.width / 2, y: result.coordinate.y + result.coordinate.height / 2 }
-  return { success: true, center, coordinate: result.coordinate }
-}
-
-module.exports = { tagName, schema, executeImgCenterPointLocation, matchImageAndGetCoordinate }

+ 0 - 104
nodejs/ef-compiler/actions/fun/img-cropping.js

@@ -1,104 +0,0 @@
-/**
- * fun 标签:img-cropping
- * 根据区域坐标裁剪截图指定区域并保存
- */
-
-const path = require('path')
-const fs = require('fs')
-const { spawnSync } = require('child_process')
-const { captureScreenshot } = require('../../../adb/adb-screencap.js')
-
-const tagName = 'img-cropping'
-const configPath = process.env.STATIC_ROOT
-  ? path.join(path.dirname(process.env.STATIC_ROOT), 'config.js')
-  : path.join(__dirname, '..', '..', '..', '..', 'config.js')
-const projectRoot = path.dirname(path.resolve(configPath))
-const config = fs.existsSync(configPath) ? require(configPath) : {}
-const imgCropScriptPath = path.join(projectRoot, 'python', 'scripts', 'img-crop.py')
-
-/** 解析 area 为 { x, y, width, height } */
-function parseAreaToRect(area) {
-  const obj = typeof area === 'string' ? JSON.parse(area) : area
-  if (obj.topLeft && obj.bottomRight) {
-    return {
-      x: parseInt(obj.topLeft.x, 10),
-      y: parseInt(obj.topLeft.y, 10),
-      width: parseInt(obj.bottomRight.x - obj.topLeft.x, 10),
-      height: parseInt(obj.bottomRight.y - obj.topLeft.y, 10),
-    }
-  }
-  if (obj.topLeft && obj.topRight && obj.bottomLeft && obj.bottomRight) {
-    return {
-      x: parseInt(obj.topLeft.x, 10),
-      y: parseInt(obj.topLeft.y, 10),
-      width: parseInt(obj.bottomRight.x - obj.topLeft.x, 10),
-      height: parseInt(obj.bottomRight.y - obj.topLeft.y, 10),
-    }
-  }
-  if (obj.x != null && obj.y != null && obj.width != null && obj.height != null) {
-    return {
-      x: parseInt(obj.x, 10),
-      y: parseInt(obj.y, 10),
-      width: parseInt(obj.width, 10),
-      height: parseInt(obj.height, 10),
-    }
-  }
-  return null
-}
-
-/** 构建截图路径 */
-function buildScreenshotPath(folderPath) {
-  return path.join(folderPath, 'history', 'ScreenShot.png')
-}
-
-/** 构建保存路径 */
-function buildSavePath(savePath, folderPath) {
-  return savePath.includes(':') ? savePath : path.join(folderPath, savePath)
-}
-
-/** 解析 Python 可执行路径(与 config.pythonPath:python/py)。 */
-function getPythonExe() {
-  const base = config.pythonPath?.path || path.join(projectRoot, 'python', 'py')
-  const winPy = path.join(base, 'python.exe')
-  if (fs.existsSync(winPy)) return winPy
-  const unixPy = path.join(base, 'python')
-  if (fs.existsSync(unixPy)) return unixPy
-  return 'python'
-}
-
-/** 调用 Python 裁剪图片并保存 */
-function cropAndSaveImage(inputPath, outputPath, x, y, width, height) {
-  const pythonExe = getPythonExe()
-  const r = spawnSync(pythonExe, [imgCropScriptPath, inputPath, outputPath, String(x), String(y), String(width), String(height)], {
-    encoding: 'utf-8',
-    timeout: 10000,
-    cwd: projectRoot
-  })
-  if (r.status !== 0) {
-    return { success: false, error: (r.stderr || r.stdout || '').trim() || '裁剪失败' }
-  }
-  return { success: true }
-}
-
-/** 执行 img-cropping */
-async function executeImgCropping({ area, savePath, folderPath, device }) {
-  const rect = parseAreaToRect(area)
-  if (!rect || rect.width <= 0 || rect.height <= 0) {
-    return { success: false, error: '区域坐标格式不正确' }
-  }
-
-  const screenshotPath = buildScreenshotPath(folderPath)
-  const outputPath = buildSavePath(savePath, folderPath)
-
-  if (device) {
-    const cap = captureScreenshot(device, screenshotPath)
-    if (!cap.success) return { success: false, error: cap.error }
-  }
-
-  const crop = cropAndSaveImage(screenshotPath, outputPath, rect.x, rect.y, rect.width, rect.height)
-  if (!crop.success) return { success: false, error: crop.error }
-
-  return { success: true }
-}
-
-module.exports = { tagName, executeImgCropping, cropAndSaveImage }

+ 0 - 0
nodejs/ef-compiler/actions/fun/img-bounding-box-location.js → nodejs/ef-compiler/actions/fun/img/img-bounding-box-location.js


+ 1035 - 0
nodejs/ef-compiler/actions/fun/img/img-center-point-location.js

@@ -0,0 +1,1035 @@
+/**
+ * fun 标签:img-center-point-location
+ * 仅需 inVars[0]:模板路径。
+ *
+ * 【当前逻辑】tmp → ADB 截图 + 模板复制 → VLM img2text(主模型 + 最多 2 个备用模型,**共≤3 次**;
+ * 仅在「请求失败」或「无有效中心点坐标」时换模型重试)→ 归一化转像素 →
+ * 可选:Python 在截图上画**绿色圆圈**标出中心点,保存为 `screenshot_center_marked.png`(同一次 `img-center-时间戳` 目录)。
+ *
+ * 【已注释保留】原流程:VLM(ROI + template_crop / need_center_crop / template_scale)→ 外扩 roi →
+ * Python 预处理 → NCC(img-center-orb-akaze.py)。恢复时取消下方对应块注释并改回 runPipeline 即可。
+ *
+ * 密钥与网关:优先根目录 config.js(openaiApiKey 等同步到环境变量),否则 nodejs/ai/config.js;模型常量见文件顶部。
+ */
+
+// ---------------------------------------------------------------------------
+// 配置:模型与超时(优先改这里或对应环境变量 / 根目录 config.js)
+// ---------------------------------------------------------------------------
+
+/**
+ * 多模态模型 id(**仅请求一次** img2text):在此直接写字符串;写 '' 则依次用
+ * process.env.IMG_CENTER_OPENAI_MODEL、config.imgCenterOpenAiModel、nodejs/ai 默认链。
+ */
+// const IMG_CENTER_OPENAI_MODEL = 'gemini-3.1-pro-preview'
+const IMG_CENTER_OPENAI_MODEL = 'gpt-5.4'
+
+/**
+ * 中心点 VLM:主模型请求失败或 JSON 无法解析出有效 center_rx/center_ry 时,依次用备用 1、备用 2 再请求。
+ * 链为 [主模型, 备用1, 备用2] 去重后取前 3 个,**最多 3 次** img2text。
+ * 写 '' 表示该档跳过;环境变量 IMG_CENTER_FALLBACK_MODEL_1 / IMG_CENTER_FALLBACK_MODEL_2 优先于常量。
+ *
+ * Claude 4.6 选型(多模态「对齐两点坐标」类任务,公开对比多认为 Opus 图像分析略强于 Sonnet,但本任务输出极短 JSON):
+ * - 首选备用:**claude-sonnet-4-6**(延迟/成本更友好,视觉足够)。
+ * - 第二轮:**claude-opus-4-6**(能力上限略高,仍不稳再上)。
+ * - *-thinking:推理链更长、更慢更贵;非复杂推理时一般**不必**作默认备用。
+ * OpenAI 网关可改回如 gpt-4o / gpt-4.1。
+ */
+const IMG_CENTER_FALLBACK_MODEL_1 = 'claude-opus-4-6'
+const IMG_CENTER_FALLBACK_MODEL_2 = 'gemini-3.1-pro-preview'
+
+/** 中心点 img2text 最多调用次数(主 + 备用,去重后截断) */
+const IMG_CENTER_CENTER_POINT_MAX_VLM_CALLS = 3
+
+/**
+ * 视觉 API(img2text)请求超时(毫秒)
+ * 环境变量 IMG_CENTER_AI_TIMEOUT_MS 可覆盖
+ */
+const IMG_CENTER_AI_TIMEOUT_MS = Math.max(
+  10_000,
+  parseInt(String(process.env.IMG_CENTER_AI_TIMEOUT_MS || '').trim(), 10) || 300_000
+)
+
+/** 在截图上绘制中心点绿圈的 Python 脚本超时(毫秒);环境变量 IMG_CENTER_MARK_SCRIPT_TIMEOUT_MS */
+const IMG_CENTER_MARK_SCRIPT_TIMEOUT_MS = Math.max(
+  5000,
+  parseInt(String(process.env.IMG_CENTER_MARK_SCRIPT_TIMEOUT_MS || '').trim(), 10) || 30_000
+)
+
+/** 【旧 NCC 流程】runPipeline 内等待 Matched.png 就绪的上限(毫秒) */
+const MATCHED_PNG_MAX_WAIT_MS = 60_000
+
+/** 【旧 NCC 流程】Python NCC 匹配脚本 spawn 超时(毫秒) */
+const PYTHON_ORB_SCRIPT_TIMEOUT_MS = 120_000
+
+/**
+ * 【旧 NCC 流程】NCC 最低分:传给 Python(环境变量 IMG_CENTER_NCC_MIN_SCORE)。
+ */
+const IMG_CENTER_NCC_MIN_SCORE_DEFAULT = 0.34
+
+/** 【旧预处理】模板预处理脚本 spawn 超时(毫秒);环境变量 IMG_CENTER_PREPROCESS_TIMEOUT_MS 可覆盖 */
+const PYTHON_PREPROCESS_TEMPLATE_TIMEOUT_MS = Math.max(
+  5000,
+  parseInt(String(process.env.IMG_CENTER_PREPROCESS_TIMEOUT_MS || '').trim(), 10) || 60_000
+)
+
+/**
+ * 【旧 VLM ROI】外扩比例,环境变量 IMG_CENTER_ROI_PAD 可覆盖
+ */
+const IMG_CENTER_ROI_PAD = Math.max(
+  0,
+  Math.min(
+    0.12,
+    parseFloat(String(process.env.IMG_CENTER_ROI_PAD || '').trim()) || 0.03
+  )
+)
+
+/**
+ * 【旧 VLM ROI】归一化高度下限;IMG_CENTER_ROI_MIN_REL_H,默认 0.15
+ */
+const IMG_CENTER_ROI_MIN_REL_H = Math.max(
+  0.06,
+  Math.min(
+    0.45,
+    parseFloat(String(process.env.IMG_CENTER_ROI_MIN_REL_H || '').trim()) || 0.15
+  )
+)
+
+/**
+ * 【旧 VLM ROI】归一化宽度下限(0 表示不强制);IMG_CENTER_ROI_MIN_REL_W
+ */
+const IMG_CENTER_ROI_MIN_REL_W = (() => {
+  const v = parseFloat(String(process.env.IMG_CENTER_ROI_MIN_REL_W || '').trim())
+  if (!Number.isFinite(v) || v <= 0) return 0
+  return Math.max(0.06, Math.min(0.95, v))
+})()
+
+// ---------------------------------------------------------------------------
+// 依赖与工程路径
+// ---------------------------------------------------------------------------
+
+const path = require('path')
+const fs = require('fs')
+const { spawnSync } = require('child_process')
+const { getPythonExeFromConfig } = require('../../../../python-exe-from-config.js')
+
+const configPath = process.env.STATIC_ROOT
+  ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
+  : path.join(__dirname, '..', '..', '..', '..', '..', 'config.js')
+const config = fs.existsSync(configPath) ? require(configPath) : {}
+const projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot))
+  ? config.projectRoot
+  : path.dirname(path.resolve(configPath))
+
+/** 在加载 nodejs/ai 之前同步,使 ai/config 能读到应用级密钥与网关 */
+function syncProcessEnvFromAppConfig () {
+  const k = config.openaiApiKey || config.vlmApiKey
+  if (
+    k &&
+    !String(process.env.API_KEY || '').trim() &&
+    !String(process.env.OPENAI_API_KEY || '').trim() &&
+    !String(process.env.VLM_API_KEY || '').trim()
+  ) {
+    process.env.API_KEY = String(k).trim()
+  }
+  const u = config.openaiApiUrl
+  if (
+    u &&
+    !String(process.env.BASE_URL || '').trim() &&
+    !String(process.env.OPENAI_API_URL || '').trim()
+  ) {
+    process.env.BASE_URL = String(u).trim().replace(/\/$/, '')
+  }
+}
+
+syncProcessEnvFromAppConfig()
+
+const aiRoot = path.join(__dirname, '..', '..', '..', '..', 'ai')
+const aiModule = require(path.join(aiRoot, 'ai.js'))
+const img2textRequest = require(path.join(aiRoot, 'request', 'img2text.js'))
+const aiPackageConfig = require(path.join(aiRoot, 'config.js'))
+
+/** 【旧 NCC / 预处理】脚本路径(当前 runPipeline 不调用;恢复旧流程时用) */
+const orbScriptPath = path.join(projectRoot, 'python', 'scripts', 'img-center-orb-akaze.py')
+const preprocessTemplateScriptPath = path.join(projectRoot, 'python', 'scripts', 'img-center-preprocess-template.py')
+/** 在截图上标出 VLM 中心点(绿圈)的可视化脚本 */
+const markCenterScriptPath = path.join(projectRoot, 'python', 'scripts', 'img-center-mark-center.py')
+
+/**
+ * 【当前】仅问中心点:图1=截图,图2=模板;返回 center_rx、center_ry ∈ [0,1](相对图1 宽高)。
+ */
+const VLM_CENTER_ONLY_PROMPT = `你收到两张图,顺序固定:
+图1:Android 手机完整截图(与 adb screencap 一致),逻辑像素宽约 W、高约 H;坐标原点在左上角,x 向右、y 向下。
+图2:模板图。请在图1 中找到与图2 视觉上对应的同一区域(同一图标、缩略图格子、按钮等)。
+
+任务:给出图2 在图1 中**匹配可见区域的几何中心点**(该区域中心,不是图2 整张文件的画布中心)。
+
+只输出一个 JSON 对象,必须包含:
+- "center_rx"、"center_ry":数字,取值在 [0,1],分别为该中心点在图1 上相对宽度 W、高度 H 的归一化坐标(左缘=0,上缘=0)。
+若图1 中完全无法对应图2,两个键均填 null。
+
+禁止 markdown、禁止代码围栏、禁止 JSON 以外的任何文字。`
+
+/*
+ * ---------- 旧 ROI + 模板几何 VLM 提示(NCC 流程用,保留勿删)----------
+const VLM_USER_PROMPT = `你收到两张图,顺序固定:
+图1:Android 手机竖屏完整截图,逻辑像素宽 W、高 H。
+图2:模板图——要在图1画面里定位的同一内容。常见情况:图2 是 PC/自动化侧通过 ADB 推到手机上的**原始文件**(与磁盘上打开的同一像素内容),**不是**相册 App 已处理后的内部版本。图1 里若出现该图,往往是**系统相册/图库**里的**网格缩略图**:会先对原图做**缩小**,且方格多为**近似正方形**,竖图/横图常被**居中裁成方图**再显示,与图2 原始长宽比可能完全不同。
+
+任务:在图1中找到与图2视觉上对应的那一块区域(同一控件、同一相册格里的缩略图等),给出「搜索矩形」roi_hint;并在适用时给出 template_crop、template_scale、need_center_crop,使下游能**复现相册里那一格的「裁切 + 缩放」**,用于模板匹配。
+
+硬性规则(roi_hint):
+- roi_hint 的四个数必须是相对图1 的归一化坐标:rx0,ry0 为矩形左上角,rx1,ry1 为右下角,均在 [0,1],且 rx0<rx1、ry0<ry1。
+- 【只框图2】图2 里**只有**某一图标/按钮时,roi 应主要覆盖图1里**与之对应的那一块**,不要为了「多装点内容」而把**竖直方向**上、与图2无关的相邻 Tab 图标(例如在图1里与目标**横向并排**的其它底栏图标)一起框进竖长条;底栏场景下各图标是**横排**的,roi 应是**偏横向的条带**包住目标及窄边距,而不是上下堆叠多个无关图标。
+- 【必须完整】图2 模板在图1里所对应的那一整块 UI(含圆角、描边、阴影等可见像素)必须**全部**落在 roi 矩形内部,**任何一边都不得裁切**到模板上的图形;若宁可 roi 明显大一点也要保证完整。
+- 【自检】若你意识到按当前四个数裁图会「切掉」图2 上任意可见部分(例如只框到红色按钮的下半截、圆角被切、+ 号缺一截),**必须**把 ry0 上移或 ry1 下移、或放宽 rx,直到不会裁切。
+- 【底部栏】目标在屏底导航栏时:roi 的纵向高度 (ry1−ry0) 建议至少为屏高的 **14%~22%**,且 ry0 要足够靠上,使整块圆形/圆角按钮(含完整外轮廓)都在框内;**禁止**高度小于屏高 **12%** 的扁条。横向 (rx1−rx0) 以刚好包住目标按钮宽度 + 左右各约 5%~15% 屏宽为宜。
+- 【可大一些】在已完整包含上述目标的前提下,roi 宁可略大勿小:各方向外扩边距避免贴边裁切;图2 近似正方形时,roi **勿**做成「竖远长于横」的窄竖条(除非图2 本身就是竖长条)。
+- 若图1中存在多处相似元素,选择与图2内容最一致、最可能是用户意图的一处;若完全无法对应,四个坐标全填 null。
+
+template_scale(模拟相册把图缩小进格子的比例):
+- 数字,范围建议 0.05~1.0,表示在**已按下方顺序处理完 template_crop 与 need_center_crop 之后**,对图2 再做**等比线性缩放**(宽、高同比例);1 表示不在此步缩小。
+- 估法:对比**图1 里目标格中的缩略图**与**图2 原始文件**——若屏上格子里的内容明显是「整图缩小后的局部/整体」,应给 **小于 1**(相册格常见约 **0.2~0.5**,视分辨率与格大小而定);若图1 里几乎 1:1 对应原图2 像素内容则接近 1。
+- **不要**把 template_scale 理解成随意数字;应绑定:**对齐「ADB 原图 → 相册网格里显示尺寸」的缩放**。
+
+template_crop(相对**原始图2** 宽高的归一化矩形 cx0,cy0,cx1,cy1 ∈ [0,1],cx0<cx1、cy0<cy1):
+- **相册/多宫格场景**:表示「图1 里那一格缩略图**所对应的原图2 上的可见区域**」——即:若把原图2 按相册逻辑裁切后才会得到与格内一致的画面,应用此矩形框出原图2 上的该区域。竖长图在方格里通常只显示**中间一条/一块**,此时 crop 应是**接近正方形或略竖/略横的矩形**,**不要**默认填 0,0,1,1 除非图1 明确显示的是「整图缩进格内、无裁切」。
+- **非相册场景**(图标、按钮、整段 UI):可表示去掉白边、只保留主体;无需裁切则 0,0,1,1。
+- 若你能较准给出「格内所见 ↔ 原图2」的对应关系,应优先给出**非全图**的 template_crop;与 need_center_crop 配合见下。
+
+need_center_crop:
+- 当图1 明显是「相册/多宫格选图」、缩略图为**方格**且图2 与格内显示的长宽比不一致(典型:竖图进方格)时为 **true**;否则 **false**。
+- 为 true 时:roi_hint 仍只框**与图2 对应的那一格**(含格线外极少边距)。下游会在 template_crop 结果上再作**居中取最大正方形**,以逼近系统相册方格裁切;因此若你已在 template_crop 里给出了**精确的方格可见区域**(本身已接近正方形),可将 need_center_crop 设为 **false**,避免几何重复。
+- 若相册场景下**无法**可靠估计 template_crop,可 **template_crop 填 0,0,1,1** 且 **need_center_crop 为 true**,由程序用「整图居中裁方」兜底。
+
+confidence:0~1,可随意填;下游**不使用**该字段做拒识,仅作记录。
+
+只输出一个 JSON 对象,顶层键必须包含:roi_hint、need_center_crop、confidence、template_scale、template_crop。
+roi_hint 为对象,键 "rx0","ry0","rx1","ry1"(数字或 null)。
+template_crop 为对象,键 "cx0","cy0","cx1","cy1"(数字)。
+禁止 markdown、禁止代码围栏、禁止 JSON 以外的任何文字。`
+ * ---------- 旧 VLM_USER_PROMPT 结束 ----------
+ */
+
+const tagName = 'img-center-point-location'
+
+const schema = {
+  description:
+    '在屏幕截图中查找模板并返回中心点;主模型无效坐标时自动换备用模型(最多 3 次)。成功时保存 screenshot_center_marked.png。原 ROI+预处理+NCC 已注释保留。',
+  inputs: { template: '模板图片路径(inVars[0])', variable: '输出变量名(outVars)' },
+  outputs: { variable: '中心点 {x,y}(对象)' },
+}
+
+// ---------------------------------------------------------------------------
+// 以下为 runPipeline 主流程中的调用顺序(自上而下与执行顺序一致)
+// ---------------------------------------------------------------------------
+
+/** 流程目录下的 tmp,例如 static/process/GenerateNote/tmp */
+function resolveWorkflowTmpRoot (folderPath) {
+  if (folderPath && typeof folderPath === 'string') {
+    const fp = path.isAbsolute(folderPath) ? folderPath : path.join(projectRoot, folderPath)
+    return path.join(fp, 'tmp')
+  }
+  return path.join(projectRoot, 'tmp')
+}
+
+function getAdbPath () {
+  return config.adbPath?.path
+    ? (path.isAbsolute(config.adbPath.path) ? config.adbPath.path : path.resolve(projectRoot, config.adbPath.path))
+    : path.join(projectRoot, 'lib', 'scrcpy-adb', process.platform === 'win32' ? 'adb.exe' : 'adb')
+}
+
+function adbScreencapPng (adbPath, device, outFile) {
+  const r = spawnSync(adbPath, ['-s', device, 'exec-out', 'screencap', '-p'], {
+    encoding: 'buffer',
+    maxBuffer: 40 * 1024 * 1024,
+    windowsHide: true,
+  })
+  if (r.status !== 0 || !r.stdout || r.stdout.length < 100) return false
+  fs.mkdirSync(path.dirname(outFile), { recursive: true })
+  fs.writeFileSync(outFile, r.stdout)
+  return true
+}
+
+function fileToDataUrlPng (absPath) {
+  const buf = fs.readFileSync(absPath)
+  const b64 = buf.toString('base64')
+  return `data:image/png;base64,${b64}`
+}
+
+/** 从模型返回文本中抽出 JSON 对象 */
+function parseVlmJson (text) {
+  let s = String(text || '').trim()
+  const fence = s.match(/```(?:json)?\s*([\s\S]*?)```/i)
+  if (fence) s = fence[1].trim()
+  const m = s.match(/\{[\s\S]*\}/)
+  if (!m) return null
+  try {
+    return JSON.parse(m[0])
+  } catch (_) {
+    return null
+  }
+}
+
+/** 读取 PNG IHDR 宽高(adb screencap -p 为 PNG) */
+function readPngIhdrDimensions (absPath) {
+  try {
+    const fd = fs.openSync(absPath, 'r')
+    const buf = Buffer.allocUnsafe(24)
+    fs.readSync(fd, buf, 0, 24, 0)
+    fs.closeSync(fd)
+    if (buf.length < 24 || buf[0] !== 0x89) return null
+    const w = buf.readUInt32BE(16)
+    const h = buf.readUInt32BE(20)
+    if (!Number.isFinite(w) || !Number.isFinite(h) || w < 1 || h < 1) return null
+    return { width: w, height: h }
+  } catch (_) {
+    return null
+  }
+}
+
+function clamp01 (x) {
+  if (x === null || x === undefined) return null
+  const n = Number(x)
+  if (!Number.isFinite(n)) return null
+  return Math.max(0, Math.min(1, n))
+}
+
+/**
+ * 从 VLM JSON 解析归一化中心;优先 center_rx/center_ry,其次 cx/cy;
+ * 若仅有 center_x/center_y:大于 1 时按像素除以 width/height,否则按归一化。
+ */
+function parseCenterNormalizedFromVlm (parsed, width, height) {
+  if (!parsed || typeof parsed !== 'object') return null
+  let crx = clamp01(parsed.center_rx)
+  let cry = clamp01(parsed.center_ry)
+  if (crx != null && cry != null) return { center_rx: crx, center_ry: cry }
+
+  crx = clamp01(parsed.cx)
+  cry = clamp01(parsed.cy)
+  if (crx != null && cry != null) return { center_rx: crx, center_ry: cry }
+
+  const px = Number(parsed.center_x)
+  const py = Number(parsed.center_y)
+  if (!Number.isFinite(px) || !Number.isFinite(py)) return null
+  if (width > 0 && height > 0 && (px > 1 || py > 1)) {
+    crx = clamp01(px / width)
+    cry = clamp01(py / height)
+    if (crx != null && cry != null) return { center_rx: crx, center_ry: cry }
+  }
+  crx = clamp01(px)
+  cry = clamp01(py)
+  if (crx != null && cry != null) return { center_rx: crx, center_ry: cry }
+  return null
+}
+
+/** 实际发往 API 的模型名:见顶部 IMG_CENTER_OPENAI_MODEL → env / config → nodejs/ai。 */
+function getImgCenterModel () {
+  const explicit =
+    String(IMG_CENTER_OPENAI_MODEL || '').trim() ||
+    String(process.env.IMG_CENTER_OPENAI_MODEL || '').trim() ||
+    (config.imgCenterOpenAiModel && String(config.imgCenterOpenAiModel).trim()) ||
+    ''
+  return img2textRequest.resolveImgCenterModel(explicit || undefined)
+}
+
+/** 第一次仅用 OpenAI 兼容网关;若配置成 doubao 则改用 nodejs/ai 的 IMG_CENTER_MODEL(避免首轮即豆包) */
+function getPrimaryOpenAiImgCenterModel () {
+  const m = getImgCenterModel()
+  if (m && String(m).toLowerCase() === 'doubao') {
+    const fb = String(aiPackageConfig.IMG_CENTER_MODEL || '').trim()
+    if (fb && fb.toLowerCase() !== 'doubao') return fb
+    return 'gpt-5.4'
+  }
+  return m
+}
+
+function resolveFallbackCenterModelId (envKey, constVal) {
+  const e = String(process.env[envKey] || '').trim()
+  if (e) return e
+  return String(constVal != null ? constVal : '').trim()
+}
+
+/**
+ * 中心点 img2text 模型链:[主模型, 备用1, 备用2] 去重后取前 IMG_CENTER_CENTER_POINT_MAX_VLM_CALLS 个。
+ */
+function getCenterPointVlmModelChain () {
+  const primary = String(getPrimaryOpenAiImgCenterModel() || '').trim()
+  const fb1 = resolveFallbackCenterModelId('IMG_CENTER_FALLBACK_MODEL_1', IMG_CENTER_FALLBACK_MODEL_1)
+  const fb2 = resolveFallbackCenterModelId('IMG_CENTER_FALLBACK_MODEL_2', IMG_CENTER_FALLBACK_MODEL_2)
+  const chain = []
+  const push = (id) => {
+    const x = String(id || '').trim()
+    if (!x) return
+    if (!chain.includes(x)) chain.push(x)
+  }
+  push(primary)
+  push(fb1)
+  push(fb2)
+  return chain.slice(0, IMG_CENTER_CENTER_POINT_MAX_VLM_CALLS)
+}
+
+function getImgCenterAiMeta () {
+  return {
+    model: getImgCenterModel(),
+    primaryOpenAiModel: getPrimaryOpenAiImgCenterModel(),
+    centerPointVlmModelChain: getCenterPointVlmModelChain(),
+    fallbackModel1: resolveFallbackCenterModelId('IMG_CENTER_FALLBACK_MODEL_1', IMG_CENTER_FALLBACK_MODEL_1),
+    fallbackModel2: resolveFallbackCenterModelId('IMG_CENTER_FALLBACK_MODEL_2', IMG_CENTER_FALLBACK_MODEL_2),
+    baseUrl: aiPackageConfig.BASE_URL,
+    openAiKeyConfigured: !!(aiPackageConfig.API_KEY && String(aiPackageConfig.API_KEY).trim()),
+  }
+}
+
+/**
+ * 【当前】单次 img2text(指定 model);每次尝试写入 openai_raw_attempt_{i}.json。
+ * @returns {{ ok: boolean, data?: object, error?: string, model: string, rawResp?: object }}
+ */
+async function callOpenAiCenterPointWithModel (workDir, screenshotPath, templatePath, modelName, attemptIndex) {
+  syncProcessEnvFromAppConfig()
+  const openAiKey = String(aiPackageConfig.API_KEY || '').trim()
+  const model = String(modelName || '').trim()
+  const screenUrl = fileToDataUrlPng(screenshotPath)
+  const tplUrl = fileToDataUrlPng(templatePath)
+  const imageUrls = [screenUrl, tplUrl]
+  try {
+    if (!openAiKey) {
+      return { ok: false, error: '缺少 OpenAI 兼容 API_KEY', model }
+    }
+    if (!model) {
+      return { ok: false, error: '模型 id 为空', model }
+    }
+    const result = await aiModule.run('img2text', VLM_CENTER_ONLY_PROMPT, imageUrls, {
+      timeoutMs: IMG_CENTER_AI_TIMEOUT_MS,
+      model,
+    })
+    const resp = result.data
+    const attemptPayload = {
+      model,
+      attemptIndex,
+      httpSuccess: result.success,
+      httpError: result.success ? null : (result.error || null),
+      response: resp,
+    }
+    fs.writeFileSync(
+      path.join(workDir, `openai_raw_attempt_${attemptIndex}.json`),
+      JSON.stringify(attemptPayload, null, 2),
+      'utf8'
+    )
+    if (!result.success) {
+      return { ok: false, error: result.error || 'VLM 请求失败', model }
+    }
+    const content = resp?.choices?.[0]?.message?.content
+    const parsed = parseVlmJson(content)
+    if (!parsed || typeof parsed !== 'object') {
+      return { ok: false, error: '无法解析模型返回为 JSON', model, rawResp: resp }
+    }
+    return { ok: true, data: parsed, model, rawResp: resp }
+  } catch (e) {
+    const msg = e && e.message ? e.message : String(e)
+    try {
+      fs.writeFileSync(path.join(workDir, `openai_error_attempt_${attemptIndex}.txt`), msg, 'utf8')
+    } catch (_) {}
+    return { ok: false, error: msg, model }
+  }
+}
+
+/*
+ * ---------- 旧 callOpenAiRoi(ROI + 模板几何,NCC 流程用,保留勿删)----------
+async function callOpenAiRoi (workDir, screenshotPath, templatePath) {
+  syncProcessEnvFromAppConfig()
+  const openAiKey = String(aiPackageConfig.API_KEY || '').trim()
+
+  const screenUrl = fileToDataUrlPng(screenshotPath)
+  const tplUrl = fileToDataUrlPng(templatePath)
+  const imageUrls = [screenUrl, tplUrl]
+
+  const emptyFallback = {
+    roi_hint: { rx0: 0, ry0: 0, rx1: 1, ry1: 1 },
+    need_center_crop: false,
+    confidence: 0,
+    template_scale: 1,
+    template_crop: { cx0: 0, cy0: 0, cx1: 1, cy1: 1 },
+    template_vlm_preprocessed: false,
+  }
+
+  try {
+    if (!openAiKey) {
+      return { ok: false, error: '缺少 OpenAI 兼容 API_KEY', fallback: emptyFallback }
+    }
+    const openAiModel = getPrimaryOpenAiImgCenterModel()
+    const result = await aiModule.run('img2text', VLM_USER_PROMPT, imageUrls, {
+      timeoutMs: IMG_CENTER_AI_TIMEOUT_MS,
+      model: openAiModel,
+    })
+    if (!result.success) {
+      return { ok: false, error: result.error || 'VLM 请求失败', fallback: emptyFallback }
+    }
+    const resp = result.data
+    const content = resp?.choices?.[0]?.message?.content
+    const parsed = parseVlmJson(content)
+    if (!parsed || typeof parsed !== 'object') {
+      return { ok: false, error: '无法解析模型返回为 JSON', fallback: emptyFallback }
+    }
+    fs.writeFileSync(path.join(workDir, 'openai_raw.json'), JSON.stringify(resp, null, 2), 'utf8')
+    return { ok: true, data: parsed }
+  } catch (e) {
+    const msg = e && e.message ? e.message : String(e)
+    return { ok: false, error: msg, fallback: emptyFallback }
+  }
+}
+ * ---------- 旧 callOpenAiRoi 结束 ----------
+ */
+
+/*
+ * ---------- 旧 normalizeTemplateGeometry / normalizeVlmPayload / expandRoiHintNormalized / isFullScreenRoiHint(保留勿删)----------
+function normalizeTemplateGeometry (obj) {
+  const c = obj && obj.template_crop
+  let cx0 = 0
+  let cy0 = 0
+  let cx1 = 1
+  let cy1 = 1
+  if (c && typeof c === 'object') {
+    const n = (k, d) => {
+      const v = Number(c[k])
+      return Number.isFinite(v) ? Math.max(0, Math.min(1, v)) : d
+    }
+    cx0 = n('cx0', 0)
+    cy0 = n('cy0', 0)
+    cx1 = n('cx1', 1)
+    cy1 = n('cy1', 1)
+    if (cx1 <= cx0) {
+      cx0 = 0
+      cx1 = 1
+    }
+    if (cy1 <= cy0) {
+      cy0 = 0
+      cy1 = 1
+    }
+  }
+  let sc = Number(obj && obj.template_scale)
+  if (!Number.isFinite(sc) || sc <= 0) sc = 1
+  sc = Math.max(0.05, Math.min(1, sc))
+  return {
+    template_crop: { cx0, cy0, cx1, cy1 },
+    template_scale: sc,
+  }
+}
+
+function normalizeVlmPayload (obj) {
+  const rh = obj.roi_hint || {}
+  const nums = ['rx0', 'ry0', 'rx1', 'ry1']
+  let bad = false
+  for (const k of nums) {
+    const v = rh[k]
+    if (v === null || v === undefined) bad = true
+  }
+  const tg = normalizeTemplateGeometry(obj)
+  if (bad) {
+    return {
+      roi_hint: { rx0: 0, ry0: 0, rx1: 1, ry1: 1 },
+      need_center_crop: false,
+      confidence: 0,
+      template_scale: tg.template_scale,
+      template_crop: tg.template_crop,
+      template_vlm_preprocessed: false,
+    }
+  }
+  const c = Number(obj.confidence)
+  const conf = Number.isFinite(c) ? Math.max(0, Math.min(1, c)) : 0
+  return {
+    roi_hint: {
+      rx0: Number(rh.rx0),
+      ry0: Number(rh.ry0),
+      rx1: Number(rh.rx1),
+      ry1: Number(rh.ry1),
+    },
+    need_center_crop: !!obj.need_center_crop,
+    confidence: conf,
+    template_scale: tg.template_scale,
+    template_crop: tg.template_crop,
+    template_vlm_preprocessed: false,
+  }
+}
+
+function expandRoiHintNormalized (rh) {
+  let rx0 = Number(rh.rx0)
+  let ry0 = Number(rh.ry0)
+  let rx1 = Number(rh.rx1)
+  let ry1 = Number(rh.ry1)
+  if (!(rx1 > rx0 && ry1 > ry0)) return rh
+
+  const pad = IMG_CENTER_ROI_PAD
+  rx0 = Math.max(0, rx0 - pad)
+  ry0 = Math.max(0, ry0 - pad)
+  rx1 = Math.min(1, rx1 + pad)
+  ry1 = Math.min(1, ry1 + pad)
+
+  let h = ry1 - ry0
+  const hMin = IMG_CENTER_ROI_MIN_REL_H
+  if (h < hMin) {
+    const deficit = hMin - h
+    const bottomAnchored = ry1 >= 0.88
+    if (bottomAnchored) {
+      ry0 = Math.max(0, ry0 - deficit)
+      h = ry1 - ry0
+      if (h < hMin) {
+        ry1 = Math.min(1, ry0 + hMin)
+        h = ry1 - ry0
+        if (h < hMin) ry0 = Math.max(0, ry1 - hMin)
+      }
+    } else {
+      const cy = (ry0 + ry1) / 2
+      ry0 = Math.max(0, cy - hMin / 2)
+      ry1 = Math.min(1, ry0 + hMin)
+      if (ry1 >= 1 - 1e-9) ry0 = Math.max(0, 1 - hMin)
+    }
+  }
+
+  const wMin = IMG_CENTER_ROI_MIN_REL_W
+  if (wMin > 0) {
+    let w = rx1 - rx0
+    if (w < wMin) {
+      const cx = (rx0 + rx1) / 2
+      rx0 = Math.max(0, cx - wMin / 2)
+      rx1 = Math.min(1, rx0 + wMin)
+      if (rx1 >= 1 - 1e-9) rx0 = Math.max(0, 1 - wMin)
+    }
+  }
+
+  return { rx0, ry0, rx1, ry1 }
+}
+
+function isFullScreenRoiHint (rh) {
+  return (
+    Math.abs(Number(rh.rx0)) < 1e-9 &&
+    Math.abs(Number(rh.ry0)) < 1e-9 &&
+    Math.abs(Number(rh.rx1) - 1) < 1e-9 &&
+    Math.abs(Number(rh.ry1) - 1) < 1e-9
+  )
+}
+ * ---------- 旧几何归一化结束 ----------
+ */
+
+/*
+ * ---------- 旧 getPythonPath / runTemplatePreprocess / waitUntilMatchedWritten(保留勿删)----------
+function getPythonPath () {
+  const base = config.pythonPath?.path || config.pythonVenvPath || path.join(projectRoot, 'python', process.arch === 'arm64' ? 'arm64' : 'x64')
+  const envPy = path.join(base, 'env', 'Scripts', 'python.exe')
+  const scriptsPy = path.join(base, 'Scripts', 'python.exe')
+  const pyEmbedded = path.join(base, 'py', 'python.exe')
+  if (fs.existsSync(envPy)) return envPy
+  if (fs.existsSync(scriptsPy)) return scriptsPy
+  if (fs.existsSync(pyEmbedded)) return pyEmbedded
+  return 'python'
+}
+
+function runTemplatePreprocess (pythonPath, templateAbsPath, workDir) {
+  if (!fs.existsSync(preprocessTemplateScriptPath)) {
+    return { ok: false, error: `未找到 ${preprocessTemplateScriptPath}` }
+  }
+  const env = { ...process.env, PYTHONIOENCODING: 'utf-8' }
+  if (process.platform === 'win32') {
+    const pyDir = path.dirname(pythonPath)
+    const pyRoot = path.dirname(path.dirname(pyDir))
+    env.PATH = [pyDir, pyRoot, process.env.PATH].filter(Boolean).join(path.delimiter)
+  }
+  const r = spawnSync(
+    pythonPath,
+    [preprocessTemplateScriptPath, '--src', templateAbsPath, '--work-dir', workDir],
+    { encoding: 'utf-8', timeout: PYTHON_PREPROCESS_TEMPLATE_TIMEOUT_MS, env, cwd: projectRoot, windowsHide: true }
+  )
+  if (r.status !== 0) {
+    const msg = [r.stderr, r.stdout].filter(Boolean).join('\n').trim() || '模板预处理失败'
+    return { ok: false, error: msg }
+  }
+  try {
+    const tail = String(r.stdout || '').trim().split('\n').filter(Boolean).pop() || ''
+    const j = JSON.parse(tail)
+    if (!j.success) return { ok: false, error: j.error || '模板预处理失败' }
+    return { ok: true, meta: j }
+  } catch (e) {
+    return { ok: false, error: '模板预处理输出非 JSON' }
+  }
+}
+
+function waitUntilMatchedWritten (absPath, maxMs) {
+  const matchedWaitMs =
+    maxMs != null
+      ? maxMs
+      : (process.env.IMG_MATCH_MATCHED_WAIT_MS
+        ? Math.max(5000, parseInt(process.env.IMG_MATCH_MATCHED_WAIT_MS, 10) || 30000)
+        : 30000)
+  if (!absPath) return true
+  const t0 = Date.now()
+  let lastSize = -1
+  let stableStart = 0
+  const STABLE_MS = 120
+  while (Date.now() - t0 < matchedWaitMs) {
+    try {
+      if (fs.existsSync(absPath)) {
+        const st = fs.statSync(absPath)
+        if (st.size >= 32) {
+          if (st.size === lastSize) {
+            if (Date.now() - stableStart >= STABLE_MS) return true
+          } else {
+            lastSize = st.size
+            stableStart = Date.now()
+          }
+        }
+      }
+    } catch (_) {}
+    const until = Date.now() + 35
+    while (Date.now() < until) {}
+  }
+  try {
+    return fs.existsSync(absPath) && fs.statSync(absPath).size >= 32
+  } catch (_) {
+    return false
+  }
+}
+ * ---------- 旧 getPythonPath / 预处理 / Matched 等待结束 ----------
+ */
+
+function resolvePythonExecutable () {
+  return getPythonExeFromConfig(config)
+}
+
+/**
+ * 在 screenshot 上以绿色空心圆标出中心点,写入 workDir/screenshot_center_marked.png。
+ * 圆半径可由环境变量 IMG_CENTER_MARK_RADIUS(正整数像素)覆盖,否则由脚本按分辨率估算。
+ * @returns {{ ok: boolean, outPath?: string, error?: string }}
+ */
+function drawCenterMarkOnScreenshot (workDir, screenshotPath, centerX, centerY) {
+  if (!fs.existsSync(markCenterScriptPath)) {
+    return { ok: false, error: `未找到 ${markCenterScriptPath}` }
+  }
+  const outPath = path.join(workDir, 'screenshot_center_marked.png')
+  const pythonPath = resolvePythonExecutable()
+  const args = [
+    markCenterScriptPath,
+    '--input', screenshotPath,
+    '--output', outPath,
+    '--x', String(Math.round(centerX)),
+    '--y', String(Math.round(centerY)),
+  ]
+  const rEnv = parseInt(String(process.env.IMG_CENTER_MARK_RADIUS || '').trim(), 10)
+  if (Number.isFinite(rEnv) && rEnv > 0) {
+    args.push('--radius', String(rEnv))
+  }
+  const env = { ...process.env, PYTHONIOENCODING: 'utf-8' }
+  if (process.platform === 'win32') {
+    const pyDir = path.dirname(pythonPath)
+    const pyRoot = path.dirname(path.dirname(pyDir))
+    env.PATH = [pyDir, pyRoot, process.env.PATH].filter(Boolean).join(path.delimiter)
+  }
+  const r = spawnSync(pythonPath, args, {
+    encoding: 'utf-8',
+    timeout: IMG_CENTER_MARK_SCRIPT_TIMEOUT_MS,
+    env,
+    cwd: projectRoot,
+    windowsHide: true,
+  })
+  if (r.status !== 0) {
+    const msg = [r.stderr, r.stdout].filter(Boolean).join('\n').trim() || '绘制中心点标记失败'
+    return { ok: false, error: msg }
+  }
+  try {
+    const tail = String(r.stdout || '').trim().split('\n').filter(Boolean).pop() || ''
+    const j = JSON.parse(tail)
+    if (!j.success) return { ok: false, error: j.error || '绘制中心点标记失败' }
+  } catch (_) {
+    if (!fs.existsSync(outPath) || fs.statSync(outPath).size < 32) {
+      return { ok: false, error: '标记脚本输出非 JSON 或未写出有效 PNG' }
+    }
+  }
+  return { ok: true, outPath }
+}
+
+async function runPipeline (device, templateAbsPath, folderPath) {
+  if (!device) return { success: false, error: '缺少设备 ID' }
+  if (!templateAbsPath || !fs.existsSync(templateAbsPath)) {
+    return { success: false, error: `模板不存在: ${templateAbsPath}` }
+  }
+
+  const tmpRoot = resolveWorkflowTmpRoot(folderPath)
+  fs.mkdirSync(tmpRoot, { recursive: true })
+  const workDir = path.join(tmpRoot, `img-center-${Date.now()}`)
+  fs.mkdirSync(workDir, { recursive: true })
+
+  const screenshotPath = path.join(workDir, 'screenshot.png')
+  const templateInWork = path.join(workDir, 'template.png')
+
+  const adbPath = getAdbPath()
+  if (!adbScreencapPng(adbPath, device, screenshotPath)) {
+    return { success: false, error: 'ADB 截图失败' }
+  }
+  fs.copyFileSync(templateAbsPath, templateInWork)
+
+  const dims = readPngIhdrDimensions(screenshotPath)
+  if (!dims) {
+    return { success: false, error: '无法读取截图 PNG 尺寸(IHDR)', workDir }
+  }
+
+  const modelChain = getCenterPointVlmModelChain()
+  if (modelChain.length === 0) {
+    const err = '未配置可用 VLM 模型'
+    fs.writeFileSync(path.join(workDir, 'openai_error.txt'), err, 'utf8')
+    return { success: false, error: err, workDir }
+  }
+
+  const attemptLog = []
+  let lastError = ''
+  /** @type {{ center_rx: number, center_ry: number } | null} */
+  let norm = null
+  let aiData = null
+  let successRaw = null
+  let successModel = null
+
+  for (let i = 0; i < modelChain.length; i++) {
+    const m = modelChain[i]
+    const ai = await callOpenAiCenterPointWithModel(workDir, screenshotPath, templateInWork, m, i)
+    const normTry = ai.ok ? parseCenterNormalizedFromVlm(ai.data, dims.width, dims.height) : null
+    attemptLog.push({
+      index: i,
+      model: m,
+      requestOk: ai.ok,
+      error: ai.ok ? null : ai.error,
+      hasValidCenter: !!normTry,
+    })
+    if (!ai.ok) {
+      lastError = ai.error || 'VLM 中心点失败'
+      continue
+    }
+    if (normTry) {
+      norm = normTry
+      aiData = ai.data
+      successRaw = ai.rawResp
+      successModel = m
+      break
+    }
+    lastError = '模型未返回有效中心点(需 center_rx/center_ry 或兼容字段)'
+  }
+
+  try {
+    fs.writeFileSync(
+      path.join(workDir, 'vlm_center_model_attempts.json'),
+      JSON.stringify(
+        {
+          model_chain: modelChain,
+          success_model: successModel,
+          attempts: attemptLog,
+        },
+        null,
+        2
+      ),
+      'utf8'
+    )
+  } catch (_) {}
+
+  if (!norm || !aiData) {
+    fs.writeFileSync(path.join(workDir, 'openai_error.txt'), String(lastError || 'unknown'), 'utf8')
+    fs.writeFileSync(path.join(workDir, 'center_parse_error.txt'), String(lastError || 'unknown'), 'utf8')
+    return { success: false, error: lastError || 'VLM 中心点失败', workDir }
+  }
+
+  try {
+    if (successRaw) {
+      fs.writeFileSync(path.join(workDir, 'openai_raw.json'), JSON.stringify(successRaw, null, 2), 'utf8')
+    }
+    fs.writeFileSync(path.join(workDir, 'vlm_center_parsed.json'), JSON.stringify(aiData, null, 2), 'utf8')
+  } catch (_) {}
+
+  const px = Math.round(norm.center_rx * dims.width)
+  const py = Math.round(norm.center_ry * dims.height)
+  const ix = Math.max(0, Math.min(dims.width - 1, px))
+  const iy = Math.max(0, Math.min(dims.height - 1, py))
+
+  let markedScreenshotPath = null
+  const mark = drawCenterMarkOnScreenshot(workDir, screenshotPath, ix, iy)
+  if (mark.ok) {
+    markedScreenshotPath = mark.outPath || null
+  } else {
+    fs.writeFileSync(
+      path.join(workDir, 'screenshot_center_mark_error.txt'),
+      String(mark.error || 'unknown'),
+      'utf8'
+    )
+  }
+
+  fs.writeFileSync(
+    path.join(workDir, 'vlm_center_result.json'),
+    JSON.stringify(
+      {
+        center_rx: norm.center_rx,
+        center_ry: norm.center_ry,
+        pixel_x: ix,
+        pixel_y: iy,
+        screenshot_width: dims.width,
+        screenshot_height: dims.height,
+        marked_screenshot: markedScreenshotPath
+          ? path.basename(markedScreenshotPath)
+          : null,
+        marked_screenshot_error: markedScreenshotPath ? null : (mark.error || '未生成标记图'),
+      },
+      null,
+      2
+    ),
+    'utf8'
+  )
+
+  return {
+    success: true,
+    coordinate: { x: ix, y: iy, width: 1, height: 1 },
+    clickPosition: { x: ix, y: iy },
+    workDir,
+    markedScreenshotPath,
+  }
+}
+
+/*
+ * ---------- 旧 runPipeline(ROI + 预处理 + NCC,保留勿删)----------
+async function runPipeline_OLD_NCC (device, templateAbsPath, folderPath) {
+  if (!device) return { success: false, error: '缺少设备 ID' }
+  if (!templateAbsPath || !fs.existsSync(templateAbsPath)) {
+    return { success: false, error: `模板不存在: ${templateAbsPath}` }
+  }
+  if (!fs.existsSync(orbScriptPath)) {
+    return { success: false, error: `未找到 ${orbScriptPath}` }
+  }
+  if (!fs.existsSync(preprocessTemplateScriptPath)) {
+    return { success: false, error: `未找到 ${preprocessTemplateScriptPath}` }
+  }
+
+  const tmpRoot = resolveWorkflowTmpRoot(folderPath)
+  fs.mkdirSync(tmpRoot, { recursive: true })
+  const workDir = path.join(tmpRoot, `img-center-${Date.now()}`)
+  fs.mkdirSync(workDir, { recursive: true })
+
+  const screenshotPath = path.join(workDir, 'screenshot.png')
+  const templateInWork = path.join(workDir, 'template.png')
+  const matchedPath = path.join(workDir, 'Matched.png')
+
+  const adbPath = getAdbPath()
+  if (!adbScreencapPng(adbPath, device, screenshotPath)) {
+    return { success: false, error: 'ADB 截图失败' }
+  }
+  fs.copyFileSync(templateAbsPath, templateInWork)
+
+  const ai = await callOpenAiRoi(workDir, screenshotPath, templateInWork)
+  if (!ai.ok) {
+    fs.writeFileSync(path.join(workDir, 'openai_error.txt'), String(ai.error || 'unknown'), 'utf8')
+    return { success: false, error: ai.error || 'VLM ROI 失败', workDir }
+  }
+  let payload = normalizeVlmPayload(ai.data)
+  if (!isFullScreenRoiHint(payload.roi_hint)) {
+    payload = {
+      ...payload,
+      roi_hint: expandRoiHintNormalized(payload.roi_hint),
+    }
+  }
+  fs.writeFileSync(path.join(workDir, 'vlm_roi.json'), JSON.stringify(payload, null, 2), 'utf8')
+
+  const pythonPath = getPythonPath()
+  const prep = runTemplatePreprocess(pythonPath, templateAbsPath, workDir)
+  if (!prep.ok) {
+    fs.copyFileSync(templateAbsPath, templateInWork)
+    try {
+      const vrPath = path.join(workDir, 'vlm_roi.json')
+      const vr = JSON.parse(fs.readFileSync(vrPath, 'utf8'))
+      vr.template_vlm_preprocessed = false
+      delete vr.template_preprocess_paths
+      fs.writeFileSync(vrPath, JSON.stringify(vr, null, 2), 'utf8')
+    } catch (_) {}
+    fs.writeFileSync(path.join(workDir, 'template_preprocess_error.txt'), String(prep.error || ''), 'utf8')
+  }
+
+  const env = { ...process.env, PYTHONIOENCODING: 'utf-8' }
+  if (!String(env.IMG_CENTER_NCC_MIN_SCORE || '').trim()) {
+    env.IMG_CENTER_NCC_MIN_SCORE = String(IMG_CENTER_NCC_MIN_SCORE_DEFAULT)
+  }
+  if (process.platform === 'win32') {
+    const pyDir = path.dirname(pythonPath)
+    const pyRoot = path.dirname(path.dirname(pyDir))
+    env.PATH = [pyDir, pyRoot, process.env.PATH].filter(Boolean).join(path.delimiter)
+  }
+  const r = spawnSync(
+    pythonPath,
+    [orbScriptPath, '--work-dir', workDir],
+    { encoding: 'utf-8', timeout: PYTHON_ORB_SCRIPT_TIMEOUT_MS, env, cwd: projectRoot }
+  )
+  if (r.status !== 0) {
+    const msg = [r.stderr, r.stdout].filter(Boolean).join('\n').trim() || 'NCC 匹配脚本失败'
+    return { success: false, error: msg, workDir }
+  }
+  let out
+  try {
+    out = JSON.parse(r.stdout.trim())
+  } catch (e) {
+    return { success: false, error: `脚本输出非 JSON: ${(r.stdout || '').slice(0, 300)}`, workDir }
+  }
+  if (!out.success) return { success: false, error: out.error || '匹配失败', workDir }
+  if (!waitUntilMatchedWritten(matchedPath, MATCHED_PNG_MAX_WAIT_MS)) {
+    return { success: false, error: `Matched.png 未就绪: ${matchedPath}`, workDir }
+  }
+  return {
+    success: true,
+    coordinate: { x: out.x, y: out.y, width: out.width, height: out.height },
+    clickPosition: { x: out.center_x, y: out.center_y },
+    workDir,
+  }
+}
+ * ---------- 旧 runPipeline 结束 ----------
+ */
+
+/**
+ * press/locate:Electron 侧 await;返回 { success, coordinate?, clickPosition?, error? }
+ * @param {string} [folderPath] 当前流程目录(如 .../static/process/GenerateNote),临时文件写入其下 tmp/
+ */
+async function matchImageAndGetCoordinate (device, imagePath, folderPath) {
+  const templatePath = path.isAbsolute(imagePath) ? imagePath : path.resolve(projectRoot, imagePath)
+  const r = await runPipeline(device, templatePath, folderPath)
+  if (!r.success) return { success: false, error: r.error }
+  return {
+    success: true,
+    coordinate: r.coordinate,
+    clickPosition: r.clickPosition,
+    markedScreenshotPath: r.markedScreenshotPath || null,
+  }
+}
+
+async function executeImgCenterPointLocation ({ device, template, folderPath }) {
+  if (!device) return { success: false, error: '缺少设备 ID,无法自动获取截图' }
+  if (!template || typeof template !== 'string') return { success: false, error: '缺少模板图片路径(inVars[0])' }
+  const baseDir = folderPath && typeof folderPath === 'string' ? folderPath : projectRoot
+  const isAbsoluteOrDrive = template.startsWith('/') || template.includes(':')
+  const hasSubPath = template.includes('/') || template.includes(path.sep)
+  const templatePath = isAbsoluteOrDrive ? template : (hasSubPath ? path.join(baseDir, template) : path.join(baseDir, 'resources', template))
+
+  const result = await runPipeline(device, templatePath, folderPath)
+  if (!result.success) return { success: false, error: result.error }
+  const center = result.clickPosition || {
+    x: result.coordinate.x + result.coordinate.width / 2,
+    y: result.coordinate.y + result.coordinate.height / 2,
+  }
+  return {
+    success: true,
+    center,
+    coordinate: result.coordinate,
+    workDir: result.workDir,
+    markedScreenshotPath: result.markedScreenshotPath || null,
+  }
+}
+
+module.exports = {
+  tagName,
+  schema,
+  executeImgCenterPointLocation,
+  matchImageAndGetCoordinate,
+  /** @deprecated 与 matchImageAndGetCoordinate 相同 */
+  matchImageAndGetCoordinateAsync: matchImageAndGetCoordinate,
+  /** 解析后的模型名 + 当前 ai 包 baseUrl(测试 / 调试) */
+  getImgCenterModel,
+  getImgCenterAiMeta,
+  /** 中心点 VLM 实际调用链(主 + 备用,≤3) */
+  getCenterPointVlmModelChain,
+}

+ 140 - 0
nodejs/ef-compiler/actions/fun/img/img-cropping.js

@@ -0,0 +1,140 @@
+/**
+ * fun 标签:img-cropping — 按正方形区域居中裁剪图片并保存
+ *
+ * 入参(inVars 顺序):
+ * - imagePath:源图片路径(相对当前流程目录或绝对路径)
+ * - savePath:输出路径(相对流程目录或绝对路径)
+ * - squareSpec:边长规则 [scale, 轴]:
+ *   - [0.8, "w"] / "0.8,w" / "[0.8,w]":边长 = 图片宽度 × 0.8
+ *   - [1, "h"] / "1,h":边长 = 图片高度 × 1
+ *   轴:w / width / 宽;h / height / 高
+ *
+ * 若 scale×参照边 大于 min(宽,高),会夹紧到 min(宽,高)。
+ */
+
+const path = require('path')
+const fs = require('fs')
+const { spawnSync } = require('child_process')
+const { getPythonExeFromConfig } = require('../../../../python-exe-from-config.js')
+
+const tagName = 'img-cropping'
+
+const configPath = process.env.STATIC_ROOT
+  ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
+  : path.join(__dirname, '..', '..', '..', '..', '..', 'config.js')
+const config = fs.existsSync(configPath) ? require(configPath) : {}
+const projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot))
+  ? config.projectRoot
+  : path.dirname(path.resolve(configPath))
+const squareCropScriptPath = path.join(projectRoot, 'python', 'scripts', 'img-crop-square-center.py')
+
+function buildAbsolutePath (p, folderPath) {
+  if (p == null || p === '') return null
+  const s = typeof p === 'string' ? p.trim() : String(p)
+  if (!s) return null
+  if (path.isAbsolute(s) || /^[A-Za-z]:/.test(s)) return path.normalize(s)
+  return folderPath ? path.join(folderPath, s) : path.resolve(projectRoot, s)
+}
+
+/** @returns {{ scale: number, axis: 'w'|'h' } | { error: string }} */
+function parseSquareSpec (raw) {
+  if (raw == null) return { error: 'squareSpec 为空' }
+
+  if (Array.isArray(raw)) {
+    if (raw.length < 2) return { error: 'squareSpec 数组须为 [scale, 轴],如 [0.8, "w"]' }
+    const scale = Number(raw[0])
+    const axis = normalizeAxis(raw[1])
+    if (Number.isNaN(scale) || scale <= 0) return { error: 'squareSpec 中 scale 须为大于 0 的数字' }
+    if (!axis) return { error: 'squareSpec 中轴须为 w 或 h(宽/高)' }
+    return { scale, axis }
+  }
+
+  const str = typeof raw === 'string' ? raw.trim() : String(raw).trim()
+  if (!str) return { error: 'squareSpec 为空字符串' }
+
+  try {
+    const j = JSON.parse(str)
+    if (Array.isArray(j) && j.length >= 2) return parseSquareSpec(j)
+  } catch (_) { /* 非 JSON */ }
+
+  const inner = str.startsWith('[') && str.endsWith(']') ? str.slice(1, -1).trim() : str
+  const parts = inner.split(',').map((s) => s.trim().replace(/^["']|["']$/g, ''))
+  if (parts.length < 2) return { error: 'squareSpec 格式须为 [scale,轴] 或 "scale,轴",如 [0.8,w] 或 1,h' }
+  const scale = Number(parts[0])
+  const axis = normalizeAxis(parts[1])
+  if (Number.isNaN(scale) || scale <= 0) return { error: 'scale 须为大于 0 的数字' }
+  if (!axis) return { error: '轴须为 w 或 h' }
+  return { scale, axis }
+}
+
+function normalizeAxis (v) {
+  if (v == null) return null
+  const a = String(v).trim().toLowerCase()
+  if (a === 'w' || a === 'width' || a === '宽') return 'w'
+  if (a === 'h' || a === 'height' || a === '高') return 'h'
+  return null
+}
+
+/**
+ * @param {{ imagePath: string, squareSpec: string|any[], savePath: string, folderPath?: string }} input
+ */
+async function executeImgCropping ({ imagePath, squareSpec, savePath, folderPath }) {
+  if (imagePath == null || String(imagePath).trim() === '') {
+    return { success: false, error: 'img-cropping 缺少 imagePath(inVars[0])' }
+  }
+  if (savePath == null || String(savePath).trim() === '') {
+    return { success: false, error: 'img-cropping 缺少 savePath(inVars[1])' }
+  }
+  if (squareSpec === undefined || squareSpec === null || squareSpec === '') {
+    return { success: false, error: 'img-cropping 缺少 squareSpec(inVars[2],如 [0.8,"w"])' }
+  }
+
+  const spec = parseSquareSpec(squareSpec)
+  if (spec.error) return { success: false, error: `img-cropping squareSpec 无效: ${spec.error}` }
+
+  const absIn = buildAbsolutePath(String(imagePath).trim(), folderPath)
+  const absOut = buildAbsolutePath(String(savePath).trim(), folderPath)
+  if (!absIn || !absOut) return { success: false, error: 'img-cropping 路径无效' }
+  if (!fs.existsSync(absIn)) return { success: false, error: `img-cropping 源文件不存在: ${absIn}` }
+  if (!fs.existsSync(squareCropScriptPath)) {
+    return { success: false, error: `脚本不存在: ${squareCropScriptPath}` }
+  }
+
+  const pythonExe = getPythonExeFromConfig(config)
+  const r = spawnSync(
+    pythonExe,
+    [squareCropScriptPath, absIn, absOut, String(spec.scale), spec.axis],
+    {
+      encoding: 'utf-8',
+      timeout: 60000,
+      cwd: projectRoot
+    }
+  )
+
+  const combined = ((r.stdout || '') + '\n' + (r.stderr || '')).trim()
+  if (r.status !== 0) {
+    return { success: false, error: combined || 'img-cropping 裁剪失败' }
+  }
+
+  let meta = {}
+  const lines = (r.stdout || '').split(/\r?\n/).map((l) => l.trim()).filter(Boolean)
+  const jsonLine = lines.filter((l) => l.startsWith('{')).pop()
+  if (jsonLine) {
+    try {
+      meta = JSON.parse(jsonLine)
+    } catch (_) {}
+  }
+  if (!meta.success) {
+    return { success: false, error: meta.error || combined || 'img-cropping 未返回成功' }
+  }
+
+  return {
+    success: true,
+    path: absOut,
+    value: absOut,
+    result: absOut,
+    crop: { x: meta.x, y: meta.y, side: meta.side, sourceWidth: meta.width, sourceHeight: meta.height }
+  }
+}
+
+module.exports = { tagName, executeImgCropping, parseSquareSpec }

+ 101 - 0
nodejs/ef-compiler/actions/fun/img/img-scale.js

@@ -0,0 +1,101 @@
+/**
+ * fun 结点:img-scale — 图片等比缩放(宽高同乘同一比例)
+ *
+ * 入参(inVars 顺序,与 img-cropping 一致:先路径再参数):
+ * - imagePath:源图(相对流程目录或绝对路径)
+ * - savePath:输出路径
+ * - scale:缩放系数。支持:
+ *   - 0.8 → 长宽各 ×0.8
+ *   - "80%" → 80%;整数 10~100 视为百分比,如 80 即 80%
+ *   - 1~10 之间的小数/整数为倍率,如 2 即 2 倍、1.5 即 1.5 倍(整数 10 为 10% 而非 10 倍)
+ */
+
+const path = require('path')
+const fs = require('fs')
+const { spawnSync } = require('child_process')
+const { getPythonExeFromConfig } = require('../../../../python-exe-from-config.js')
+
+const tagName = 'img-scale'
+
+const configPath = process.env.STATIC_ROOT
+  ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
+  : path.join(__dirname, '..', '..', '..', '..', '..', 'config.js')
+const config = fs.existsSync(configPath) ? require(configPath) : {}
+const projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot))
+  ? config.projectRoot
+  : path.dirname(path.resolve(configPath))
+const scaleScriptPath = path.join(projectRoot, 'python', 'scripts', 'img-scale-proportional.py')
+
+function buildAbsolutePath (p, folderPath) {
+  if (p == null || p === '') return null
+  const s = typeof p === 'string' ? p.trim() : String(p)
+  if (!s) return null
+  if (path.isAbsolute(s) || /^[A-Za-z]:/.test(s)) return path.normalize(s)
+  return folderPath ? path.join(folderPath, s) : path.resolve(projectRoot, s)
+}
+
+/**
+ * @param {{ imagePath: string, savePath: string, scale: string|number, folderPath?: string }} input
+ */
+async function executeImgScale ({ imagePath, savePath, scale, folderPath }) {
+  if (imagePath == null || String(imagePath).trim() === '') {
+    return { success: false, error: 'img-scale 缺少 imagePath(inVars[0])' }
+  }
+  if (savePath == null || String(savePath).trim() === '') {
+    return { success: false, error: 'img-scale 缺少 savePath(inVars[1])' }
+  }
+  if (scale === undefined || scale === null || scale === '') {
+    return { success: false, error: 'img-scale 缺少 scale(inVars[2],如 0.8 或 80%)' }
+  }
+
+  const scaleStr = typeof scale === 'number' && Number.isFinite(scale) ? String(scale) : String(scale).trim()
+  if (!scaleStr) return { success: false, error: 'img-scale scale 无效' }
+
+  const absIn = buildAbsolutePath(String(imagePath).trim(), folderPath)
+  const absOut = buildAbsolutePath(String(savePath).trim(), folderPath)
+  if (!absIn || !absOut) return { success: false, error: 'img-scale 路径无效' }
+  if (!fs.existsSync(absIn)) return { success: false, error: `img-scale 源文件不存在: ${absIn}` }
+  if (!fs.existsSync(scaleScriptPath)) {
+    return { success: false, error: `脚本不存在: ${scaleScriptPath}` }
+  }
+
+  const pythonExe = getPythonExeFromConfig(config)
+  const r = spawnSync(pythonExe, [scaleScriptPath, absIn, absOut, scaleStr], {
+    encoding: 'utf-8',
+    timeout: 120_000,
+    cwd: projectRoot,
+  })
+
+  const combined = ((r.stdout || '') + '\n' + (r.stderr || '')).trim()
+  if (r.status !== 0) {
+    return { success: false, error: combined || 'img-scale 执行失败' }
+  }
+
+  let meta = {}
+  const lines = (r.stdout || '').split(/\r?\n/).map((l) => l.trim()).filter(Boolean)
+  const jsonLine = lines.filter((l) => l.startsWith('{')).pop()
+  if (jsonLine) {
+    try {
+      meta = JSON.parse(jsonLine)
+    } catch (_) {}
+  }
+  if (!meta.success) {
+    return { success: false, error: meta.error || combined || 'img-scale 未返回成功' }
+  }
+
+  return {
+    success: true,
+    path: absOut,
+    value: absOut,
+    result: absOut,
+    scaled: {
+      scale: meta.scale,
+      sourceWidth: meta.sourceWidth,
+      sourceHeight: meta.sourceHeight,
+      width: meta.width,
+      height: meta.height,
+    },
+  }
+}
+
+module.exports = { tagName, executeImgScale }

+ 17 - 2
nodejs/ef-compiler/actions/json/json-to-arr.js → nodejs/ef-compiler/actions/fun/json/json-to-arr.js

@@ -1,6 +1,7 @@
 /**
  * json 结点:将 JSON 字符串解析为数组(如 img-url-json -> img-url-arr)
  * 对应 process:type: "json", method: "json-to-arr", inVars: ["{img-url-json}"], outVars: ["{img-url-arr}"]
+ * 或统一写法:type: "fun", method: "json-to-arr", inVars / outVars 可省略
  */
 
 /**
@@ -10,13 +11,27 @@
  */
 async function executeJsonToArr({ jsonString }) {
   if (jsonString == null) return { success: false, error: 'json-to-arr 缺少输入(如 inVars[0])' }
-  const str = typeof jsonString === 'string' ? jsonString.trim() : String(jsonString).trim()
+  let str = typeof jsonString === 'string' ? jsonString.trim() : String(jsonString).trim()
+  // 去掉可能的 markdown 代码块包裹
+  const codeBlockMatch = str.match(/^```(?:json)?\s*([\s\S]*?)```\s*$/m)
+  if (codeBlockMatch) str = codeBlockMatch[1].trim()
   if (!str) return { success: false, error: 'json-to-arr 输入为空' }
   let parsed
   try {
     parsed = JSON.parse(str)
   } catch (e) {
-    return { success: false, error: `JSON 解析失败: ${e && e.message ? e.message : String(e)}` }
+    // AI 可能截断:尝试补全后再解析
+    if (str.startsWith('[')) {
+      if (/^\[\s*"[^"]*$/.test(str)) {
+        try { parsed = JSON.parse(str + '"]') } catch (_) {}
+      }
+      if (parsed == null && str.endsWith('"')) {
+        try { parsed = JSON.parse(str + ']') } catch (_) {}
+      }
+    }
+    if (parsed == null) {
+      return { success: false, error: `JSON 解析失败: ${e && e.message ? e.message : String(e)}` }
+    }
   }
   const arr = Array.isArray(parsed) ? parsed : (parsed != null ? [parsed] : [])
   return { success: true, result: arr }

+ 18 - 13
nodejs/ef-compiler/actions/fun/ocr.js

@@ -9,12 +9,15 @@ const fs = require('fs')
 const os = require('os')
 const { spawnSync } = require('child_process')
 const { captureScreenshot } = require('../../../adb/adb-screencap.js')
+const { getPythonExeFromConfig } = require('../../../python-exe-from-config.js')
 
 const configPath = process.env.STATIC_ROOT
-  ? path.join(path.dirname(process.env.STATIC_ROOT), 'config.js')
+  ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
   : path.join(__dirname, '..', '..', '..', '..', 'config.js')
-const projectRoot = path.dirname(path.resolve(configPath))
 const config = fs.existsSync(configPath) ? require(configPath) : {}
+const projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot))
+  ? config.projectRoot
+  : path.dirname(path.resolve(configPath))
 const ocrScriptPath = path.join(projectRoot, 'python', 'scripts', 'ocr.py')
 
 const tagName = 'ocr'
@@ -25,13 +28,15 @@ const schema = {
   outputs: { variable: '识别文本 或 中心点 JSON' },
 }
 
-function getPythonPath() {
-  const base = config.pythonPath?.path || path.join(projectRoot, 'python', 'py')
-  const winPy = path.join(base, 'python.exe')
-  if (fs.existsSync(winPy)) return winPy
-  const unixPy = path.join(base, 'python')
-  if (fs.existsSync(unixPy)) return unixPy
-  return 'python'
+/** 仅 RapidOCR 需要 opencv:使用本地 python/opencv-4.13.0 时在 spawn 时注入 PYTHONPATH */
+function getOcrEnv() {
+  const env = { ...process.env, PYTHONIOENCODING: 'utf-8' }
+  const opencvDir = path.join(projectRoot, 'python', 'opencv-4.13.0')
+  if (fs.existsSync(opencvDir)) {
+    const prev = env.PYTHONPATH || ''
+    env.PYTHONPATH = prev ? `${opencvDir}${path.delimiter}${prev}` : opencvDir
+  }
+  return env
 }
 
 /**
@@ -55,11 +60,11 @@ async function executeOcr({ imagePath, folderPath }) {
     return { success: false, error: `图片不存在: ${resolvedImage}` }
   }
 
-  const pythonPath = getPythonPath()
+  const pythonPath = getPythonExeFromConfig(config)
   const r = spawnSync(pythonPath, [ocrScriptPath, '--image', resolvedImage, '--project-root', projectRoot], {
     encoding: 'utf-8',
     timeout: 60000,
-    env: { ...process.env, PYTHONIOENCODING: 'utf-8' },
+    env: getOcrEnv(),
     cwd: projectRoot,
   })
 
@@ -98,11 +103,11 @@ async function executeOcrFindText({ device, findText, folderPath }) {
     if (!fs.existsSync(screenshotPath) || fs.statSync(screenshotPath).size === 0) {
       return { success: false, error: '设备截图失败或为空' }
     }
-    const pythonPath = getPythonPath()
+    const pythonPath = getPythonExeFromConfig(config)
     const r = spawnSync(pythonPath, [ocrScriptPath, '--image', screenshotPath, '--find-text', findText.trim(), '--project-root', projectRoot], {
       encoding: 'utf-8',
       timeout: 60000,
-      env: { ...process.env, PYTHONIOENCODING: 'utf-8' },
+      env: getOcrEnv(),
       cwd: projectRoot,
     })
     const outStr = (r.stdout || '').trim()

+ 19 - 10
nodejs/ef-compiler/ef-compiler.js

@@ -4,10 +4,10 @@
 const path = require('path')
 const fs = require('fs')
 
-// --- 配置:projectRoot = 与根目录 config.js 同目录(仓库根或 exe 同目录)---
+// --- 配置:根目录 config.js(与 Electron 共用);projectRoot 优先取 config.projectRoot ---
 const defaultRoot = path.resolve(__dirname, '..', '..')
 const configPath = process.env.STATIC_ROOT
-  ? path.join(path.dirname(process.env.STATIC_ROOT), 'config.js')
+  ? path.join(path.dirname(path.resolve(process.env.STATIC_ROOT)), 'config.js')
   : path.join(defaultRoot, 'config.js')
 let projectRoot = path.dirname(path.resolve(configPath))
 let adbInteractPath = path.join(projectRoot, 'nodejs', 'adb', 'adb-interact.js')
@@ -15,6 +15,7 @@ let nodeExePath = 'node'
 if (fs.existsSync(configPath)) {
   try {
     const cfg = require(configPath)
+    if (cfg.projectRoot && fs.existsSync(cfg.projectRoot)) projectRoot = cfg.projectRoot
     nodeExePath = cfg.nodejsPath
       ? (path.isAbsolute(cfg.nodejsPath) ? cfg.nodejsPath : path.join(projectRoot, cfg.nodejsPath))
       : (process.env.STATIC_ROOT ? path.join(projectRoot, 'node', process.platform === 'win32' ? 'node.exe' : 'node') : 'node')
@@ -38,14 +39,7 @@ const workflowJsonParser = require('./workflow-json-parser.js')
 const sequenceRunner = require('./sequence-runner.js')
 const actions = require('./actions/fun/fun-parser.js')
 
-// --- 功能模块(fun 目录)与运行时 API ---
-const { matchImageAndGetCoordinate } = require('./actions/fun/img-center-point-location.js')
-const { readTextFile } = require('./actions/fun/read-txt.js')
-const { writeTextFile } = require('./actions/fun/save-txt.js')
-
-const electronAPI = runtimeApi.createElectronAPI({ matchImageAndGetCoordinate, readTextFile, writeTextFile }, compilerConfig)
-
-// --- 共享状态(变量上下文、步骤计数、当前工作流目录等)---
+// --- 共享状态(须早于 electronAPI:匹配 tmp 依赖当前工作流目录)---
 const state = {
   variableContext: {},
   variableContextInitialized: false,
@@ -54,6 +48,21 @@ const state = {
   declaredVariableNames: [], // workflow.variables 的 key,用于校验 inVars/outVars 引用是否已声明
 }
 
+// --- 功能模块(fun 目录)与运行时 API ---
+const imgCenterLocation = require('./actions/fun/img/img-center-point-location.js')
+const { readTextFile } = require('./actions/fun/IO/read-txt.js')
+const { writeTextFile } = require('./actions/fun/IO/save-txt.js')
+
+/** 按图点击等:临时目录落在「当前流程文件夹/tmp/」;调用方也可传第三参覆盖 */
+async function matchImageAndGetCoordinate (device, imagePath, folderPathFromCaller) {
+  const fp = folderPathFromCaller != null && folderPathFromCaller !== ''
+    ? folderPathFromCaller
+    : state.currentWorkflowFolderPath
+  return imgCenterLocation.matchImageAndGetCoordinate(device, imagePath, fp)
+}
+
+const electronAPI = runtimeApi.createElectronAPI({ matchImageAndGetCoordinate, readTextFile, writeTextFile }, compilerConfig)
+
 // --- 从各组件抽出的工具方法(供本文件与 ctx 使用)---
 const extractVarName = setParser.extractVarName
 const replaceVariablesInString = setParser.replaceVariablesInString

+ 45 - 22
nodejs/ef-compiler/sequence-runner.js

@@ -3,9 +3,13 @@ const variableParser = require('./variable-parser.js')
 
 /**
  * 执行操作序列(schedule/if/for/while + 普通步骤)
- * 单文件 ≤500 行。ctx: executeAction, logMessage, evaluateCondition, getActionName, parseDelayString, calculateWaitTime, state
- * state: variableContext, globalStepCounter, currentWorkflowFolderPath, variableContextInitialized
+ * 规则:除 schedule 的定时间隔外,均为同步串行——上一步 await 结束才执行下一步;任一步失败立即 return,后续步骤不执行。
+ * 根级 execute 数组之间可有 stepInterval;for/if/while/try 内部及 schedule 单次 tick 内步骤间隔为 0。
  */
+function isOk(r) {
+  return r && r.success === true
+}
+
 async function executeActionSequence(
   actions,
   device,
@@ -20,6 +24,8 @@ async function executeActionSequence(
   const { executeAction, logMessage, evaluateCondition, getActionName, parseDelayString, calculateWaitTime, replaceVariablesInString, state } = ctx
   const variableContext = state.variableContext
   const DEFAULT_STEP_INTERVAL = ctx.DEFAULT_STEP_INTERVAL ?? 1000
+  /** 嵌套体内步骤紧接执行;仅根级兄弟结点之间用 stepInterval */
+  const innerInterval = 0
 
   if (depth === 0) {
     state.globalStepCounter = 0
@@ -65,8 +71,10 @@ async function executeActionSequence(
           }
         }
         if (actionsToExecute.length > 0) {
-          const result = await executeActionSequence(actionsToExecute, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
-          if (!result.success) return result
+          const result = await executeActionSequence(actionsToExecute, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
+          if (!isOk(result)) {
+            return { success: false, error: (result && result.error != null) ? String(result.error) : 'failed', completedSteps: result && result.completedSteps != null ? result.completedSteps : completedSteps }
+          }
           completedSteps += result.completedSteps || 0
         }
       }
@@ -77,8 +85,10 @@ async function executeActionSequence(
       const conditionResult = evaluateCondition(action.condition, variableContext)
       const actionsToExecute = conditionResult ? (action.then || action.ture || []) : (action.else || action.false || [])
       if (actionsToExecute.length > 0) {
-        const result = await executeActionSequence(actionsToExecute, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
-        if (!result.success) return result
+        const result = await executeActionSequence(actionsToExecute, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
+        if (!isOk(result)) {
+          return { success: false, error: (result && result.error != null) ? String(result.error) : 'failed', completedSteps: result && result.completedSteps != null ? result.completedSteps : completedSteps }
+        }
         completedSteps += result.completedSteps || 0
       }
       continue
@@ -100,8 +110,10 @@ async function executeActionSequence(
           if (shouldStop && shouldStop()) return { success: false, error: 'Execution stopped', completedSteps }
           if (action.variable) variableContext[action.variable.replace(/^\{|\}$/g, '').trim()] = i
           if (action.body && action.body.length > 0) {
-            const result = await executeActionSequence(action.body, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
-            if (!result.success) return result
+            const result = await executeActionSequence(action.body, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
+            if (!isOk(result)) {
+              return { success: false, error: (result && result.error != null) ? String(result.error) : 'failed', completedSteps: result && result.completedSteps != null ? result.completedSteps : completedSteps }
+            }
             completedSteps += result.completedSteps || 0
           }
         }
@@ -125,8 +137,10 @@ async function executeActionSequence(
           if (indexKey !== null) variableContext[indexKey] = i
           if (variableKey !== null) variableContext[variableKey] = items[i]
           if (action.body && action.body.length > 0) {
-            const result = await executeActionSequence(action.body, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
-            if (!result.success) return result
+            const result = await executeActionSequence(action.body, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
+            if (!isOk(result)) {
+              return { success: false, error: (result && result.error != null) ? String(result.error) : 'failed', completedSteps: result && result.completedSteps != null ? result.completedSteps : completedSteps }
+            }
             completedSteps += result.completedSteps || 0
           }
         }
@@ -139,8 +153,10 @@ async function executeActionSequence(
       while (evaluateCondition(action.condition, variableContext)) {
         if (shouldStop && shouldStop()) return { success: false, error: 'Execution stopped', completedSteps }
         if (action.body && action.body.length > 0) {
-          const result = await executeActionSequence(action.body, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
-          if (!result.success) return result
+          const result = await executeActionSequence(action.body, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
+          if (!isOk(result)) {
+            return { success: false, error: (result && result.error != null) ? String(result.error) : 'failed', completedSteps: result && result.completedSteps != null ? result.completedSteps : completedSteps }
+          }
           completedSteps += result.completedSteps || 0
         }
       }
@@ -151,23 +167,30 @@ async function executeActionSequence(
       const tryActions = action.try || action.body || []
       const successActions = action.success || []
       const failActions = action.fail || action.catch || []
+      /** try 主路径失败后:即使 fail 分支执行成功,默认仍向上返回失败,避免 for 进入下一轮或继续执行后续兄弟步骤。需继续时请设 continueAfterFail: true */
+      const continueAfterFail = action.continueAfterFail === true
       const result = tryActions.length > 0
-        ? await executeActionSequence(tryActions, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
+        ? await executeActionSequence(tryActions, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
         : { success: true, completedSteps: 0 }
-      if (result.success && successActions.length > 0) {
-        const successResult = await executeActionSequence(successActions, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
-        if (!successResult.success) return successResult
+      if (isOk(result) && successActions.length > 0) {
+        const successResult = await executeActionSequence(successActions, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
+        if (!isOk(successResult)) {
+          return { success: false, error: (successResult && successResult.error != null) ? String(successResult.error) : 'failed', completedSteps }
+        }
         completedSteps += (result.completedSteps || 0) + (successResult.completedSteps || 0)
-      } else if (result.success) {
+      } else if (isOk(result)) {
         completedSteps += result.completedSteps || 0
       } else {
         const errMsg = (result.error != null && result.error !== '') ? String(result.error) : 'Unknown error'
         const timeStr = new Date().toISOString().replace('T', ' ').slice(0, 19)
         await logMessage(`[sequence-runner] [try failed] ${timeStr} ${errMsg}`, folderPath).catch(() => {})
         if (failActions.length > 0) {
-          const failResult = await executeActionSequence(failActions, device, folderPath, resolution, interval, onStepComplete, shouldStop, depth + 1, ctx)
-          if (!failResult.success) return failResult
+          const failResult = await executeActionSequence(failActions, device, folderPath, resolution, innerInterval, onStepComplete, shouldStop, depth + 1, ctx)
+          if (!isOk(failResult)) return failResult
           completedSteps += (result.completedSteps || 0) + (failResult.completedSteps || 0)
+          if (!continueAfterFail) {
+            return { success: false, error: errMsg, completedSteps }
+          }
         } else {
           return result
         }
@@ -212,12 +235,12 @@ async function executeActionSequence(
 
       const result = await executeAction(action, device, folderPath, resolution)
 
-      if (result.success && result.skipped) { /* 步骤跳过不写 log */ }
+      if (isOk(result) && result.skipped) { /* 步骤跳过不写 log */ }
 
       // 统一由 echo-parser.logActionError 打印结点报错,各结点只需 return { success: false, error } 即可
-      if (!result.success) {
+      if (!isOk(result)) {
         await logActionError(action, result, { getActionName, logMessage, folderPath }).catch(() => {})
-        const errDetail = result.error != null && result.error !== '' ? String(result.error) : 'unknown'
+        const errDetail = result && result.error != null && result.error !== '' ? String(result.error) : 'unknown'
         return { success: false, error: errDetail, completedSteps: i }
       }
 

+ 2 - 0
nodejs/ef-compiler/variable-parser.js

@@ -16,6 +16,8 @@ const INPUT_KEYS = [
   'inputDataString', 'textVariable', 'senderVariable', 'appendMode',
   'summaryPrompt', 'historyPrompt', 'model', 'prompt', 'systemPrompt',
   'regionArea', 'saveDir', 'url', 'filename', 'imageUrl',
+  'imagePath', 'squareSpec', 'scale', 'method',
+  'recursive',
 ]
 
 /**

+ 11 - 3
nodejs/ef-compiler/workflow-json-parser.js

@@ -10,7 +10,7 @@ const funNodeRegistry = require('./actions/fun/fun-node-registry.js')
 const actionModules = [
   require('./actions/delay-parser.js'),
   setParser,
-  require('./actions/adb/adb-parser.js'),
+  require('./actions/fun/adb/adb-parser.js'),
   require('./actions/echo-parser.js'),
   require('./actions/random-parser.js'),
   require('./actions/schedule-parser.js'),
@@ -66,6 +66,7 @@ function getActionName(action) {
     'img-bounding-box-location': 'img bounding box',
     'img-center-point-location': 'img center point',
     'img-cropping': 'img cropping',
+    'img-scale': 'img scale',
     'read-last-message': 'read last message',
     'read-txt': 'read text file',
     'read-text': 'read text file',
@@ -83,7 +84,9 @@ function getActionName(action) {
     'echo': 'echo',
   }
   ;(funNodeRegistry || []).forEach((def) => { typeNames[def.type] = def.displayName || def.type })
-  const typeName = (action.type === 'fun' || action.type === 'ai') ? (typeNames[action.method] || action.method || action.type) : (typeNames[action.type] || action.type)
+  const typeName = (action.type === 'fun' || action.type === 'ai' || action.type === 'io')
+    ? (typeNames[action.method] || action.method || action.type)
+    : (typeNames[action.type] || action.type)
   const value = action.value || action.target || ''
   const displayValue = typeof value === 'string' ? value : JSON.stringify(value)
   if (action.type === 'schedule') {
@@ -103,7 +106,12 @@ function getActionName(action) {
   if (action.type === 'for') return `${typeName}: ${action.variable || ''}`
   if (action.type === 'try') return typeName
   if (action.type === 'set') return `${typeName}: ${action.variable || ''}`
-  if (action.type === 'fun' || action.type === 'ai') return `${typeName}: ${displayValue}`
+  if (action.type === 'fun' || action.type === 'ai' || action.type === 'io') {
+    const m = action.method && String(action.method)
+    if (m && m.startsWith('adb-')) return `adb ${m.slice(4)}: ${displayValue}`
+    if (m === 'json-to-arr' || m === 'json-json-to-arr') return `json to arr: ${displayValue}`
+    return `${typeName}: ${displayValue}`
+  }
   return `${typeName}: ${displayValue}`
 }
 

+ 24 - 0
nodejs/python-exe-from-config.js

@@ -0,0 +1,24 @@
+'use strict'
+const path = require('path')
+
+/**
+ * Python 可执行文件路径:只认根目录 config.js 的导出(pythonExePath / pythonDir / pythonPath.path)。
+ * 禁止在业务代码里拼接 python/x64、venv、py/python.exe 等路径。
+ *
+ * @param {object} config - require(项目根/config.js) 的结果
+ * @returns {string}
+ */
+function getPythonExeFromConfig (config) {
+  if (!config || typeof config !== 'object') return 'python'
+  const explicit = config.pythonExePath != null && String(config.pythonExePath).trim()
+  if (explicit) {
+    const p = String(config.pythonExePath).trim()
+    return path.isAbsolute(p) ? p : path.join(config.projectRoot || '', p)
+  }
+  const dir = config.pythonDir || (config.pythonPath && config.pythonPath.path)
+  if (!dir) return 'python'
+  const base = path.isAbsolute(dir) ? dir : path.join(config.projectRoot || '', dir)
+  return path.join(base, process.platform === 'win32' ? 'python.exe' : 'python')
+}
+
+module.exports = { getPythonExeFromConfig }

+ 11 - 5
nodejs/run-process.js

@@ -48,8 +48,10 @@ ensureProcessDirAndLog()
 let config, projectRoot, adbPath, ipListJson, ipList, actions, resolution, executeActionSequence
 try {
   const configPath = process.env.STATIC_ROOT ? path.join(path.dirname(staticRoot), 'config.js') : path.join(__dirname, '..', 'config.js')
-  projectRoot = path.dirname(path.resolve(configPath))
   config = require(configPath)
+  projectRoot = (config.projectRoot && fs.existsSync(config.projectRoot))
+    ? config.projectRoot
+    : path.dirname(path.resolve(configPath))
   adbPath = config.adbPath?.path
     ? (path.isAbsolute(config.adbPath.path) ? config.adbPath.path : path.resolve(projectRoot, config.adbPath.path))
     : path.join(projectRoot, 'lib', 'scrcpy-adb', process.platform === 'win32' ? 'adb.exe' : 'adb')
@@ -119,7 +121,7 @@ async function ensureDeviceConnected(ip, port, logLineFn) {
   return false
 }
 
-/** 启动执行:遍历 ip 列表并异步执行脚本;任一台失败则停止全部并返回失败设备 IP;log.txt 仅写入报错与带 log:true 的 echo */
+/** 启动执行:多设备串行(一台跑完再跑下一台,与流程步骤同步语义一致);任一台失败则 shouldStop;log.txt 仅写入报错与带 log:true 的 echo */
 async function start() {
   logLine(`Process "${scriptName}" start, devices: ${ipList.length}`)
   let failedIp = null
@@ -135,13 +137,17 @@ async function start() {
     }
     logLine(`Running on ${deviceId}`)
     const result = await executeActionSequence(actions, deviceId, folderPath, resolution, 1000, null, () => shouldStop)
-    if (!result.success) {
+    const ok = result && result.success === true
+    if (!ok) {
       if (!failedIp) { failedIp = ip; shouldStop = true }
     }
-    return { ip, success: result.success }
+    return { ip, success: ok }
   }
 
-  const results = await Promise.all(ipList.map(ip => runOne(ip)))
+  const results = []
+  for (const ip of ipList) {
+    results.push(await runOne(ip))
+  }
 
   const output = failedIp
     ? { success: false, failedIp, results }

+ 5 - 2
package/pack-resources/config.js

@@ -3,7 +3,8 @@ const path = require('path')
 const projectRoot = __dirname
 
 // Python:嵌入式解释器在 python/py
-const pythonPath = path.join(projectRoot, 'python', 'py')
+const pythonDir = path.join(projectRoot, 'python', 'py')
+const pythonExePath = path.join(pythonDir, process.platform === 'win32' ? 'python.exe' : 'python')
 
 // Node.js:打包目录下 node/node.exe
 const nodejsPath = path.join(projectRoot, 'node', process.platform === 'win32' ? 'node.exe' : 'node')
@@ -16,7 +17,9 @@ module.exports = {
   window: { width: 800, height: 600, autoHideMenuBar: true },
   devTools: { enabled: false },
   vite: { port: 9527, host: 'localhost' },
-  pythonPath: { path: pythonPath },
+  pythonPath: { path: pythonDir },
+  pythonDir,
+  pythonExePath,
   adbPath: { path: adbPath },
   nodejsPath,
 }

+ 8 - 3
package/pack-resources/electron-pack-win.js

@@ -331,15 +331,20 @@ async function main() {
     const packagedConfig = `// 打包后使用:python 在 exe 同目录 python/py,不依赖系统 Python
 const path = require('path')
 const projectRoot = path.dirname(process.execPath)
+const isWin = process.platform === 'win32'
+const pythonDir = path.join(projectRoot, 'python', 'py')
+const pythonExePath = path.join(pythonDir, isWin ? 'python.exe' : 'python')
 
 module.exports = {
   projectRoot,
   window: { width: 800, height: 600, autoHideMenuBar: true },
   devTools: { enabled: false },
   vite: { port: 9527, host: 'localhost' },
-  pythonPath: { path: path.join(projectRoot, 'python', 'py') },
-  adbPath: { path: path.join(projectRoot, 'scrcpy-adb', process.platform === 'win32' ? 'adb.exe' : 'adb') },
-  nodejsPath: path.join(projectRoot, 'node', process.platform === 'win32' ? 'node.exe' : 'node')
+  pythonPath: { path: pythonDir },
+  pythonDir,
+  pythonExePath,
+  adbPath: { path: path.join(projectRoot, 'scrcpy-adb', isWin ? 'adb.exe' : 'adb') },
+  nodejsPath: path.join(projectRoot, 'node', isWin ? 'node.exe' : 'node')
 }
 `
     try {

+ 0 - 25
python/RoMa/.github/actions/uv-build/action.yml

@@ -1,25 +0,0 @@
-name: uv-build
-description: Build the project with uv and run smoke tests
-inputs:
-  python-version:
-    description: Python version to install
-    required: true
-runs:
-  using: composite
-  steps:
-    - name: Install uv
-      uses: astral-sh/setup-uv@v6
-    - name: Install Python
-      shell: bash
-      run: uv python install ${{ inputs.python-version }}
-    - name: Build
-      shell: bash
-      run: uv build
-    - name: Smoke test (wheel)
-      shell: bash
-      run: uv run --isolated --no-project --with dist/*.whl tests/smoke_test.py
-    - name: Smoke test (source distribution)
-      shell: bash
-      run: uv run --isolated --no-project --with dist/*.tar.gz tests/smoke_test.py
-
-

+ 0 - 22
python/RoMa/.github/workflows/build.yml

@@ -1,22 +0,0 @@
-name: Build
-
-on:
-  push:
-    branches:
-      - main
-  pull_request:
-    branches:
-      - main
-jobs:
-  run:
-    runs-on: ubuntu-latest
-    strategy:
-      matrix:
-        python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
-    steps:
-      - name: Checkout
-        uses: actions/checkout@v5
-      - name: Build and smoke test
-        uses: ./.github/actions/uv-build
-        with:
-          python-version: ${{ matrix.python-version }}

+ 0 - 39
python/RoMa/.github/workflows/publish.yml

@@ -1,39 +0,0 @@
-name: PyPI Publish
-
-on:
-  push:
-    tags:
-      # Publish on any tag starting with a `v`, e.g., v0.1.0
-      - v*
-  workflow_dispatch:
-jobs:
-  build:
-    runs-on: ubuntu-latest
-    strategy:
-      matrix:
-        python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
-    steps:
-      - name: Checkout
-        uses: actions/checkout@v5
-      - name: Build and smoke test
-        uses: ./.github/actions/uv-build
-        with:
-          python-version: ${{ matrix.python-version }}
-
-  publish:
-    runs-on: ubuntu-latest
-    needs: build
-    environment:
-      name: pypi
-    permissions:
-      id-token: write
-      contents: read
-    steps:
-      - name: Checkout
-        uses: actions/checkout@v5
-      - name: Build and smoke test
-        uses: ./.github/actions/uv-build
-        with:
-          python-version: '3.12'
-      - name: Publish
-        run: uv publish

+ 0 - 12
python/RoMa/.gitignore

@@ -1,12 +0,0 @@
-*.egg-info*
-*.vscode*
-*__pycache__*
-vis*
-workspace*
-.venv
-.DS_Store
-jobs/*
-*ignore_me*
-*.pth
-wandb*
-results/*

+ 0 - 1
python/RoMa/.python-version

@@ -1 +0,0 @@
-3.12

+ 0 - 21
python/RoMa/LICENSE

@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2023 Johan Edstedt
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.

+ 0 - 163
python/RoMa/README.md

@@ -1,163 +0,0 @@
-# 
-<p align="center">
-  <h1 align="center"> <ins>RoMa</ins> 🏛️:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
-  <p align="center">
-    <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
-    ·
-    <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
-    ·
-    <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
-    ·
-    <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">Mårten Wadenbäck</a>
-    ·
-    <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
-  </p>
-  <h2 align="center"><p>
-    <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> | 
-    <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
-  </p></h2>
-  <div align="center"></div>
-</p>
-<br/>
-<p align="center">
-    <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
-    <br>
-    <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
-</p>
-
-## Setup/Install
-In your python environment (tested on Linux python 3.12), run:
-```bash
-uv pip install -e .
-```
-or 
-```bash
-uv sync
-```
-You can also install `romatch` directly as a package from PyPI by
-```bash
-uv pip install romatch
-```
-or 
-```bash
-uv add romatch
-```
-
-## Fused local correlation kernel
-Include the `--extra fused-local-corr` flag as:
-```bash
-uv sync --extra fused-local-corr
-```
-or 
-```bash
-uv pip install romatch[fused-local-corr]
-```
-or
-```bash
-uv add romatch[fused-local-corr]
-```
-
-
-
-## Demo / How to Use
-We provide two demos in the [demos folder](demo).
-Here's the gist of it:
-```python
-from romatch import roma_outdoor
-roma_model = roma_outdoor(device=device)
-# Match
-warp, certainty = roma_model.match(imA_path, imB_path, device=device)
-# Sample matches for estimation
-matches, certainty = roma_model.sample(warp, certainty)
-# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
-kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
-# Find a fundamental matrix (or anything else of interest)
-F, mask = cv2.findFundamentalMat(
-    kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
-)
-```
-
-**New**: You can also match arbitrary keypoints with RoMa. See [match_keypoints](romatch/models/matcher.py) in RegressionMatcher.
-
-## Settings
-
-### Resolution
-By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864). 
-You can change this at construction (see roma_outdoor kwargs).
-You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
-
-### Sampling
-roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
-
-
-## Reproducing Results
-The experiments in the paper are provided in the [experiments folder](experiments).
-
-### Training
-1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
-2. Run the relevant experiment, e.g.,
-```bash
-torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
-```
-### Testing
-```bash
-python experiments/roma_outdoor.py --only_test --benchmark mega-1500
-```
-## License
-All our code except DINOv2 is MIT license.
-DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
-
-## Acknowledgement
-Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
-
-## Tiny RoMa
-If you find that RoMa is too heavy, you might want to try Tiny RoMa which is built on top of XFeat.
-```python
-from romatch import tiny_roma_v1_outdoor
-tiny_roma_model = tiny_roma_v1_outdoor(device=device)
-```
-Mega1500:
-|  | AUC@5 | AUC@10 | AUC@20 |
-|----------|----------|----------|----------|
-| XFeat    | 46.4    | 58.9    | 69.2    |
-| XFeat*    |  51.9   | 67.2    | 78.9    |
-| Tiny RoMa v1    | 56.4 | 69.5 | 79.5     |
-| RoMa    |  -   | -    | -    |
-
-Mega-8-Scenes (See DKM):
-|  | AUC@5 | AUC@10 | AUC@20 |
-|----------|----------|----------|----------|
-| XFeat    | -    | -    | -    |
-| XFeat*    |  50.1   | 64.4    | 75.2    |
-| Tiny RoMa v1    | 57.7 | 70.5 | 79.6     |
-| RoMa    |  -   | -    | -    |
-
-IMC22 :'):
-|  | mAA@10 |
-|----------|----------|
-| XFeat    | 42.1    |
-| XFeat*    |  -   |
-| Tiny RoMa v1    | 42.2 |
-| RoMa    |  -   |
-
-## Reproducibility
-There are a few diffs in the current codebase compared to the original repo used to run experiments.
-
-1. The `scale_factor` used in the `match` method now is relative to the original training resolution of `560`. Previosly it was based on the set coarse resolution (which might or might not be `560`).
-2. Newer PyTorch, original code used something like `2.1`.
-3. Stochastic eval: both RANSAC and the chosen correspondences can affect results in `Mega1500`.
-4. Matrix inverse in GP has been replaced with cholesky decomp.
-
-That being said, if diff of results are $>0.5$ there probably is something wrong, please let me know.
-
-
-## BibTeX
-If you find our models useful, please consider citing our paper!
-```
-@inproceedings{edstedt2024roma,
-title={{RoMa: Robust Dense Feature Matching}},
-author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and Wadenbäck, Mårten and Felsberg, Michael},
-booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
-year={2024}
-}
-```

BIN
python/RoMa/assets/sacre_coeur_A.jpg


BIN
python/RoMa/assets/sacre_coeur_B.jpg


BIN
python/RoMa/assets/toronto_A.jpg


BIN
python/RoMa/assets/toronto_B.jpg


+ 0 - 2
python/RoMa/data/.gitignore

@@ -1,2 +0,0 @@
-*
-!.gitignore

+ 0 - 47
python/RoMa/demo/demo_3D_effect.py

@@ -1,47 +0,0 @@
-from PIL import Image
-import torch
-import torch.nn.functional as F
-import numpy as np
-from romatch.utils.utils import tensor_to_pil
-
-from romatch import roma_outdoor
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-if torch.backends.mps.is_available():
-    device = torch.device('mps')
-
-if __name__ == "__main__":
-    from argparse import ArgumentParser
-    parser = ArgumentParser()
-    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
-    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
-    parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
-
-    args, _ = parser.parse_known_args()
-    im1_path = args.im_A_path
-    im2_path = args.im_B_path
-    save_path = args.save_path
-
-    # Create model
-    roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
-    roma_model.symmetric = False
-
-    H, W = roma_model.get_output_resolution()
-
-    im1 = Image.open(im1_path).resize((W, H))
-    im2 = Image.open(im2_path).resize((W, H))
-
-    # Match
-    warp, certainty = roma_model.match(im1_path, im2_path, device=device)
-    # Sampling not needed, but can be done with model.sample(warp, certainty)
-    x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
-    x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
-
-    coords_A, coords_B = warp[...,:2], warp[...,2:]
-    for i, x in enumerate(np.linspace(0,2*np.pi,200)):
-        t = (1 + np.cos(x))/2
-        interp_warp = (1-t)*coords_A + t*coords_B
-        im2_transfer_rgb = F.grid_sample(
-        x2[None], interp_warp[None], mode="bilinear", align_corners=False
-        )[0]
-        tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")

+ 0 - 34
python/RoMa/demo/demo_fundamental.py

@@ -1,34 +0,0 @@
-from PIL import Image
-import torch
-import cv2
-from romatch import roma_outdoor
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-if torch.backends.mps.is_available():
-    device = torch.device('mps')
-
-if __name__ == "__main__":
-    from argparse import ArgumentParser
-    parser = ArgumentParser()
-    parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
-    parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
-
-    args, _ = parser.parse_known_args()
-    im1_path = args.im_A_path
-    im2_path = args.im_B_path
-
-    # Create model
-    roma_model = roma_outdoor(device=device)
-
-
-    W_A, H_A = Image.open(im1_path).size
-    W_B, H_B = Image.open(im2_path).size
-
-    # Match
-    warp, certainty = roma_model.match(im1_path, im2_path, device=device)
-    # Sample matches for estimation
-    matches, certainty = roma_model.sample(warp, certainty)
-    kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)    
-    F, mask = cv2.findFundamentalMat(
-        kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
-    )

+ 0 - 50
python/RoMa/demo/demo_match.py

@@ -1,50 +0,0 @@
-import os
-os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
-import torch
-from PIL import Image
-import torch.nn.functional as F
-import numpy as np
-from romatch.utils.utils import tensor_to_pil
-
-from romatch import roma_outdoor
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-if torch.backends.mps.is_available():
-    device = torch.device('mps')
-
-if __name__ == "__main__":
-    from argparse import ArgumentParser
-    parser = ArgumentParser()
-    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
-    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
-    parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
-
-    args, _ = parser.parse_known_args()
-    im1_path = args.im_A_path
-    im2_path = args.im_B_path
-    save_path = args.save_path
-
-    # Create model
-    roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
-
-    H, W = roma_model.get_output_resolution()
-
-    im1 = Image.open(im1_path).resize((W, H))
-    im2 = Image.open(im2_path).resize((W, H))
-
-    # Match
-    warp, certainty = roma_model.match(im1_path, im2_path, device=device)
-    # Sampling not needed, but can be done with model.sample(warp, certainty)
-    x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
-    x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
-
-    im2_transfer_rgb = F.grid_sample(
-    x2[None], warp[:, :, :W, 2:], mode="bilinear", align_corners=False
-    )[0]
-    im1_transfer_rgb = F.grid_sample(
-    x1[None], warp[:, :, W:, :2], mode="bilinear", align_corners=False
-    )[0]
-    warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
-    white_im = torch.ones((H,2*W),device=device)
-    vis_im = certainty * warp_im + (1 - certainty) * white_im
-    tensor_to_pil(vis_im, unnormalize=False).save(save_path)

+ 0 - 43
python/RoMa/demo/demo_match_opencv_sift.py

@@ -1,43 +0,0 @@
-from PIL import Image
-import numpy as np
-
-import numpy as np
-import cv2 as cv
-import matplotlib.pyplot as plt
-
-
-
-if __name__ == "__main__":
-    from argparse import ArgumentParser
-    parser = ArgumentParser()
-    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
-    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
-    parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
-
-    args, _ = parser.parse_known_args()
-    im1_path = args.im_A_path
-    im2_path = args.im_B_path
-    save_path = args.save_path
-
-    img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE)          # queryImage
-    img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
-    # Initiate SIFT detector
-    sift = cv.SIFT_create()
-    # find the keypoints and descriptors with SIFT
-    kp1, des1 = sift.detectAndCompute(img1,None)
-    kp2, des2 = sift.detectAndCompute(img2,None)
-    # BFMatcher with default params
-    bf = cv.BFMatcher()
-    matches = bf.knnMatch(des1,des2,k=2)
-    # Apply ratio test
-    good = []
-    for m,n in matches:
-        if m.distance < 0.75*n.distance:
-            good.append([m])
-    # cv.drawMatchesKnn expects list of lists as matches.
-    draw_params = dict(matchColor = (255,0,0), # draw matches in red color
-                   singlePointColor = None,
-                   flags = 2)
-
-    img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
-    Image.fromarray(img3).save("demo/sift_matches.png")

+ 0 - 77
python/RoMa/demo/demo_match_tiny.py

@@ -1,77 +0,0 @@
-import os
-os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
-import torch
-from PIL import Image
-import torch.nn.functional as F
-import numpy as np
-from romatch.utils.utils import tensor_to_pil
-
-from romatch import tiny_roma_v1_outdoor
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-if torch.backends.mps.is_available():
-    device = torch.device('mps')
-
-if __name__ == "__main__":
-    from argparse import ArgumentParser
-    parser = ArgumentParser()
-    parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
-    parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
-    parser.add_argument("--save_A_path", default="demo/tiny_roma_warp_A.jpg", type=str)
-    parser.add_argument("--save_B_path", default="demo/tiny_roma_warp_B.jpg", type=str)
-
-    args, _ = parser.parse_known_args()
-    im1_path = args.im_A_path
-    im2_path = args.im_B_path
-
-    # Create model
-    roma_model = tiny_roma_v1_outdoor(device=device)
-
-    # Match
-    warp, certainty1 = roma_model.match(im1_path, im2_path)
-    
-    h1, w1 = warp.shape[:2]
-    
-    # maybe im1.size != im2.size
-    im1 = Image.open(im1_path).resize((w1, h1))
-    im2 = Image.open(im2_path)
-    x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
-    x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
-    
-    h2, w2 = x2.shape[1:]
-    g1_p2x = w2 / 2 * (warp[..., 2] + 1)
-    g1_p2y = h2 / 2 * (warp[..., 3] + 1)
-    g2_p1x = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
-    g2_p1y = torch.zeros((h2, w2), dtype=torch.float32).to(device) - 2
-
-    x, y = torch.meshgrid(
-        torch.arange(w1, device=device),
-        torch.arange(h1, device=device),
-        indexing="xy",
-    )
-    g2x = torch.round(g1_p2x[y, x]).long()
-    g2y = torch.round(g1_p2y[y, x]).long()
-    idx_x = torch.bitwise_and(0 <= g2x, g2x < w2)
-    idx_y = torch.bitwise_and(0 <= g2y, g2y < h2)
-    idx = torch.bitwise_and(idx_x, idx_y)
-    g2_p1x[g2y[idx], g2x[idx]] = x[idx].float() * 2 / w1 - 1
-    g2_p1y[g2y[idx], g2x[idx]] = y[idx].float() * 2 / h1 - 1
-
-    certainty2 = F.grid_sample(
-        certainty1[None][None],
-        torch.stack([g2_p1x, g2_p1y], dim=2)[None],
-        mode="bilinear",
-        align_corners=False,
-    )[0]
-    
-    white_im1 = torch.ones((h1, w1), device = device)
-    white_im2 = torch.ones((h2, w2), device = device)
-    
-    certainty1 = F.avg_pool2d(certainty1[None], kernel_size=5, stride=1, padding=2)[0]
-    certainty2 = F.avg_pool2d(certainty2[None], kernel_size=5, stride=1, padding=2)[0]
-    
-    vis_im1 = certainty1 * x1 + (1 - certainty1) * white_im1
-    vis_im2 = certainty2 * x2 + (1 - certainty2) * white_im2
-    
-    tensor_to_pil(vis_im1, unnormalize=False).save(args.save_A_path)
-    tensor_to_pil(vis_im2, unnormalize=False).save(args.save_B_path)

+ 0 - 2
python/RoMa/demo/gif/.gitignore

@@ -1,2 +0,0 @@
-*
-!.gitignore

+ 0 - 59
python/RoMa/experiments/eval_roma_outdoor.py

@@ -1,59 +0,0 @@
-import json
-
-from romatch.benchmarks import MegadepthDenseBenchmark
-from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, HpatchesHomogBenchmark
-from romatch.benchmarks import Mega1500PoseLibBenchmark
-
-def test_mega_8_scenes(model, name):
-    mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
-                                                scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
-                                                    'mega_8_scenes_0025_0.1_0.3.npz',
-                                                    'mega_8_scenes_0021_0.1_0.3.npz',
-                                                    'mega_8_scenes_0008_0.1_0.3.npz',
-                                                    'mega_8_scenes_0032_0.1_0.3.npz',
-                                                    'mega_8_scenes_1589_0.1_0.3.npz',
-                                                    'mega_8_scenes_0063_0.1_0.3.npz',
-                                                    'mega_8_scenes_0024_0.1_0.3.npz',
-                                                    'mega_8_scenes_0019_0.3_0.5.npz',
-                                                    'mega_8_scenes_0025_0.3_0.5.npz',
-                                                    'mega_8_scenes_0021_0.3_0.5.npz',
-                                                    'mega_8_scenes_0008_0.3_0.5.npz',
-                                                    'mega_8_scenes_0032_0.3_0.5.npz',
-                                                    'mega_8_scenes_1589_0.3_0.5.npz',
-                                                    'mega_8_scenes_0063_0.3_0.5.npz',
-                                                    'mega_8_scenes_0024_0.3_0.5.npz'])
-    mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
-    print(mega_8_scenes_results)
-    json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
-
-def test_mega1500(model, name):
-    mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
-
-def test_mega1500_poselib(model, name):
-    mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth")
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
-
-def test_mega_dense(model, name):
-    megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
-    megadense_results = megadense_benchmark.benchmark(model)
-    json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
-    
-def test_hpatches(model, name):
-    hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
-    hpatches_results = hpatches_benchmark.benchmark(model)
-    json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
-
-
-if __name__ == "__main__":
-    from romatch import roma_outdoor
-
-    device = "cuda"
-    model = roma_outdoor(device = device, coarse_res = 672, upsample_res = 1344)
-    experiment_name = "roma_latest"
-    test_mega1500(model, experiment_name)
-    test_hpatches(model, experiment_name)
-    #test_mega1500_poselib(model, experiment_name)
-    

+ 0 - 84
python/RoMa/experiments/eval_tiny_roma_v1_outdoor.py

@@ -1,84 +0,0 @@
-import torch
-import os
-from pathlib import Path
-import json
-from romatch.benchmarks import ScanNetBenchmark
-from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
-from romatch.benchmarks import MegaDepthPoseEstimationBenchmark
-
-def test_mega_8_scenes(model, name):
-    mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
-                                                scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
-                                                    'mega_8_scenes_0025_0.1_0.3.npz',
-                                                    'mega_8_scenes_0021_0.1_0.3.npz',
-                                                    'mega_8_scenes_0008_0.1_0.3.npz',
-                                                    'mega_8_scenes_0032_0.1_0.3.npz',
-                                                    'mega_8_scenes_1589_0.1_0.3.npz',
-                                                    'mega_8_scenes_0063_0.1_0.3.npz',
-                                                    'mega_8_scenes_0024_0.1_0.3.npz',
-                                                    'mega_8_scenes_0019_0.3_0.5.npz',
-                                                    'mega_8_scenes_0025_0.3_0.5.npz',
-                                                    'mega_8_scenes_0021_0.3_0.5.npz',
-                                                    'mega_8_scenes_0008_0.3_0.5.npz',
-                                                    'mega_8_scenes_0032_0.3_0.5.npz',
-                                                    'mega_8_scenes_1589_0.3_0.5.npz',
-                                                    'mega_8_scenes_0063_0.3_0.5.npz',
-                                                    'mega_8_scenes_0024_0.3_0.5.npz'])
-    mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
-    print(mega_8_scenes_results)
-    json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
-
-def test_mega1500(model, name):
-    mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
-
-def test_mega1500_poselib(model, name):
-    #model.exact_softmax = True
-    mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
-
-def test_mega_8_scenes_poselib(model, name):
-    mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
-                                                  scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
-                                                    'mega_8_scenes_0025_0.1_0.3.npz',
-                                                    'mega_8_scenes_0021_0.1_0.3.npz',
-                                                    'mega_8_scenes_0008_0.1_0.3.npz',
-                                                    'mega_8_scenes_0032_0.1_0.3.npz',
-                                                    'mega_8_scenes_1589_0.1_0.3.npz',
-                                                    'mega_8_scenes_0063_0.1_0.3.npz',
-                                                    'mega_8_scenes_0024_0.1_0.3.npz',
-                                                    'mega_8_scenes_0019_0.3_0.5.npz',
-                                                    'mega_8_scenes_0025_0.3_0.5.npz',
-                                                    'mega_8_scenes_0021_0.3_0.5.npz',
-                                                    'mega_8_scenes_0008_0.3_0.5.npz',
-                                                    'mega_8_scenes_0032_0.3_0.5.npz',
-                                                    'mega_8_scenes_1589_0.3_0.5.npz',
-                                                    'mega_8_scenes_0063_0.3_0.5.npz',
-                                                    'mega_8_scenes_0024_0.3_0.5.npz'])
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
-
-def test_scannet_poselib(model, name):
-    scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
-    scannet_results = scannet_benchmark.benchmark(model)
-    json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
-
-def test_scannet(model, name):
-    scannet_benchmark = ScanNetBenchmark("data/scannet")
-    scannet_results = scannet_benchmark.benchmark(model)
-    json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
-
-if __name__ == "__main__":
-    os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
-    os.environ["OMP_NUM_THREADS"] = "16"
-    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
-    from romatch import tiny_roma_v1_outdoor
-
-    experiment_name = Path(__file__).stem
-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-    model = tiny_roma_v1_outdoor(device)
-    #test_mega1500_poselib(model, experiment_name)
-    test_mega_8_scenes_poselib(model, experiment_name)
- 

+ 0 - 322
python/RoMa/experiments/roma_indoor.py

@@ -1,322 +0,0 @@
-import os
-import torch
-from argparse import ArgumentParser
-from warnings import warn
-from torch import nn
-from torch.utils.data import ConcatDataset
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-
-import json
-import wandb
-from tqdm import tqdm
-
-from romatch.benchmarks import MegadepthDenseBenchmark
-from romatch.datasets.megadepth import MegadepthBuilder
-from romatch.datasets.scannet import ScanNetBuilder
-from romatch.losses.robust_loss import RobustLosses
-from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
-from romatch.train.train import train_k_steps
-from romatch.models.matcher import *
-from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
-from romatch.models.encoders import *
-from romatch.checkpointing import CheckPoint
-
-resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
-
-def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
-    gp_dim = 512
-    feat_dim = 512
-    decoder_dim = gp_dim + feat_dim
-    cls_to_coord_res = 64
-    coordinate_decoder = TransformerDecoder(
-        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
-        decoder_dim, 
-        cls_to_coord_res**2 + 1,
-        is_classifier=True,
-        amp = True,
-        pos_enc = False,)
-    dw = True
-    hidden_blocks = 8
-    kernel_size = 5
-    displacement_emb = "linear"
-    disable_local_corr_grad = True
-    
-    conv_refiner = nn.ModuleDict(
-        {
-            "16": ConvRefiner(
-                2 * 512+128+(2*7+1)**2,
-                2 * 512+128+(2*7+1)**2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=128,
-                local_corr_radius = 7,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "8": ConvRefiner(
-                2 * 512+64+(2*3+1)**2,
-                2 * 512+64+(2*3+1)**2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=64,
-                local_corr_radius = 3,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "4": ConvRefiner(
-                2 * 256+32+(2*2+1)**2,
-                2 * 256+32+(2*2+1)**2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=32,
-                local_corr_radius = 2,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "2": ConvRefiner(
-                2 * 64+16,
-                128+16,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=16,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "1": ConvRefiner(
-                2 * 9 + 6,
-                24,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks = hidden_blocks,
-                displacement_emb = displacement_emb,
-                displacement_emb_dim = 6,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-        }
-    )
-    kernel_temperature = 0.2
-    learn_temperature = False
-    no_cov = True
-    kernel = CosKernel
-    only_attention = False
-    basis = "fourier"
-    gp16 = GP(
-        kernel,
-        T=kernel_temperature,
-        learn_temperature=learn_temperature,
-        only_attention=only_attention,
-        gp_dim=gp_dim,
-        basis=basis,
-        no_cov=no_cov,
-    )
-    gps = nn.ModuleDict({"16": gp16})
-    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
-    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
-    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
-    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
-    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
-    proj = nn.ModuleDict({
-        "16": proj16,
-        "8": proj8,
-        "4": proj4,
-        "2": proj2,
-        "1": proj1,
-        })
-    displacement_dropout_p = 0.0
-    gm_warp_dropout_p = 0.0
-    decoder = Decoder(coordinate_decoder, 
-                      gps, 
-                      proj, 
-                      conv_refiner, 
-                      detach=True, 
-                      scales=["16", "8", "4", "2", "1"], 
-                      displacement_dropout_p = displacement_dropout_p,
-                      gm_warp_dropout_p = gm_warp_dropout_p)
-    h,w = resolutions[resolution]
-    encoder = CNNandDinov2(
-        cnn_kwargs = dict(
-            pretrained=pretrained_backbone,
-            amp = True),
-        amp = True,
-        use_vgg = True,
-    )
-    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
-    return matcher
-
-def train(args):
-    dist.init_process_group('nccl')
-    #torch._dynamo.config.verbose=True
-    gpus = int(os.environ['WORLD_SIZE'])
-    # create model and move it to GPU with id rank
-    rank = dist.get_rank()
-    print(f"Start running DDP on rank {rank}")
-    device_id = rank % torch.cuda.device_count()
-    romatch.LOCAL_RANK = device_id
-    torch.cuda.set_device(device_id)
-    
-    resolution = args.train_resolution
-    wandb_log = not args.dont_log_wandb
-    experiment_name = os.path.splitext(os.path.basename(__file__))[0]
-    wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
-    wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
-    checkpoint_dir = "workspace/checkpoints/"
-    h,w = resolutions[resolution]
-    model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
-    # Num steps
-    global_step = 0
-    batch_size = args.gpu_batch_size
-    step_size = gpus*batch_size
-    romatch.STEP_SIZE = step_size
-    
-    N = (32 * 250000)  # 250k steps of batch size 32
-    # checkpoint every
-    k = 25000 // romatch.STEP_SIZE
-
-    # Data
-    mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
-    use_horizontal_flip_aug = True
-    rot_prob = 0
-    depth_interpolation_mode = "bilinear"
-    megadepth_train1 = mega.build_scenes(
-        split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
-        ht=h,wt=w,
-    )
-    megadepth_train2 = mega.build_scenes(
-        split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
-        ht=h,wt=w,
-    )
-    megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
-    mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
-    
-    scannet = ScanNetBuilder(data_root="data/scannet")
-    scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
-    scannet_train = ConcatDataset(scannet_train)
-    scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
-
-    # Loss and optimizer
-    depth_loss_scannet = RobustLosses(
-        ce_weight=0.0, 
-        local_dist={1:4, 2:4, 4:8, 8:8},
-        local_largest_scale=8,
-        depth_interpolation_mode=depth_interpolation_mode,
-        alpha = 0.5,
-        c = 1e-4,)
-    # Loss and optimizer
-    depth_loss_mega = RobustLosses(
-        ce_weight=0.01, 
-        local_dist={1:4, 2:4, 4:8, 8:8},
-        local_largest_scale=8,
-        depth_interpolation_mode=depth_interpolation_mode,
-        alpha = 0.5,
-        c = 1e-4,)
-    parameters = [
-        {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
-        {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
-    ]
-    optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
-    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
-        optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
-    megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
-    checkpointer = CheckPoint(checkpoint_dir, experiment_name)
-    model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
-    romatch.GLOBAL_STEP = global_step
-    ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
-    grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
-    grad_clip_norm = 0.01
-    for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
-        mega_sampler = torch.utils.data.WeightedRandomSampler(
-            mega_ws, num_samples = batch_size * k, replacement=False
-        )
-        mega_dataloader = iter(
-            torch.utils.data.DataLoader(
-                megadepth_train,
-                batch_size = batch_size,
-                sampler = mega_sampler,
-                num_workers = 8,
-            )
-        )
-        scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
-            scannet_ws, num_samples=batch_size * k, replacement=False
-        )
-        scannet_dataloader = iter(
-            torch.utils.data.DataLoader(
-                scannet_train,
-                batch_size=batch_size,
-                sampler=scannet_ws_sampler,
-                num_workers=gpus * 8,
-            )
-        )
-        for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0):
-            train_k_steps(
-                n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
-            )
-            train_k_steps(
-                n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
-            )
-        checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
-        wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
-
-def test_scannet(model, name, resolution, sample_mode):
-    scannet_benchmark = ScanNetBenchmark("data/scannet")
-    scannet_results = scannet_benchmark.benchmark(model)
-    json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
-
-if __name__ == "__main__":
-    import warnings
-    warn('Current version of romatch is not tested for training, use at your own risk.')
-
-    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
-    warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
-    os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
-    os.environ["OMP_NUM_THREADS"] = "16"
-    
-    import romatch
-    parser = ArgumentParser()
-    parser.add_argument("--test", action='store_true')
-    parser.add_argument("--debug_mode", action='store_true')
-    parser.add_argument("--dont_log_wandb", action='store_true')
-    parser.add_argument("--train_resolution", default='medium')
-    parser.add_argument("--gpu_batch_size", default=4, type=int)
-    parser.add_argument("--wandb_entity", required = False)
-
-    args, _ = parser.parse_known_args()
-    romatch.DEBUG_MODE = args.debug_mode
-    if not args.test:
-        train(args)
-    experiment_name = os.path.splitext(os.path.basename(__file__))[0]
-    checkpoint_dir = "workspace/"
-    checkpoint_name = checkpoint_dir + experiment_name + ".pth"
-    test_resolution = "medium"
-    sample_mode = "threshold_balanced"
-    symmetric = True
-    upsample_preds = False
-    attenuate_cert = True
-
-    model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
-    model = model.cuda()
-    states = torch.load(checkpoint_name)
-    model.load_state_dict(states["model"])
-    test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)

+ 0 - 308
python/RoMa/experiments/train_roma_outdoor.py

@@ -1,308 +0,0 @@
-import os
-import torch
-from argparse import ArgumentParser
-from warnings import warn
-from torch import nn
-from torch.utils.data import ConcatDataset
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-import json
-import wandb
-
-from romatch.benchmarks import MegadepthDenseBenchmark
-from romatch.datasets.megadepth import MegadepthBuilder
-from romatch.losses.robust_loss import RobustLosses
-from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
-
-from romatch.train.train import train_k_steps
-from romatch.models.matcher import *
-from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
-from romatch.models.encoders import *
-from romatch.checkpointing import CheckPoint
-
-resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
-
-def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
-    import warnings
-    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
-    gp_dim = 512
-    feat_dim = 512
-    decoder_dim = gp_dim + feat_dim
-    cls_to_coord_res = 64
-    coordinate_decoder = TransformerDecoder(
-        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
-        decoder_dim, 
-        cls_to_coord_res**2 + 1,
-        is_classifier=True,
-        amp = True,
-        pos_enc = False,)
-    dw = True
-    hidden_blocks = 8
-    kernel_size = 5
-    displacement_emb = "linear"
-    disable_local_corr_grad = True
-    
-    conv_refiner = nn.ModuleDict(
-        {
-            "16": ConvRefiner(
-                2 * 512+128+(2*7+1)**2,
-                2 * 512+128+(2*7+1)**2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=128,
-                local_corr_radius = 7,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "8": ConvRefiner(
-                2 * 512+64+(2*3+1)**2,
-                2 * 512+64+(2*3+1)**2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=64,
-                local_corr_radius = 3,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "4": ConvRefiner(
-                2 * 256+32+(2*2+1)**2,
-                2 * 256+32+(2*2+1)**2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=32,
-                local_corr_radius = 2,
-                corr_in_other = True,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "2": ConvRefiner(
-                2 * 64+16,
-                128+16,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=16,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-            "1": ConvRefiner(
-                2 * 9 + 6,
-                24,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks = hidden_blocks,
-                displacement_emb = displacement_emb,
-                displacement_emb_dim = 6,
-                amp = True,
-                disable_local_corr_grad = disable_local_corr_grad,
-                bn_momentum = 0.01,
-            ),
-        }
-    )
-    kernel_temperature = 0.2
-    learn_temperature = False
-    no_cov = True
-    kernel = CosKernel
-    only_attention = False
-    basis = "fourier"
-    gp16 = GP(
-        kernel,
-        T=kernel_temperature,
-        learn_temperature=learn_temperature,
-        only_attention=only_attention,
-        gp_dim=gp_dim,
-        basis=basis,
-        no_cov=no_cov,
-    )
-    gps = nn.ModuleDict({"16": gp16})
-    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
-    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
-    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
-    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
-    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
-    proj = nn.ModuleDict({
-        "16": proj16,
-        "8": proj8,
-        "4": proj4,
-        "2": proj2,
-        "1": proj1,
-        })
-    displacement_dropout_p = 0.0
-    gm_warp_dropout_p = 0.0
-    decoder = Decoder(coordinate_decoder, 
-                      gps, 
-                      proj, 
-                      conv_refiner, 
-                      detach=True, 
-                      scales=["16", "8", "4", "2", "1"], 
-                      displacement_dropout_p = displacement_dropout_p,
-                      gm_warp_dropout_p = gm_warp_dropout_p)
-    h,w = resolutions[resolution]
-    encoder = CNNandDinov2(
-        cnn_kwargs = dict(
-            pretrained=pretrained_backbone,
-            amp = True),
-        amp = True,
-        use_vgg = True,
-    )
-    matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
-    return matcher
-
-def train(args):
-    dist.init_process_group('nccl')
-    #torch._dynamo.config.verbose=True
-    gpus = int(os.environ['WORLD_SIZE'])
-    # create model and move it to GPU with id rank
-    rank = dist.get_rank()
-    print(f"Start running DDP on rank {rank}")
-    device_id = rank % torch.cuda.device_count()
-    romatch.LOCAL_RANK = device_id
-    torch.cuda.set_device(device_id)
-    
-    resolution = args.train_resolution
-    wandb_log = not args.dont_log_wandb
-    experiment_name = os.path.splitext(os.path.basename(__file__))[0]
-    wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
-    wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
-    checkpoint_dir = "workspace/checkpoints/"
-    h,w = resolutions[resolution]
-    model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
-    # Num steps
-    global_step = 0
-    batch_size = args.gpu_batch_size
-    step_size = gpus*batch_size
-    romatch.STEP_SIZE = step_size
-    
-    N = (32 * 250000)  # 250k steps of batch size 32
-    # checkpoint every
-    k = 25000 // romatch.STEP_SIZE
-
-    # Data
-    mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
-    use_horizontal_flip_aug = True
-    rot_prob = 0
-    depth_interpolation_mode = "bilinear"
-    megadepth_train1 = mega.build_scenes(
-        split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
-        ht=h,wt=w,
-    )
-    megadepth_train2 = mega.build_scenes(
-        split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
-        ht=h,wt=w,
-    )
-    megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
-    mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
-    # Loss and optimizer
-    depth_loss = RobustLosses(
-        ce_weight=0.01, 
-        local_dist={1:4, 2:4, 4:8, 8:8},
-        local_largest_scale=8,
-        depth_interpolation_mode=depth_interpolation_mode,
-        alpha = 0.5,
-        c = 1e-4,)
-    parameters = [
-        {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8},
-        {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
-    ]
-    optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
-    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
-        optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
-    megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
-    checkpointer = CheckPoint(checkpoint_dir, experiment_name)
-    model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
-    romatch.GLOBAL_STEP = global_step
-    ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
-    grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
-    grad_clip_norm = 0.01
-    for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
-        mega_sampler = torch.utils.data.WeightedRandomSampler(
-            mega_ws, num_samples = batch_size * k, replacement=False
-        )
-        mega_dataloader = iter(
-            torch.utils.data.DataLoader(
-                megadepth_train,
-                batch_size = batch_size,
-                sampler = mega_sampler,
-                num_workers = 8,
-            )
-        )
-        train_k_steps(
-            n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
-        )
-        checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
-        wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP)
-
-def test_mega_8_scenes(model, name):
-    mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
-                                                scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
-                                                    'mega_8_scenes_0025_0.1_0.3.npz',
-                                                    'mega_8_scenes_0021_0.1_0.3.npz',
-                                                    'mega_8_scenes_0008_0.1_0.3.npz',
-                                                    'mega_8_scenes_0032_0.1_0.3.npz',
-                                                    'mega_8_scenes_1589_0.1_0.3.npz',
-                                                    'mega_8_scenes_0063_0.1_0.3.npz',
-                                                    'mega_8_scenes_0024_0.1_0.3.npz',
-                                                    'mega_8_scenes_0019_0.3_0.5.npz',
-                                                    'mega_8_scenes_0025_0.3_0.5.npz',
-                                                    'mega_8_scenes_0021_0.3_0.5.npz',
-                                                    'mega_8_scenes_0008_0.3_0.5.npz',
-                                                    'mega_8_scenes_0032_0.3_0.5.npz',
-                                                    'mega_8_scenes_1589_0.3_0.5.npz',
-                                                    'mega_8_scenes_0063_0.3_0.5.npz',
-                                                    'mega_8_scenes_0024_0.3_0.5.npz'])
-    mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
-    print(mega_8_scenes_results)
-    json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
-
-def test_mega1500(model, name):
-    mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
-
-def test_mega_dense(model, name):
-    megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
-    megadense_results = megadense_benchmark.benchmark(model)
-    json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
-    
-def test_hpatches(model, name):
-    hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
-    hpatches_results = hpatches_benchmark.benchmark(model)
-    json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
-
-
-if __name__ == "__main__":
-    os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
-    os.environ["OMP_NUM_THREADS"] = "16"
-    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
-    warn('Current version of romatch is not tested for training, use at your own risk.')
-    import romatch
-    parser = ArgumentParser()
-    parser.add_argument("--only_test", action='store_true')
-    parser.add_argument("--debug_mode", action='store_true')
-    parser.add_argument("--dont_log_wandb", action='store_true')
-    parser.add_argument("--train_resolution", default='medium')
-    parser.add_argument("--gpu_batch_size", default=8, type=int)
-    parser.add_argument("--wandb_entity", required = False)
-
-    args, _ = parser.parse_known_args()
-    romatch.DEBUG_MODE = args.debug_mode
-    if not args.only_test:
-        train(args)

+ 0 - 498
python/RoMa/experiments/train_tiny_roma_v1_outdoor.py

@@ -1,498 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import os
-import torch
-from argparse import ArgumentParser
-from pathlib import Path
-import math
-import numpy as np
-
-from torch import nn
-from torch.utils.data import ConcatDataset
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-import json
-import wandb
-from PIL import Image
-from torchvision.transforms import ToTensor
-
-from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
-from romatch.benchmarks import Mega1500PoseLibBenchmark, ScanNetPoselibBenchmark
-from romatch.datasets.megadepth import MegadepthBuilder
-from romatch.losses.robust_loss_tiny_roma import RobustLosses
-from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
-from romatch.train.train import train_k_steps
-from romatch.checkpointing import CheckPoint
-
-resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6), "xfeat": (600,800), "big": (768, 1024)}
-
-def kde(x, std = 0.1):
-    # use a gaussian kernel to estimate density
-    x = x.half() # Do it in half precision TODO: remove hardcoding
-    scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
-    density = scores.sum(dim=-1)
-    return density
-
-class BasicLayer(nn.Module):
-    """
-        Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
-    """
-    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
-        super().__init__()
-        self.layer = nn.Sequential(
-                                        nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
-                                        nn.BatchNorm2d(out_channels, affine=False),
-                                        nn.ReLU(inplace = True) if relu else nn.Identity()
-                                    )
-
-    def forward(self, x):
-        return self.layer(x)
-
-class XFeatModel(nn.Module):
-    """
-        Implementation of architecture described in 
-        "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
-    """
-
-    def __init__(self, xfeat = None, 
-                 freeze_xfeat = True, 
-                 sample_mode = "threshold_balanced", 
-                 symmetric = False, 
-                 exact_softmax = False):
-        super().__init__()
-        if xfeat is None:
-            xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096).net
-            del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
-        if freeze_xfeat:
-            xfeat.train(False)
-            self.xfeat = [xfeat]# hide params from ddp
-        else:
-            self.xfeat = nn.ModuleList([xfeat])
-        self.freeze_xfeat = freeze_xfeat
-        match_dim = 256
-        self.coarse_matcher = nn.Sequential(
-            BasicLayer(64+64+2, match_dim,),
-            BasicLayer(match_dim, match_dim,), 
-            BasicLayer(match_dim, match_dim,), 
-            BasicLayer(match_dim, match_dim,), 
-            nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
-        fine_match_dim = 64
-        self.fine_matcher = nn.Sequential(
-            BasicLayer(24+24+2, fine_match_dim,),
-            BasicLayer(fine_match_dim, fine_match_dim,), 
-            BasicLayer(fine_match_dim, fine_match_dim,), 
-            BasicLayer(fine_match_dim, fine_match_dim,), 
-            nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
-        self.sample_mode = sample_mode
-        self.sample_thresh = 0.2
-        self.symmetric = symmetric
-        self.exact_softmax = exact_softmax
-    
-    @property
-    def device(self):
-        return self.fine_matcher[-1].weight.device
-    
-    def preprocess_tensor(self, x):
-        """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
-        H, W = x.shape[-2:]
-        _H, _W = (H//32) * 32, (W//32) * 32
-        rh, rw = H/_H, W/_W
-
-        x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
-        return x, rh, rw        
-    
-    def forward_single(self, x):
-        with torch.inference_mode(self.freeze_xfeat or not self.training):
-            xfeat = self.xfeat[0]
-            with torch.no_grad():
-                x = x.mean(dim=1, keepdim = True)
-                x = xfeat.norm(x)
-
-            #main backbone
-            x1 = xfeat.block1(x)
-            x2 = xfeat.block2(x1 + xfeat.skip1(x))
-            x3 = xfeat.block3(x2)
-            x4 = xfeat.block4(x3)
-            x5 = xfeat.block5(x4)
-            x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
-            x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
-            feats = xfeat.block_fusion( x3 + x4 + x5 )
-        if self.freeze_xfeat:
-            return x2.clone(), feats.clone()
-        return x2, feats
-
-    def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
-        if coords.shape[-1] == 2:
-            return self._to_pixel_coordinates(coords, H_A, W_A) 
-        
-        if isinstance(coords, (list, tuple)):
-            kpts_A, kpts_B = coords[0], coords[1]
-        else:
-            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
-        return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
-
-    def _to_pixel_coordinates(self, coords, H, W):
-        kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
-        return kpts
-    
-    def pos_embed(self, corr_volume: torch.Tensor):
-        B, H1, W1, H0, W0 = corr_volume.shape 
-        grid = torch.stack(
-                torch.meshgrid(
-                    torch.linspace(-1+1/W1,1-1/W1, W1), 
-                    torch.linspace(-1+1/H1,1-1/H1, H1), 
-                    indexing = "xy"), 
-                dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
-        down = 4
-        if not self.training and not self.exact_softmax:
-            grid_lr = torch.stack(
-                torch.meshgrid(
-                    torch.linspace(-1+down/W1,1-down/W1, W1//down), 
-                    torch.linspace(-1+down/H1,1-down/H1, H1//down), 
-                    indexing = "xy"), 
-                dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
-            cv = corr_volume
-            best_match = cv.reshape(B,H1*W1,H0,W0).amax(dim=1) # B, HW, H, W
-            P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
-            pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
-            pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
-        else:
-            P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
-            pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
-        return pos_embeddings
-    
-    def visualize_warp(self, warp, certainty, im_A = None, im_B = None, 
-                       im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
-        device = warp.device
-        H,W2,_ = warp.shape
-        W = W2//2 if symmetric else W2
-        if im_A is None:
-            from PIL import Image
-            im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
-        if not isinstance(im_A, torch.Tensor):
-            im_A = im_A.resize((W,H))
-            im_B = im_B.resize((W,H))    
-            x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
-            if symmetric:
-                x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
-        else:
-            if symmetric:
-                x_A = im_A
-            x_B = im_B
-        im_A_transfer_rgb = F.grid_sample(
-        x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
-        )[0]
-        if symmetric:
-            im_B_transfer_rgb = F.grid_sample(
-            x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
-            )[0]
-            warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
-            white_im = torch.ones((H,2*W),device=device)
-        else:
-            warp_im = im_A_transfer_rgb
-            white_im = torch.ones((H, W), device = device)
-        vis_im = certainty * warp_im + (1 - certainty) * white_im
-        if save_path is not None:
-            from romatch.utils import tensor_to_pil
-            tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
-        return vis_im
-     
-    def corr_volume(self, feat0, feat1):
-        """
-            input:
-                feat0 -> torch.Tensor(B, C, H, W)
-                feat1 -> torch.Tensor(B, C, H, W)
-            return:
-                corr_volume -> torch.Tensor(B, H, W, H, W)
-        """
-        B, C, H0, W0 = feat0.shape
-        B, C, H1, W1 = feat1.shape
-        feat0 = feat0.view(B, C, H0*W0)
-        feat1 = feat1.view(B, C, H1*W1)
-        corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
-        return corr_volume
-    
-    @torch.inference_mode()
-    def match_from_path(self, im0_path, im1_path):
-        device = self.device
-        im0 = ToTensor()(Image.open(im0_path))[None].to(device)
-        im1 = ToTensor()(Image.open(im1_path))[None].to(device)
-        return self.match(im0, im1, batched = False)
-    
-    @torch.inference_mode()
-    def match(self, im0, im1, *args, batched = True):
-        # stupid
-        if isinstance(im0, (str, Path)):
-            return self.match_from_path(im0, im1)
-        elif isinstance(im0, Image.Image):
-            batched = False
-            device = self.device
-            im0 = ToTensor()(im0)[None].to(device)
-            im1 = ToTensor()(im1)[None].to(device)
- 
-        B,C,H0,W0 = im0.shape
-        B,C,H1,W1 = im1.shape
-        self.train(False)
-        corresps = self.forward({"im_A":im0, "im_B":im1})
-        #return 1,1
-        flow = F.interpolate(
-            corresps[4]["flow"], 
-            size = (H0, W0), 
-            mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
-        grid = torch.stack(
-            torch.meshgrid(
-                torch.linspace(-1+1/W0,1-1/W0, W0), 
-                torch.linspace(-1+1/H0,1-1/H0, H0), 
-                indexing = "xy"), 
-            dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
-        
-        certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
-        warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
-        if batched:
-            return warp, cert
-        else:
-            return warp[0], cert[0]
-
-    def sample(
-        self,
-        matches,
-        certainty,
-        num=10000,
-    ):
-        if "threshold" in self.sample_mode:
-            upper_thresh = self.sample_thresh
-            certainty = certainty.clone()
-            certainty[certainty > upper_thresh] = 1
-        matches, certainty = (
-            matches.reshape(-1, 4),
-            certainty.reshape(-1),
-        )
-        expansion_factor = 4 if "balanced" in self.sample_mode else 1
-        good_samples = torch.multinomial(certainty, 
-                          num_samples = min(expansion_factor*num, len(certainty)), 
-                          replacement=False)
-        good_matches, good_certainty = matches[good_samples], certainty[good_samples]
-        if "balanced" not in self.sample_mode:
-            return good_matches, good_certainty
-        density = kde(good_matches, std=0.1)
-        p = 1 / (density+1)
-        p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
-        balanced_samples = torch.multinomial(p, 
-                          num_samples = min(num,len(good_certainty)), 
-                          replacement=False)
-        return good_matches[balanced_samples], good_certainty[balanced_samples]
-            
-    def forward(self, batch):
-        """
-            input:
-                x -> torch.Tensor(B, C, H, W) grayscale or rgb images
-            return:
-
-        """
-        im0 = batch["im_A"]
-        im1 = batch["im_B"]
-        corresps = {}
-        im0, rh0, rw0 = self.preprocess_tensor(im0)
-        im1, rh1, rw1 = self.preprocess_tensor(im1)
-        B, C, H0, W0 = im0.shape
-        B, C, H1, W1 = im1.shape
-        to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
- 
-        if im0.shape[-2:] == im1.shape[-2:]:
-            x = torch.cat([im0, im1], dim=0)
-            x = self.forward_single(x)
-            feats_x0_c, feats_x1_c = x[1].chunk(2)
-            feats_x0_f, feats_x1_f = x[0].chunk(2)
-        else:
-            feats_x0_f, feats_x0_c = self.forward_single(im0)
-            feats_x1_f, feats_x1_c = self.forward_single(im1)
-        corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
-        coarse_warp = self.pos_embed(corr_volume)
-        coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
-        feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
-        coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
-        coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
-        corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
-        coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)        
-        coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
-        feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
-        fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
-        fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
-        corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
-        return corresps
-    
-
-
-
-
-def train(args):
-    rank = 0
-    gpus = 1
-    device_id = rank % torch.cuda.device_count()
-    romatch.LOCAL_RANK = 0
-    torch.cuda.set_device(device_id)
-        
-    resolution = "big"
-    wandb_log = not args.dont_log_wandb
-    experiment_name = Path(__file__).stem
-    wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
-    wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
-    checkpoint_dir = "workspace/checkpoints/"
-    h,w = resolutions[resolution]
-    model = XFeatModel(freeze_xfeat = False).to(device_id)
-    # Num steps
-    global_step = 0
-    batch_size = args.gpu_batch_size
-    step_size = gpus*batch_size
-    romatch.STEP_SIZE = step_size
-    
-    N = 2_000_000  # 2M pairs
-    # checkpoint every
-    k = 25000 // romatch.STEP_SIZE
-
-    # Data
-    mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
-    use_horizontal_flip_aug = True
-    normalize = False # don't imgnet normalize
-    rot_prob = 0
-    depth_interpolation_mode = "bilinear"
-    megadepth_train1 = mega.build_scenes(
-        split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
-        ht=h,wt=w, normalize = normalize
-    )
-    megadepth_train2 = mega.build_scenes(
-        split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
-        ht=h,wt=w, normalize = normalize
-    )
-    megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
-    mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
-    # Loss and optimizer
-    depth_loss = RobustLosses(
-        ce_weight=0.01, 
-        local_dist={4:4},
-        depth_interpolation_mode=depth_interpolation_mode,
-        alpha = {4:0.15, 8:0.15},
-        c = 1e-4,
-        epe_mask_prob_th = 0.001,
-        )
-    parameters = [
-        {"params": model.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8},
-    ]
-    optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
-    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
-        optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10])
-    #megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
-    mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 30)
-
-    checkpointer = CheckPoint(checkpoint_dir, experiment_name)
-    model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
-    romatch.GLOBAL_STEP = global_step
-    grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
-    grad_clip_norm = 0.01
-    #megadense_benchmark.benchmark(model)
-    for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE):
-        mega_sampler = torch.utils.data.WeightedRandomSampler(
-            mega_ws, num_samples = batch_size * k, replacement=False
-        )
-        mega_dataloader = iter(
-            torch.utils.data.DataLoader(
-                megadepth_train,
-                batch_size = batch_size,
-                sampler = mega_sampler,
-                num_workers = 8,
-            )
-        )
-        train_k_steps(
-            n, k, mega_dataloader, model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
-        )
-        checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP)
-        wandb.log(mega1500_benchmark.benchmark(model, model_name=experiment_name), step = romatch.GLOBAL_STEP)
-
-def test_mega_8_scenes(model, name):
-    mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
-                                                scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
-                                                    'mega_8_scenes_0025_0.1_0.3.npz',
-                                                    'mega_8_scenes_0021_0.1_0.3.npz',
-                                                    'mega_8_scenes_0008_0.1_0.3.npz',
-                                                    'mega_8_scenes_0032_0.1_0.3.npz',
-                                                    'mega_8_scenes_1589_0.1_0.3.npz',
-                                                    'mega_8_scenes_0063_0.1_0.3.npz',
-                                                    'mega_8_scenes_0024_0.1_0.3.npz',
-                                                    'mega_8_scenes_0019_0.3_0.5.npz',
-                                                    'mega_8_scenes_0025_0.3_0.5.npz',
-                                                    'mega_8_scenes_0021_0.3_0.5.npz',
-                                                    'mega_8_scenes_0008_0.3_0.5.npz',
-                                                    'mega_8_scenes_0032_0.3_0.5.npz',
-                                                    'mega_8_scenes_1589_0.3_0.5.npz',
-                                                    'mega_8_scenes_0063_0.3_0.5.npz',
-                                                    'mega_8_scenes_0024_0.3_0.5.npz'])
-    mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name)
-    print(mega_8_scenes_results)
-    json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
-
-def test_mega1500(model, name):
-    mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
-
-def test_mega1500_poselib(model, name):
-    mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1)
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega1500_poselib_{name}.json", "w"))
-
-def test_mega_8_scenes_poselib(model, name):
-    mega1500_benchmark = Mega1500PoseLibBenchmark("data/megadepth", num_ransac_iter = 1, test_every = 1,
-                                                  scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
-                                                    'mega_8_scenes_0025_0.1_0.3.npz',
-                                                    'mega_8_scenes_0021_0.1_0.3.npz',
-                                                    'mega_8_scenes_0008_0.1_0.3.npz',
-                                                    'mega_8_scenes_0032_0.1_0.3.npz',
-                                                    'mega_8_scenes_1589_0.1_0.3.npz',
-                                                    'mega_8_scenes_0063_0.1_0.3.npz',
-                                                    'mega_8_scenes_0024_0.1_0.3.npz',
-                                                    'mega_8_scenes_0019_0.3_0.5.npz',
-                                                    'mega_8_scenes_0025_0.3_0.5.npz',
-                                                    'mega_8_scenes_0021_0.3_0.5.npz',
-                                                    'mega_8_scenes_0008_0.3_0.5.npz',
-                                                    'mega_8_scenes_0032_0.3_0.5.npz',
-                                                    'mega_8_scenes_1589_0.3_0.5.npz',
-                                                    'mega_8_scenes_0063_0.3_0.5.npz',
-                                                    'mega_8_scenes_0024_0.3_0.5.npz'])
-    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
-    json.dump(mega1500_results, open(f"results/mega_8_scenes_poselib_{name}.json", "w"))
-
-def test_scannet_poselib(model, name):
-    scannet_benchmark = ScanNetPoselibBenchmark("data/scannet")
-    scannet_results = scannet_benchmark.benchmark(model)
-    json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
-
-def test_scannet(model, name):
-    scannet_benchmark = ScanNetBenchmark("data/scannet")
-    scannet_results = scannet_benchmark.benchmark(model)
-    json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
-
-if __name__ == "__main__":
-    os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
-    os.environ["OMP_NUM_THREADS"] = "16"
-    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
-    import romatch
-    parser = ArgumentParser()
-    parser.add_argument("--only_test", action='store_true')
-    parser.add_argument("--debug_mode", action='store_true')
-    parser.add_argument("--dont_log_wandb", action='store_true')
-    parser.add_argument("--train_resolution", default='medium')
-    parser.add_argument("--gpu_batch_size", default=8, type=int)
-    parser.add_argument("--wandb_entity", required = False)
-
-    args, _ = parser.parse_known_args()
-    romatch.DEBUG_MODE = args.debug_mode
-    if not args.only_test:
-        train(args)
-
-    experiment_name = "tiny_roma_v1_outdoor"#Path(__file__).stem
-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-    model = XFeatModel(freeze_xfeat=False, exact_softmax=False).to(device)
-    model.load_state_dict(torch.load(f"{experiment_name}.pth"))
-    test_mega1500_poselib(model, experiment_name)
-    

+ 0 - 42
python/RoMa/pyproject.toml

@@ -1,42 +0,0 @@
-[project]
-name = "romatch"
-version = "0.1.2"
-description = "Robust Dense Feature Matching"
-readme = "README.md"
-authors = [
-    { name = "Johan Edstedt", email = "johan.edstedt@liu.se" }
-]
-requires-python = ">=3.9"
-dependencies = [
-    "albumentations",
-    "einops",
-    "h5py",
-    "kornia",
-    "loguru",
-    "matplotlib",
-    "opencv-python",
-    "poselib>=2.0.4",
-    "timm",
-    "torch>=2.5.1",
-    "torchvision",
-    "tqdm",
-    "wandb",
-]
-
-[project.optional-dependencies]
-fused-local-corr = [
-    "fused-local-corr>=0.2.2 ; sys_platform == 'linux'",
-]
-
-[build-system]
-requires = ["uv_build>=0.8.14,<0.9.0"]
-build-backend = "uv_build"
-
-[tool.uv.build-backend]
-module-name = "romatch"
-module-root = ""
-
-[dependency-groups]
-dev = [
-    "ruff>=0.13.1",
-]

+ 0 - 8
python/RoMa/romatch/__init__.py

@@ -1,8 +0,0 @@
-import os
-from .models import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor
-
-DEBUG_MODE = False
-RANK = int(os.environ.get('RANK', default = 0))
-GLOBAL_STEP = 0
-STEP_SIZE = 1
-LOCAL_RANK = -1

+ 0 - 6
python/RoMa/romatch/benchmarks/__init__.py

@@ -1,6 +0,0 @@
-from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
-from .scannet_benchmark import ScanNetBenchmark
-from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
-from .megadepth_dense_benchmark import MegadepthDenseBenchmark
-from .megadepth_pose_estimation_benchmark_poselib import Mega1500PoseLibBenchmark
-#from .scannet_benchmark_poselib import ScanNetPoselibBenchmark

+ 0 - 113
python/RoMa/romatch/benchmarks/hpatches_sequences_homog_benchmark.py

@@ -1,113 +0,0 @@
-from PIL import Image
-import numpy as np
-import torch
-import os
-
-from tqdm import tqdm
-from romatch.utils import pose_auc
-import cv2
-
-
-class HpatchesHomogBenchmark:
-    """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
-
-    def __init__(self, dataset_path) -> None:
-        seqs_dir = "hpatches-sequences-release"
-        self.seqs_path = os.path.join(dataset_path, seqs_dir)
-        self.seq_names = sorted(os.listdir(self.seqs_path))
-        # Ignore seqs is same as LoFTR.
-        self.ignore_seqs = set(
-            [
-                "i_contruction",
-                "i_crownnight",
-                "i_dc",
-                "i_pencils",
-                "i_whitebuilding",
-                "v_artisans",
-                "v_astronautis",
-                "v_talent",
-            ]
-        )
-
-    def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
-        offset = 0.5  # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
-        im_A_coords = (
-            torch.stack(
-                (
-                    wq * (im_A_coords[..., 0] + 1) / 2,
-                    hq * (im_A_coords[..., 1] + 1) / 2,
-                ),
-                axis=-1,
-            )
-            - offset
-        )
-        im_A_to_im_B = (
-            torch.stack(
-                (
-                    wsup * (im_A_to_im_B[..., 0] + 1) / 2,
-                    hsup * (im_A_to_im_B[..., 1] + 1) / 2,
-                ),
-                axis=-1,
-            )
-            - offset
-        )
-        return im_A_coords, im_A_to_im_B
-
-    def benchmark(self, model, model_name = None):
-        n_matches = []
-        homog_dists = []
-        for seq_idx, seq_name in tqdm(
-            enumerate(self.seq_names), total=len(self.seq_names)
-        ):
-            im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
-            im_A = Image.open(im_A_path)
-            w1, h1 = im_A.size
-            for im_idx in range(2, 7):
-                im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
-                im_B = Image.open(im_B_path)
-                w2, h2 = im_B.size
-                H = np.loadtxt(
-                    os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
-                )
-                dense_matches, dense_certainty = model.match(
-                    im_A_path, im_B_path
-                )
-                good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
-                pos_a, pos_b = self.convert_coordinates(
-                    good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
-                )
-                try:
-                    H_pred, inliers = cv2.findHomography(
-                        pos_a.cpu().numpy(),
-                        pos_b.cpu().numpy(),
-                        method = cv2.RANSAC,
-                        confidence = 0.99999,
-                        ransacReprojThreshold = 3 * min(w2, h2) / 480,
-                    )
-                except:
-                    H_pred = None
-                if H_pred is None:
-                    H_pred = np.zeros((3, 3))
-                    H_pred[2, 2] = 1.0
-                corners = np.array(
-                    [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
-                )
-                real_warped_corners = np.dot(corners, np.transpose(H))
-                real_warped_corners = (
-                    real_warped_corners[:, :2] / real_warped_corners[:, 2:]
-                )
-                warped_corners = np.dot(corners, np.transpose(H_pred))
-                warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
-                mean_dist = np.mean(
-                    np.linalg.norm(real_warped_corners - warped_corners, axis=1)
-                ) / (min(w2, h2) / 480.0)
-                homog_dists.append(mean_dist)
-
-        n_matches = np.array(n_matches)
-        thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-        auc = pose_auc(np.array(homog_dists), thresholds)
-        return {
-            "hpatches_homog_auc_3": auc[2],
-            "hpatches_homog_auc_5": auc[4],
-            "hpatches_homog_auc_10": auc[9],
-        }

+ 0 - 105
python/RoMa/romatch/benchmarks/megadepth_dense_benchmark.py

@@ -1,105 +0,0 @@
-import torch
-import numpy as np
-import tqdm
-from romatch.datasets import MegadepthBuilder
-from romatch.utils import warp_kpts
-from torch.utils.data import ConcatDataset
-import romatch
-
-class MegadepthDenseBenchmark:
-    def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
-        mega = MegadepthBuilder(data_root=data_root)
-        self.dataset = ConcatDataset(
-            mega.build_scenes(split="test_loftr", ht=h, wt=w)
-        )  # fixed resolution of 384,512
-        self.num_samples = num_samples
-
-    def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
-        b, h1, w1, d = dense_matches.shape
-        with torch.no_grad():
-            x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
-            mask, x2 = warp_kpts(
-                x1.double(),
-                depth1.double(),
-                depth2.double(),
-                T_1to2.double(),
-                K1.double(),
-                K2.double(),
-            )
-            x2 = torch.stack(
-                (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
-            )
-            prob = mask.float().reshape(b, h1, w1)
-        x2_hat = dense_matches[..., 2:]
-        x2_hat = torch.stack(
-            (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
-        )
-        gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
-        gd = gd[prob == 1]
-        pck_1 = (gd < 1.0).float().mean()
-        pck_3 = (gd < 3.0).float().mean()
-        pck_5 = (gd < 5.0).float().mean()
-        return gd, pck_1, pck_3, pck_5, prob
-
-    def benchmark(self, model, batch_size=8):
-        model.train(False)
-        gd_tot = 0.0
-        pck_1_tot = 0.0
-        pck_3_tot = 0.0
-        pck_5_tot = 0.0
-        sampler = torch.utils.data.WeightedRandomSampler(
-            torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
-        )
-        B = batch_size
-        dataloader = torch.utils.data.DataLoader(
-            self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
-        )
-        for idx, data in tqdm.tqdm(enumerate(dataloader), disable = romatch.RANK > 0):
-            im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
-                data["im_A"].cuda(),
-                data["im_B"].cuda(),
-                data["im_A_depth"].cuda(),
-                data["im_B_depth"].cuda(),
-                data["T_1to2"].cuda(),
-                data["K1"].cuda(),
-                data["K2"].cuda(),
-            )
-            matches, certainty = model.match(im_A, im_B, batched=True)
-            gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
-                depth1, depth2, T_1to2, K1, K2, matches
-            )
-            if romatch.DEBUG_MODE:
-                from romatch.utils.utils import tensor_to_pil
-                import torch.nn.functional as F
-                path = "vis"
-                H, W = model.get_output_resolution()
-                white_im = torch.ones((B,1,H,W),device="cuda")
-                im_B_transfer_rgb = F.grid_sample(
-                    im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
-                )
-                warp_im = im_B_transfer_rgb
-                c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
-                vis_im = c_b * warp_im + (1 - c_b) * white_im
-                for b in range(B):
-                    import os
-                    os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
-                    tensor_to_pil(vis_im[b], unnormalize=True).save(
-                        f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
-                    tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
-                        f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
-                    tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
-                        f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
-
-
-            gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
-                gd_tot + gd.mean(),
-                pck_1_tot + pck_1,
-                pck_3_tot + pck_3,
-                pck_5_tot + pck_5,
-            )
-        return {
-            "epe": gd_tot.item() / len(dataloader),
-            "mega_pck_1": pck_1_tot.item() / len(dataloader),
-            "mega_pck_3": pck_3_tot.item() / len(dataloader),
-            "mega_pck_5": pck_5_tot.item() / len(dataloader),
-        }

+ 0 - 116
python/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark.py

@@ -1,116 +0,0 @@
-import numpy as np
-import torch
-from romatch.utils import *
-from PIL import Image
-from tqdm import tqdm
-
-class MegaDepthPoseEstimationBenchmark:
-    def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
-        if scene_names is None:
-            self.scene_names = [
-                "0015_0.1_0.3.npz",
-                "0015_0.3_0.5.npz",
-                "0022_0.1_0.3.npz",
-                "0022_0.3_0.5.npz",
-                "0022_0.5_0.7.npz",
-            ]
-        else:
-            self.scene_names = scene_names
-        self.scenes = [
-            np.load(f"{data_root}/{scene}", allow_pickle=True)
-            for scene in self.scene_names
-        ]
-        self.data_root = data_root
-
-    def benchmark(self, model, model_name = None):
-        with torch.no_grad():
-            data_root = self.data_root
-            tot_e_t, tot_e_R, tot_e_pose = [], [], []
-            thresholds = [5, 10, 20]
-            for scene_ind in range(len(self.scenes)):
-                import os
-                scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
-                scene = self.scenes[scene_ind]
-                pairs = scene["pair_infos"]
-                intrinsics = scene["intrinsics"]
-                poses = scene["poses"]
-                im_paths = scene["image_paths"]
-                pair_inds = range(len(pairs))
-                for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")):
-                    idx1, idx2 = pairs[pairind][0]
-                    K1 = intrinsics[idx1].copy()
-                    T1 = poses[idx1].copy()
-                    R1, t1 = T1[:3, :3], T1[:3, 3]
-                    K2 = intrinsics[idx2].copy()
-                    T2 = poses[idx2].copy()
-                    R2, t2 = T2[:3, :3], T2[:3, 3]
-                    R, t = compute_relative_pose(R1, t1, R2, t2)
-                    T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
-                    im_A_path = f"{data_root}/{im_paths[idx1]}"
-                    im_B_path = f"{data_root}/{im_paths[idx2]}"
-                    dense_matches, dense_certainty = model.match(
-                        im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
-                    )
-                    
-                    im_A = Image.open(im_A_path)
-                    w1, h1 = im_A.size
-                    im_B = Image.open(im_B_path)
-                    w2, h2 = im_B.size
-                    if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False. 
-                        scale1 = 1200 / max(w1, h1)
-                        scale2 = 1200 / max(w2, h2)
-                        w1, h1 = scale1 * w1, scale1 * h1
-                        w2, h2 = scale2 * w2, scale2 * h2
-                        K1, K2 = K1.copy(), K2.copy()
-                        K1[:2] = K1[:2] * scale1
-                        K2[:2] = K2[:2] * scale2
-                    for _ in range(5):
-                        sparse_matches,_ = model.sample(
-                            dense_matches, dense_certainty, 5_000
-                        )
-                        kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
-                        kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
-                        shuffling = np.random.permutation(np.arange(len(kpts1)))
-                        kpts1 = kpts1[shuffling]
-                        kpts2 = kpts2[shuffling]
-                        try:
-                            threshold = 0.5 
-                            norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
-                            R_est, t_est, mask = estimate_pose(
-                                kpts1,
-                                kpts2,
-                                K1,
-                                K2,
-                                norm_threshold,
-                                conf=0.99999,
-                            )
-                            T1_to_2_est = np.concatenate((R_est, t_est), axis=-1)  #
-                            e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
-                            e_pose = max(e_t, e_R)
-                        except Exception as e:
-                            print(repr(e))
-                            e_t, e_R = 90, 90
-                            e_pose = max(e_t, e_R)
-                        tot_e_t.append(e_t)
-                        tot_e_R.append(e_R)
-                        tot_e_pose.append(e_pose)
-                        pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}")
-
-            tot_e_pose = np.array(tot_e_pose)
-            auc = pose_auc(tot_e_pose, thresholds)
-            acc_5 = (tot_e_pose < 5).mean()
-            acc_10 = (tot_e_pose < 10).mean()
-            acc_15 = (tot_e_pose < 15).mean()
-            acc_20 = (tot_e_pose < 20).mean()
-            map_5 = acc_5
-            map_10 = np.mean([acc_5, acc_10])
-            map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
-            print(f"{model_name} auc: {auc}")
-            return {
-                "auc_5": auc[0],
-                "auc_10": auc[1],
-                "auc_20": auc[2],
-                "map_5": map_5,
-                "map_10": map_10,
-                "map_20": map_20,
-            }

+ 0 - 116
python/RoMa/romatch/benchmarks/megadepth_pose_estimation_benchmark_poselib.py

@@ -1,116 +0,0 @@
-import numpy as np
-import torch
-from romatch.utils import *
-from PIL import Image
-from tqdm import tqdm
-
-# wrap cause pyposelib is still in dev
-# will add in deps later
-import poselib
-
-class Mega1500PoseLibBenchmark:
-    def __init__(self, data_root="data/megadepth", scene_names = None, num_ransac_iter = 5, test_every = 1) -> None:
-        if scene_names is None:
-            self.scene_names = [
-                "0015_0.1_0.3.npz",
-                "0015_0.3_0.5.npz",
-                "0022_0.1_0.3.npz",
-                "0022_0.3_0.5.npz",
-                "0022_0.5_0.7.npz",
-            ]
-        else:
-            self.scene_names = scene_names
-        self.scenes = [
-            np.load(f"{data_root}/{scene}", allow_pickle=True)
-            for scene in self.scene_names
-        ]
-        self.data_root = data_root
-        self.num_ransac_iter = num_ransac_iter
-        self.test_every = test_every
-
-    def benchmark(self, model, model_name = None):
-        with torch.no_grad():
-            data_root = self.data_root
-            tot_e_t, tot_e_R, tot_e_pose = [], [], []
-            thresholds = [5, 10, 20]
-            for scene_ind in range(len(self.scenes)):
-                import os
-                scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
-                scene = self.scenes[scene_ind]
-                pairs = scene["pair_infos"]
-                intrinsics = scene["intrinsics"]
-                poses = scene["poses"]
-                im_paths = scene["image_paths"]
-                pair_inds = range(len(pairs))[::self.test_every]
-                for pairind in (pbar := tqdm(pair_inds, desc = "Current AUC: ?")):
-                    idx1, idx2 = pairs[pairind][0]
-                    K1 = intrinsics[idx1].copy()
-                    T1 = poses[idx1].copy()
-                    R1, t1 = T1[:3, :3], T1[:3, 3]
-                    K2 = intrinsics[idx2].copy()
-                    T2 = poses[idx2].copy()
-                    R2, t2 = T2[:3, :3], T2[:3, 3]
-                    R, t = compute_relative_pose(R1, t1, R2, t2)
-                    T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
-                    im_A_path = f"{data_root}/{im_paths[idx1]}"
-                    im_B_path = f"{data_root}/{im_paths[idx2]}"
-                    dense_matches, dense_certainty = model.match(
-                        im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
-                    )
-                    sparse_matches,_ = model.sample(
-                        dense_matches, dense_certainty, 5_000
-                    )
-                    
-                    im_A = Image.open(im_A_path)
-                    w1, h1 = im_A.size
-                    im_B = Image.open(im_B_path)
-                    w2, h2 = im_B.size
-                    kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
-                    kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
-                    for _ in range(self.num_ransac_iter):
-                        shuffling = np.random.permutation(np.arange(len(kpts1)))
-                        kpts1 = kpts1[shuffling]
-                        kpts2 = kpts2[shuffling]
-                        try:
-                            threshold = 1 
-                            camera1 = {'model': 'PINHOLE', 'width': w1, 'height': h1, 'params': K1[[0,1,0,1], [0,1,2,2]]}
-                            camera2 = {'model': 'PINHOLE', 'width': w2, 'height': h2, 'params': K2[[0,1,0,1], [0,1,2,2]]}
-                            relpose, res = poselib.estimate_relative_pose(
-                                kpts1, 
-                                kpts2,
-                                camera1,
-                                camera2,
-                                ransac_opt = {"max_reproj_error": 2*threshold, "max_epipolar_error": threshold, "min_inliers": 8, "max_iterations": 10_000},
-                            )
-                            Rt_est  = relpose.Rt
-                            R_est, t_est = Rt_est[:3,:3], Rt_est[:3,3:]
-                            mask = np.array(res['inliers']).astype(np.float32)
-                            T1_to_2_est = np.concatenate((R_est, t_est), axis=-1)  #
-                            e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
-                            e_pose = max(e_t, e_R)
-                        except Exception as e:
-                            print(repr(e))
-                            e_t, e_R = 90, 90
-                            e_pose = max(e_t, e_R)
-                        tot_e_t.append(e_t)
-                        tot_e_R.append(e_R)
-                        tot_e_pose.append(e_pose)
-                        pbar.set_description(f"Current AUC: {pose_auc(tot_e_pose, thresholds)}")
-            tot_e_pose = np.array(tot_e_pose)
-            auc = pose_auc(tot_e_pose, thresholds)
-            acc_5 = (tot_e_pose < 5).mean()
-            acc_10 = (tot_e_pose < 10).mean()
-            acc_15 = (tot_e_pose < 15).mean()
-            acc_20 = (tot_e_pose < 20).mean()
-            map_5 = acc_5
-            map_10 = np.mean([acc_5, acc_10])
-            map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
-            print(f"{model_name} auc: {auc}")
-            return {
-                "auc_5": auc[0],
-                "auc_10": auc[1],
-                "auc_20": auc[2],
-                "map_5": map_5,
-                "map_10": map_10,
-                "map_20": map_20,
-            }

+ 0 - 143
python/RoMa/romatch/benchmarks/scannet_benchmark.py

@@ -1,143 +0,0 @@
-import os.path as osp
-import numpy as np
-import torch
-from romatch.utils import *
-from PIL import Image
-from tqdm import tqdm
-
-
-class ScanNetBenchmark:
-    def __init__(self, data_root="data/scannet") -> None:
-        self.data_root = data_root
-
-    def benchmark(self, model, model_name = None):
-        model.train(False)
-        with torch.no_grad():
-            data_root = self.data_root
-            tmp = np.load(osp.join(data_root, "test.npz"))
-            pairs, rel_pose = tmp["name"], tmp["rel_pose"]
-            tot_e_t, tot_e_R, tot_e_pose = [], [], []
-            pair_inds = np.random.choice(
-                range(len(pairs)), size=len(pairs), replace=False
-            )
-            for pairind in tqdm(pair_inds, smoothing=0.9):
-                scene = pairs[pairind]
-                scene_name = f"scene0{scene[0]}_00"
-                im_A_path = osp.join(
-                        self.data_root,
-                        "scans_test",
-                        scene_name,
-                        "color",
-                        f"{scene[2]}.jpg",
-                    )
-                im_A = Image.open(im_A_path)
-                im_B_path = osp.join(
-                        self.data_root,
-                        "scans_test",
-                        scene_name,
-                        "color",
-                        f"{scene[3]}.jpg",
-                    )
-                im_B = Image.open(im_B_path)
-                T_gt = rel_pose[pairind].reshape(3, 4)
-                R, t = T_gt[:3, :3], T_gt[:3, 3]
-                K = np.stack(
-                    [
-                        np.array([float(i) for i in r.split()])
-                        for r in open(
-                            osp.join(
-                                self.data_root,
-                                "scans_test",
-                                scene_name,
-                                "intrinsic",
-                                "intrinsic_color.txt",
-                            ),
-                            "r",
-                        )
-                        .read()
-                        .split("\n")
-                        if r
-                    ]
-                )
-                w1, h1 = im_A.size
-                w2, h2 = im_B.size
-                K1 = K.copy()
-                K2 = K.copy()
-                dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
-                sparse_matches, sparse_certainty = model.sample(
-                    dense_matches, dense_certainty, 5000
-                )
-                scale1 = 480 / min(w1, h1)
-                scale2 = 480 / min(w2, h2)
-                w1, h1 = scale1 * w1, scale1 * h1
-                w2, h2 = scale2 * w2, scale2 * h2
-                K1 = K1 * scale1
-                K2 = K2 * scale2
-
-                offset = 0.5
-                kpts1 = sparse_matches[:, :2]
-                kpts1 = (
-                    np.stack(
-                        (
-                            w1 * (kpts1[:, 0] + 1) / 2 - offset,
-                            h1 * (kpts1[:, 1] + 1) / 2 - offset,
-                        ),
-                        axis=-1,
-                    )
-                )
-                kpts2 = sparse_matches[:, 2:]
-                kpts2 = (
-                    np.stack(
-                        (
-                            w2 * (kpts2[:, 0] + 1) / 2 - offset,
-                            h2 * (kpts2[:, 1] + 1) / 2 - offset,
-                        ),
-                        axis=-1,
-                    )
-                )
-                for _ in range(5):
-                    shuffling = np.random.permutation(np.arange(len(kpts1)))
-                    kpts1 = kpts1[shuffling]
-                    kpts2 = kpts2[shuffling]
-                    try:
-                        norm_threshold = 0.5 / (
-                        np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
-                        R_est, t_est, mask = estimate_pose(
-                            kpts1,
-                            kpts2,
-                            K1,
-                            K2,
-                            norm_threshold,
-                            conf=0.99999,
-                        )
-                        T1_to_2_est = np.concatenate((R_est, t_est), axis=-1)  #
-                        e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
-                        e_pose = max(e_t, e_R)
-                    except Exception as e:
-                        print(repr(e))
-                        e_t, e_R = 90, 90
-                        e_pose = max(e_t, e_R)
-                    tot_e_t.append(e_t)
-                    tot_e_R.append(e_R)
-                    tot_e_pose.append(e_pose)
-                tot_e_t.append(e_t)
-                tot_e_R.append(e_R)
-                tot_e_pose.append(e_pose)
-            tot_e_pose = np.array(tot_e_pose)
-            thresholds = [5, 10, 20]
-            auc = pose_auc(tot_e_pose, thresholds)
-            acc_5 = (tot_e_pose < 5).mean()
-            acc_10 = (tot_e_pose < 10).mean()
-            acc_15 = (tot_e_pose < 15).mean()
-            acc_20 = (tot_e_pose < 20).mean()
-            map_5 = acc_5
-            map_10 = np.mean([acc_5, acc_10])
-            map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
-            return {
-                "auc_5": auc[0],
-                "auc_10": auc[1],
-                "auc_20": auc[2],
-                "map_5": map_5,
-                "map_10": map_10,
-                "map_20": map_20,
-            }

+ 0 - 1
python/RoMa/romatch/checkpointing/__init__.py

@@ -1 +0,0 @@
-from .checkpoint import CheckPoint

+ 0 - 60
python/RoMa/romatch/checkpointing/checkpoint.py

@@ -1,60 +0,0 @@
-import os
-import torch
-from torch.nn.parallel.data_parallel import DataParallel
-from torch.nn.parallel.distributed import DistributedDataParallel
-from loguru import logger
-import gc
-
-import romatch
-
-class CheckPoint:
-    def __init__(self, dir=None, name="tmp"):
-        self.name = name
-        self.dir = dir
-        os.makedirs(self.dir, exist_ok=True)
-
-    def save(
-        self,
-        model,
-        optimizer,
-        lr_scheduler,
-        n,
-        ):
-        if romatch.RANK == 0:
-            assert model is not None
-            if isinstance(model, (DataParallel, DistributedDataParallel)):
-                model = model.module
-            states = {
-                "model": model.state_dict(),
-                "n": n,
-                "optimizer": optimizer.state_dict(),
-                "lr_scheduler": lr_scheduler.state_dict(),
-            }
-            torch.save(states, self.dir + self.name + f"_latest.pth")
-            logger.info(f"Saved states {list(states.keys())}, at step {n}")
-    
-    def load(
-        self,
-        model,
-        optimizer,
-        lr_scheduler,
-        n,
-        ):
-        if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
-            states = torch.load(self.dir + self.name + f"_latest.pth")
-            if "model" in states:
-                model.load_state_dict(states["model"])
-            if "n" in states:
-                n = states["n"] if states["n"] else n
-            if "optimizer" in states:
-                try:
-                    optimizer.load_state_dict(states["optimizer"])
-                except Exception as e:
-                    print(f"Failed to load states for optimizer, with error {e}")
-            if "lr_scheduler" in states:
-                lr_scheduler.load_state_dict(states["lr_scheduler"])
-            print(f"Loaded states {list(states.keys())}, at step {n}")
-            del states
-            gc.collect()
-            torch.cuda.empty_cache()
-        return model, optimizer, lr_scheduler, n

+ 0 - 2
python/RoMa/romatch/datasets/__init__.py

@@ -1,2 +0,0 @@
-from .megadepth import MegadepthBuilder
-from .scannet import ScanNetBuilder

+ 0 - 232
python/RoMa/romatch/datasets/megadepth.py

@@ -1,232 +0,0 @@
-import os
-from PIL import Image
-import h5py
-import numpy as np
-import torch
-import torchvision.transforms.functional as tvf
-import kornia.augmentation as K
-from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
-import romatch
-from romatch.utils import *
-import math
-
-class MegadepthScene:
-    def __init__(
-        self,
-        data_root,
-        scene_info,
-        ht=384,
-        wt=512,
-        min_overlap=0.0,
-        max_overlap=1.0,
-        shake_t=0,
-        rot_prob=0.0,
-        normalize=True,
-        max_num_pairs = 100_000,
-        scene_name = None,
-        use_horizontal_flip_aug = False,
-        use_single_horizontal_flip_aug = False,
-        colorjiggle_params = None,
-        random_eraser = None,
-        use_randaug = False,
-        randaug_params = None,
-        randomize_size = False,
-    ) -> None:
-        self.data_root = data_root
-        self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
-        self.image_paths = scene_info["image_paths"]
-        self.depth_paths = scene_info["depth_paths"]
-        self.intrinsics = scene_info["intrinsics"]
-        self.poses = scene_info["poses"]
-        self.pairs = scene_info["pairs"]
-        self.overlaps = scene_info["overlaps"]
-        threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
-        self.pairs = self.pairs[threshold]
-        self.overlaps = self.overlaps[threshold]
-        if len(self.pairs) > max_num_pairs:
-            pairinds = np.random.choice(
-                np.arange(0, len(self.pairs)), max_num_pairs, replace=False
-            )
-            self.pairs = self.pairs[pairinds]
-            self.overlaps = self.overlaps[pairinds]
-        if randomize_size:
-            area = ht * wt
-            s = int(16 * (math.sqrt(area)//16))
-            sizes = ((ht,wt), (s,s), (wt,ht))
-            choice = romatch.RANK % 3
-            ht, wt = sizes[choice] 
-        # counts, bins = np.histogram(self.overlaps,20)
-        # print(counts)
-        self.im_transform_ops = get_tuple_transform_ops(
-            resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
-        )
-        self.depth_transform_ops = get_depth_tuple_transform_ops(
-                resize=(ht, wt)
-            )
-        self.wt, self.ht = wt, ht
-        self.shake_t = shake_t
-        self.random_eraser = random_eraser
-        if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
-            raise ValueError("Can't both flip both images and only flip one")
-        self.use_horizontal_flip_aug = use_horizontal_flip_aug
-        self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
-        self.use_randaug = use_randaug
-
-    def load_im(self, im_path):
-        im = Image.open(im_path)
-        return im
-    
-    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
-        im_A = im_A.flip(-1)
-        im_B = im_B.flip(-1)
-        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
-        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
-        K_A = flip_mat@K_A  
-        K_B = flip_mat@K_B  
-        
-        return im_A, im_B, depth_A, depth_B, K_A, K_B
-    
-    def load_depth(self, depth_ref, crop=None):
-        depth = np.array(h5py.File(depth_ref, "r")["depth"])
-        return torch.from_numpy(depth)
-
-    def __len__(self):
-        return len(self.pairs)
-
-    def scale_intrinsic(self, K, wi, hi):
-        sx, sy = self.wt / wi, self.ht / hi
-        sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
-        return sK @ K
-
-    def rand_shake(self, *things):
-        t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
-        return [
-            tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
-            for thing in things
-        ], t
-
-    def __getitem__(self, pair_idx):
-        # read intrinsics of original size
-        idx1, idx2 = self.pairs[pair_idx]
-        K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
-        K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
-
-        # read and compute relative poses
-        T1 = self.poses[idx1]
-        T2 = self.poses[idx2]
-        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
-            :4, :4
-        ]  # (4, 4)
-
-        # Load positive pair data
-        im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
-        depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
-        im_A_ref = os.path.join(self.data_root, im_A)
-        im_B_ref = os.path.join(self.data_root, im_B)
-        depth_A_ref = os.path.join(self.data_root, depth1)
-        depth_B_ref = os.path.join(self.data_root, depth2)
-        im_A = self.load_im(im_A_ref)
-        im_B = self.load_im(im_B_ref)
-        K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
-        K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
-
-        if self.use_randaug:
-            im_A, im_B = self.rand_augment(im_A, im_B)
-
-        depth_A = self.load_depth(depth_A_ref)
-        depth_B = self.load_depth(depth_B_ref)
-        # Process images
-        im_A, im_B = self.im_transform_ops((im_A, im_B))
-        depth_A, depth_B = self.depth_transform_ops(
-            (depth_A[None, None], depth_B[None, None])
-        )
-        
-        [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
-        K1[:2, 2] += t
-        K2[:2, 2] += t
-        
-        im_A, im_B = im_A[None], im_B[None]
-        if self.random_eraser is not None:
-            im_A, depth_A = self.random_eraser(im_A, depth_A)
-            im_B, depth_B = self.random_eraser(im_B, depth_B)
-                
-        if self.use_horizontal_flip_aug:
-            if np.random.rand() > 0.5:
-                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
-        if self.use_single_horizontal_flip_aug:
-            if np.random.rand() > 0.5:
-                im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
-        
-        if romatch.DEBUG_MODE:
-            tensor_to_pil(im_A[0], unnormalize=True).save(
-                            f"vis/im_A.jpg")
-            tensor_to_pil(im_B[0], unnormalize=True).save(
-                            f"vis/im_B.jpg")
-            
-        data_dict = {
-            "im_A": im_A[0],
-            "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
-            "im_B": im_B[0],
-            "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
-            "im_A_depth": depth_A[0, 0],
-            "im_B_depth": depth_B[0, 0],
-            "K1": K1,
-            "K2": K2,
-            "T_1to2": T_1to2,
-            "im_A_path": im_A_ref,
-            "im_B_path": im_B_ref,
-            
-        }
-        return data_dict
-
-
-class MegadepthBuilder:
-    def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
-        self.data_root = data_root
-        self.scene_info_root = os.path.join(data_root, "prep_scene_info")
-        self.all_scenes = os.listdir(self.scene_info_root)
-        self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
-        # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
-        self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
-        self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
-        self.test_scenes_loftr = ["0015.npy", "0022.npy"]
-        self.loftr_ignore = loftr_ignore
-        self.imc21_ignore = imc21_ignore
-
-    def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
-        if split == "train":
-            scene_names = set(self.all_scenes) - set(self.test_scenes)
-        elif split == "train_loftr":
-            scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
-        elif split == "test":
-            scene_names = self.test_scenes
-        elif split == "test_loftr":
-            scene_names = self.test_scenes_loftr
-        elif split == "custom":
-            scene_names = scene_names
-        else:
-            raise ValueError(f"Split {split} not available")
-        scenes = []
-        for scene_name in scene_names:
-            if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
-                continue
-            if self.imc21_ignore and scene_name in self.imc21_scenes:
-                continue
-            if ".npy" not in scene_name:
-                continue
-            scene_info = np.load(
-                os.path.join(self.scene_info_root, scene_name), allow_pickle=True
-            ).item()
-            scenes.append(
-                MegadepthScene(
-                    self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
-                )
-            )
-        return scenes
-
-    def weight_scenes(self, concat_dataset, alpha=0.5):
-        ns = []
-        for d in concat_dataset.datasets:
-            ns.append(len(d))
-        ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
-        return ws

+ 0 - 160
python/RoMa/romatch/datasets/scannet.py

@@ -1,160 +0,0 @@
-import os
-import random
-from PIL import Image
-import cv2
-import h5py
-import numpy as np
-import torch
-from torch.utils.data import (
-    Dataset,
-    DataLoader,
-    ConcatDataset)
-
-import torchvision.transforms.functional as tvf
-import kornia.augmentation as K
-import os.path as osp
-import matplotlib.pyplot as plt
-import romatch
-from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
-from romatch.utils.transforms import GeometricSequential
-from tqdm import tqdm
-
-class ScanNetScene:
-    def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
-) -> None:
-        self.scene_root = osp.join(data_root,"scans","scans_train")
-        self.data_names = scene_info['name']
-        self.overlaps = scene_info['score']
-        # Only sample 10s
-        valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
-        self.overlaps = self.overlaps[valid]
-        self.data_names = self.data_names[valid]
-        if len(self.data_names) > 10000:
-            pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
-            self.data_names = self.data_names[pairinds]
-            self.overlaps = self.overlaps[pairinds]
-        self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
-        self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
-        self.wt, self.ht = wt, ht
-        self.shake_t = shake_t
-        self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
-        self.use_horizontal_flip_aug = use_horizontal_flip_aug
-
-    def load_im(self, im_B, crop=None):
-        im = Image.open(im_B)
-        return im
-    
-    def load_depth(self, depth_ref, crop=None):
-        depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
-        depth = depth / 1000
-        depth = torch.from_numpy(depth).float()  # (h, w)
-        return depth
-
-    def __len__(self):
-        return len(self.data_names)
-    
-    def scale_intrinsic(self, K, wi, hi):
-        sx, sy = self.wt / wi, self.ht /  hi
-        sK = torch.tensor([[sx, 0, 0],
-                        [0, sy, 0],
-                        [0, 0, 1]])
-        return sK@K
-
-    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
-        im_A = im_A.flip(-1)
-        im_B = im_B.flip(-1)
-        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
-        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
-        K_A = flip_mat@K_A  
-        K_B = flip_mat@K_B  
-        
-        return im_A, im_B, depth_A, depth_B, K_A, K_B
-    def read_scannet_pose(self,path):
-        """ Read ScanNet's Camera2World pose and transform it to World2Camera.
-        
-        Returns:
-            pose_w2c (np.ndarray): (4, 4)
-        """
-        cam2world = np.loadtxt(path, delimiter=' ')
-        world2cam = np.linalg.inv(cam2world)
-        return world2cam
-
-
-    def read_scannet_intrinsic(self,path):
-        """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
-        """
-        intrinsic = np.loadtxt(path, delimiter=' ')
-        return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
-
-    def __getitem__(self, pair_idx):
-        # read intrinsics of original size
-        data_name = self.data_names[pair_idx]
-        scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
-        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
-        
-        # read the intrinsic of depthmap
-        K1 = K2 =  self.read_scannet_intrinsic(osp.join(self.scene_root,
-                       scene_name,
-                       'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
-        # read and compute relative poses
-        T1 =  self.read_scannet_pose(osp.join(self.scene_root,
-                       scene_name,
-                       'pose', f'{stem_name_1}.txt'))
-        T2 =  self.read_scannet_pose(osp.join(self.scene_root,
-                       scene_name,
-                       'pose', f'{stem_name_2}.txt'))
-        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4]  # (4, 4)
-
-        # Load positive pair data
-        im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
-        im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
-        depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
-        depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
-
-        im_A = self.load_im(im_A_ref)
-        im_B = self.load_im(im_B_ref)
-        depth_A = self.load_depth(depth_A_ref)
-        depth_B = self.load_depth(depth_B_ref)
-
-        # Recompute camera intrinsic matrix due to the resize
-        K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
-        K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
-        # Process images
-        im_A, im_B = self.im_transform_ops((im_A, im_B))
-        depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
-        if self.use_horizontal_flip_aug:
-            if np.random.rand() > 0.5:
-                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
-
-        data_dict = {'im_A': im_A,
-                    'im_B': im_B,
-                    'im_A_depth': depth_A[0,0],
-                    'im_B_depth': depth_B[0,0],
-                    'K1': K1,
-                    'K2': K2,
-                    'T_1to2':T_1to2,
-                    }
-        return data_dict
-
-
-class ScanNetBuilder:
-    def __init__(self, data_root = 'data/scannet') -> None:
-        self.data_root = data_root
-        self.scene_info_root = os.path.join(data_root,'scannet_indices')
-        self.all_scenes = os.listdir(self.scene_info_root)
-        
-    def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
-        # Note: split doesn't matter here as we always use same scannet_train scenes
-        scene_names = self.all_scenes
-        scenes = []
-        for scene_name in tqdm(scene_names, disable = romatch.RANK > 0):
-            scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
-            scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
-        return scenes
-    
-    def weight_scenes(self, concat_dataset, alpha=.5):
-        ns = []
-        for d in concat_dataset.datasets:
-            ns.append(len(d))
-        ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
-        return ws

+ 0 - 1
python/RoMa/romatch/losses/__init__.py

@@ -1 +0,0 @@
-from .robust_loss import RobustLosses

+ 0 - 161
python/RoMa/romatch/losses/robust_loss.py

@@ -1,161 +0,0 @@
-from einops.einops import rearrange
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from romatch.utils.utils import get_gt_warp
-import wandb
-import romatch
-import math
-
-class RobustLosses(nn.Module):
-    def __init__(
-        self,
-        robust=False,
-        center_coords=False,
-        scale_normalize=False,
-        ce_weight=0.01,
-        local_loss=True,
-        local_dist=4.0,
-        local_largest_scale=8,
-        smooth_mask = False,
-        depth_interpolation_mode = "bilinear",
-        mask_depth_loss = False,
-        relative_depth_error_threshold = 0.05,
-        alpha = 1.,
-        c = 1e-3,
-    ):
-        super().__init__()
-        self.robust = robust  # measured in pixels
-        self.center_coords = center_coords
-        self.scale_normalize = scale_normalize
-        self.ce_weight = ce_weight
-        self.local_loss = local_loss
-        self.local_dist = local_dist
-        self.local_largest_scale = local_largest_scale
-        self.smooth_mask = smooth_mask
-        self.depth_interpolation_mode = depth_interpolation_mode
-        self.mask_depth_loss = mask_depth_loss
-        self.relative_depth_error_threshold = relative_depth_error_threshold
-        self.avg_overlap = dict()
-        self.alpha = alpha
-        self.c = c
-
-    def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
-        with torch.no_grad():
-            B, C, H, W = scale_gm_cls.shape
-            device = x2.device
-            cls_res = round(math.sqrt(C))
-            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)], indexing='ij')
-            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
-            GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
-        cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction  = 'none')[prob > 0.99]
-        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
-        if not torch.any(cls_loss):
-            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
-            
-        losses = {
-            f"gm_certainty_loss_{scale}": certainty_loss.mean(),
-            f"gm_cls_loss_{scale}": cls_loss.mean(),
-        }
-        wandb.log(losses, step = romatch.GLOBAL_STEP)
-        return losses
-
-    def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
-        with torch.no_grad():
-            B, C, H, W = delta_cls.shape
-            device = x2.device
-            cls_res = round(math.sqrt(C))
-            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
-            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
-            GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
-        cls_loss = F.cross_entropy(delta_cls, GT, reduction  = 'none')[prob > 0.99]
-        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
-        if not torch.any(cls_loss):
-            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
-        losses = {
-            f"delta_certainty_loss_{scale}": certainty_loss.mean(),
-            f"delta_cls_loss_{scale}": cls_loss.mean(),
-        }
-        wandb.log(losses, step = romatch.GLOBAL_STEP)
-        return losses
-
-    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
-        epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
-        if scale == 1:
-            pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
-            wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
-
-        ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
-        a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
-        cs = self.c * scale
-        x = epe[prob > 0.99]
-        reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
-        if not torch.any(reg_loss):
-            reg_loss = (ce_loss * 0.0)  # Prevent issues where prob is 0 everywhere
-        losses = {
-            f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
-            f"{mode}_regression_loss_{scale}": reg_loss.mean(),
-        }
-        wandb.log(losses, step = romatch.GLOBAL_STEP)
-        return losses
-
-    def forward(self, corresps, batch):
-        scales = list(corresps.keys())
-        tot_loss = 0.0
-        # scale_weights due to differences in scale for regression gradients and classification gradients
-        scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
-        for scale in scales:
-            scale_corresps = corresps[scale]
-            scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
-                scale_corresps["certainty"],
-                scale_corresps.get("flow_pre_delta"),
-                scale_corresps.get("delta_cls"),
-                scale_corresps.get("offset_scale"),
-                scale_corresps.get("gm_cls"),
-                scale_corresps.get("gm_certainty"),
-                scale_corresps["flow"],
-                scale_corresps.get("gm_flow"),
-
-            )
-            if flow_pre_delta is not None:
-                flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
-                b, h, w, d = flow_pre_delta.shape
-            else:
-                # _ = 1
-                b, _, h, w = scale_certainty.shape
-            gt_warp, gt_prob = get_gt_warp(                
-            batch["im_A_depth"],
-            batch["im_B_depth"],
-            batch["T_1to2"],
-            batch["K1"],
-            batch["K2"],
-            H=h,
-            W=w,
-        )
-            x2 = gt_warp.float()
-            prob = gt_prob
-            
-            if self.local_largest_scale >= scale:
-                prob = prob * (
-                        F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
-                        < (2 / 512) * (self.local_dist[scale] * scale))
-            
-            if scale_gm_cls is not None:
-                gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
-                gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
-                tot_loss = tot_loss + scale_weights[scale] * gm_loss
-            elif scale_gm_flow is not None:
-                gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
-                gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
-                tot_loss = tot_loss + scale_weights[scale] * gm_loss
-            
-            if delta_cls is not None:
-                delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
-                delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
-                tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
-            else:
-                delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
-                reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
-                tot_loss = tot_loss + scale_weights[scale] * reg_loss
-            prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
-        return tot_loss

+ 0 - 160
python/RoMa/romatch/losses/robust_loss_tiny_roma.py

@@ -1,160 +0,0 @@
-from einops.einops import rearrange
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from romatch.utils.utils import get_gt_warp
-import wandb
-import romatch
-import math
-
-# This is slightly different than regular romatch due to significantly worse corresps
-# The confidence loss is quite tricky here //Johan
-
-class RobustLosses(nn.Module):
-    def __init__(
-        self,
-        robust=False,
-        center_coords=False,
-        scale_normalize=False,
-        ce_weight=0.01,
-        local_loss=True,
-        local_dist=None,
-        smooth_mask = False,
-        depth_interpolation_mode = "bilinear",
-        mask_depth_loss = False,
-        relative_depth_error_threshold = 0.05,
-        alpha = 1.,
-        c = 1e-3,
-        epe_mask_prob_th = None,
-        cert_only_on_consistent_depth = False,
-    ):
-        super().__init__()
-        if local_dist is None:
-            local_dist = {}
-        self.robust = robust  # measured in pixels
-        self.center_coords = center_coords
-        self.scale_normalize = scale_normalize
-        self.ce_weight = ce_weight
-        self.local_loss = local_loss
-        self.local_dist = local_dist
-        self.smooth_mask = smooth_mask
-        self.depth_interpolation_mode = depth_interpolation_mode
-        self.mask_depth_loss = mask_depth_loss
-        self.relative_depth_error_threshold = relative_depth_error_threshold
-        self.avg_overlap = dict()
-        self.alpha = alpha
-        self.c = c
-        self.epe_mask_prob_th = epe_mask_prob_th
-        self.cert_only_on_consistent_depth = cert_only_on_consistent_depth
-
-    def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale):
-        b, h,w, h,w = corr_volume.shape
-        inv_temp = 10
-        corr_volume = corr_volume.reshape(-1, h*w, h*w)
-        nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2)
-        corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean()
-        
-        losses = {
-            f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(),
-        }
-        wandb.log(losses, step = romatch.GLOBAL_STEP)
-        return losses
-
-    
-
-    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
-        epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
-        if scale in self.local_dist:
-            prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float()
-        if scale == 1:
-            pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
-            wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
-        if self.epe_mask_prob_th is not None:
-            # if too far away from gt, certainty should be 0
-            gt_cert = prob * (epe < scale * self.epe_mask_prob_th)
-        else:
-            gt_cert = prob
-        if self.cert_only_on_consistent_depth:
-            ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0])
-        else:    
-            ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert)
-        a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
-        cs = self.c * scale
-        x = epe[prob > 0.99]
-        reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
-        if not torch.any(reg_loss):
-            reg_loss = (ce_loss * 0.0)  # Prevent issues where prob is 0 everywhere
-        losses = {
-            f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
-            f"{mode}_regression_loss_{scale}": reg_loss.mean(),
-        }
-        wandb.log(losses, step = romatch.GLOBAL_STEP)
-        return losses
-
-    def forward(self, corresps, batch):
-        scales = list(corresps.keys())
-        tot_loss = 0.0
-        # scale_weights due to differences in scale for regression gradients and classification gradients
-        for scale in scales:
-            scale_corresps = corresps[scale]
-            scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = (
-                scale_corresps["certainty"],
-                scale_corresps.get("flow_pre_delta"),
-                scale_corresps.get("delta_cls"),
-                scale_corresps.get("offset_scale"),
-                scale_corresps.get("corr_volume"),
-                scale_corresps.get("gm_certainty"),
-                scale_corresps["flow"],
-                scale_corresps.get("gm_flow"),
-
-            )
-            if flow_pre_delta is not None:
-                flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
-                b, h, w, d = flow_pre_delta.shape
-            else:
-                # _ = 1
-                b, _, h, w = scale_certainty.shape
-            gt_warp, gt_prob = get_gt_warp(                
-            batch["im_A_depth"],
-            batch["im_B_depth"],
-            batch["T_1to2"],
-            batch["K1"],
-            batch["K2"],
-            H=h,
-            W=w,
-            )
-            x2 = gt_warp.float()
-            prob = gt_prob
-                        
-            if scale_gm_corr_volume is not None:
-                gt_warp_back, _ = get_gt_warp(                
-                batch["im_B_depth"],
-                batch["im_A_depth"],
-                batch["T_1to2"].inverse(),
-                batch["K2"],
-                batch["K1"],
-                H=h,
-                W=w,
-                )
-                grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device)
-                #fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1)
-                #diff = (fwd_bck - grid).norm(dim = -1)
-                with torch.no_grad():
-                    D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2))
-                    D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2))
-                    inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values) 
-                                        * (D_A == D_A.min(dim=-2, keepdim = True).values)
-                                        * (D_B < 0.01)
-                                        * (D_A < 0.01))
-
-                gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale)
-                gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"]
-                tot_loss = tot_loss + gm_loss
-            elif scale_gm_flow is not None:
-                gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
-                gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
-                tot_loss = tot_loss +  gm_loss
-            delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
-            reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
-            tot_loss = tot_loss + reg_loss
-        return tot_loss

+ 0 - 1
python/RoMa/romatch/models/__init__.py

@@ -1 +0,0 @@
-from .model_zoo import roma_outdoor, tiny_roma_v1_outdoor, roma_indoor

+ 0 - 68
python/RoMa/romatch/models/encoders.py

@@ -1,68 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision.models as tvm
-from romatch.utils.utils import get_autocast_params
-
-class VGG19(nn.Module):
-    def __init__(self, pretrained=True, amp = False, amp_dtype = torch.float16) -> None:
-        super().__init__()
-        if pretrained:
-            weights = tvm.vgg.VGG19_BN_Weights.IMAGENET1K_V1
-        else:
-            weights = None
-        self.layers = nn.ModuleList(tvm.vgg19_bn(weights=weights).features[:40])
-        self.amp = amp
-        self.amp_dtype = amp_dtype
-
-    def forward(self, x, **kwargs):
-        autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, self.amp, self.amp_dtype)
-        with torch.autocast(device_type=autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
-            feats = {}
-            scale = 1
-            for layer in self.layers:
-                if isinstance(layer, nn.MaxPool2d):
-                    feats[scale] = x
-                    scale = scale*2
-                x = layer(x)
-            return feats
-
-class CNNandDinov2(nn.Module):
-    def __init__(self, cnn_kwargs = None, amp = False, dinov2_weights = None, amp_dtype = torch.float16):
-        super().__init__()
-        if dinov2_weights is None:
-            dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
-        from .transformer import vit_large
-        vit_kwargs = dict(img_size= 518,
-            patch_size= 14,
-            init_values = 1.0,
-            ffn_layer = "mlp",
-            block_chunks = 0,
-        )
-
-        dinov2_vitl14 = vit_large(**vit_kwargs).eval()
-        dinov2_vitl14.load_state_dict(dinov2_weights)
-        cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
-        self.cnn = VGG19(**cnn_kwargs)
-        self.amp = amp
-        self.amp_dtype = amp_dtype
-        if self.amp:
-            dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
-        self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
-    
-    
-    def train(self, mode: bool = True):
-        return self.cnn.train(mode)
-    
-    def forward(self, x, upsample = False):
-        B,C,H,W = x.shape
-        feature_pyramid = self.cnn(x)
-        
-        if not upsample:
-            with torch.no_grad():
-                if self.dinov2_vitl14[0].device != x.device:
-                    self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
-                dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
-                features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
-                del dinov2_features_16
-                feature_pyramid[16] = features_16
-        return feature_pyramid

+ 0 - 1001
python/RoMa/romatch/models/matcher.py

@@ -1,1001 +0,0 @@
-import os
-import math
-import sys
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from warnings import warn
-from PIL import Image
-
-from romatch.utils import get_tuple_transform_ops
-from romatch.utils.local_correlation import local_correlation
-from romatch.utils.utils import (
-    check_rgb,
-    cls_to_flow_refine,
-    get_autocast_params,
-    check_not_i16,
-)
-from romatch.utils.kde import kde
-from romatch.models.encoders import CNNandDinov2
-
-class ConvRefiner(nn.Module):
-    def __init__(
-        self,
-        in_dim=6,
-        hidden_dim=16,
-        out_dim=2,
-        dw=False,
-        kernel_size=5,
-        hidden_blocks=3,
-        displacement_emb=None,
-        displacement_emb_dim=None,
-        local_corr_radius=None,
-        corr_in_other=None,
-        no_im_B_fm=False,
-        amp=False,
-        concat_logits=False,
-        use_bias_block_1=True,
-        use_cosine_corr=False,
-        disable_local_corr_grad=False,
-        is_classifier=False,
-        sample_mode="bilinear",
-        norm_type=nn.BatchNorm2d,
-        bn_momentum=0.1,
-        amp_dtype=torch.float16,
-        use_custom_corr=False,
-    ):
-        super().__init__()
-        if sys.platform != "linux":
-            warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
-            use_custom_corr = False
-        self.bn_momentum = bn_momentum
-        self.block1 = self.create_block(
-            in_dim,
-            hidden_dim,
-            dw=dw,
-            kernel_size=kernel_size,
-            bias=use_bias_block_1,
-        )
-        self.hidden_blocks = nn.Sequential(
-            *[
-                self.create_block(
-                    hidden_dim,
-                    hidden_dim,
-                    dw=dw,
-                    kernel_size=kernel_size,
-                    norm_type=norm_type,
-                )
-                for hb in range(hidden_blocks)
-            ]
-        )
-        self.hidden_blocks = self.hidden_blocks
-        self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
-        if displacement_emb:
-            self.has_displacement_emb = True
-            self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0)
-        else:
-            self.has_displacement_emb = False
-        self.local_corr_radius = local_corr_radius
-        self.corr_in_other = corr_in_other
-        self.no_im_B_fm = no_im_B_fm
-        self.amp = amp
-        self.concat_logits = concat_logits
-        self.use_cosine_corr = use_cosine_corr
-        self.disable_local_corr_grad = disable_local_corr_grad
-        self.is_classifier = is_classifier
-        self.sample_mode = sample_mode
-        self.amp_dtype = amp_dtype
-        self.use_custom_corr = use_custom_corr
-
-    def create_block(
-        self,
-        in_dim,
-        out_dim,
-        dw=False,
-        kernel_size=5,
-        bias=True,
-        norm_type=nn.BatchNorm2d,
-    ):
-        num_groups = 1 if not dw else in_dim
-        if dw:
-            assert out_dim % in_dim == 0, (
-                "outdim must be divisible by indim for depthwise"
-            )
-        conv1 = nn.Conv2d(
-            in_dim,
-            out_dim,
-            kernel_size=kernel_size,
-            stride=1,
-            padding=kernel_size // 2,
-            groups=num_groups,
-            bias=bias,
-        )
-        norm = (
-            norm_type(out_dim, momentum=self.bn_momentum)
-            if norm_type is nn.BatchNorm2d
-            else norm_type(num_channels=out_dim)
-        )
-        relu = nn.ReLU(inplace=True)
-        conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
-        return nn.Sequential(conv1, norm, relu, conv2)
-
-    def forward(self, x, y, warp, scale_factor=1, logits=None):
-        b, c, hs, ws = x.shape
-        autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(
-            x.device, enabled=self.amp, dtype=self.amp_dtype
-        )
-        with torch.autocast(
-            autocast_device, enabled=autocast_enabled, dtype=autocast_dtype
-        ):
-            x_hat = F.grid_sample(
-                y, warp.permute(0, 2, 3, 1), align_corners=False, mode=self.sample_mode
-            )
-            if self.has_displacement_emb:
-                im_A_coords = torch.meshgrid(
-                    (
-                        torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
-                        torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
-                    ),
-                    indexing="ij",
-                )
-                im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
-                im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
-                in_displacement = warp - im_A_coords
-                emb_in_displacement = self.disp_emb(
-                    40 / 32 * scale_factor * in_displacement
-                )
-                if self.local_corr_radius:
-                    if self.corr_in_other:
-                        # Corr in other means take a kxk grid around the predicted coordinate in other image
-                        local_corr = local_correlation(
-                            x,
-                            y,
-                            self.local_corr_radius,
-                            warp,
-                            sample_mode=self.sample_mode,
-                            use_custom_corr=self.use_custom_corr,
-                        )
-                    else:
-                        raise NotImplementedError(
-                            "Local corr in own frame should not be used."
-                        )
-                    if self.no_im_B_fm:
-                        x_hat = torch.zeros_like(x)
-                    d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
-                else:
-                    d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
-            else:
-                if self.no_im_B_fm:
-                    x_hat = torch.zeros_like(x)
-                d = torch.cat((x, x_hat), dim=1)
-            if self.concat_logits:
-                d = torch.cat((d, logits), dim=1)
-            # pad d if needed
-            channel_d = d.shape[1]
-            channel_block1 = self.block1[0].in_channels
-            if channel_d != channel_block1:
-                d = F.pad(d, (0, 0, 0, 0, 0, channel_block1 - channel_d))
-            d = self.block1(d)
-            d = self.hidden_blocks(d)
-        d = self.out_conv(d.float())
-        displacement, certainty = d[:, :-1], d[:, -1:]
-        return displacement, certainty
-
-
-class CosKernel(nn.Module):  # similar to softmax kernel
-    def __init__(self, T, learn_temperature=False):
-        super().__init__()
-        self.learn_temperature = learn_temperature
-        if self.learn_temperature:
-            self.T = nn.Parameter(torch.tensor(T))
-        else:
-            self.T = T
-
-    def __call__(self, x, y, eps=1e-6):
-        c = torch.einsum("bnd,bmd->bnm", x, y) / (
-            x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
-        )
-        if self.learn_temperature:
-            T = self.T.abs() + 0.01
-        else:
-            T = torch.tensor(self.T, device=c.device)
-        K = ((c - 1.0) / T).exp()
-        return K
-
-
-class GP(nn.Module):
-    def __init__(
-        self,
-        kernel,
-        T=1,
-        learn_temperature=False,
-        only_attention=False,
-        gp_dim=64,
-        basis="fourier",
-        covar_size=5,
-        only_nearest_neighbour=False,
-        sigma_noise=0.1,
-        no_cov=False,
-        predict_features=False,
-    ):
-        super().__init__()
-        self.K = kernel(T=T, learn_temperature=learn_temperature)
-        self.sigma_noise = sigma_noise
-        self.covar_size = covar_size
-        self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
-        self.only_attention = only_attention
-        self.only_nearest_neighbour = only_nearest_neighbour
-        self.basis = basis
-        self.no_cov = no_cov
-        self.dim = gp_dim
-        self.predict_features = predict_features
-
-    def get_local_cov(self, cov):
-        K = self.covar_size
-        b, h, w, h, w = cov.shape
-        hw = h * w
-        cov = F.pad(cov, 4 * (K // 2,))  # pad v_q
-        delta = torch.stack(
-            torch.meshgrid(
-                torch.arange(-(K // 2), K // 2 + 1),
-                torch.arange(-(K // 2), K // 2 + 1),
-                indexing="ij",
-            ),
-            dim=-1,
-        )
-        positions = torch.stack(
-            torch.meshgrid(
-                torch.arange(K // 2, h + K // 2),
-                torch.arange(K // 2, w + K // 2),
-                indexing="ij",
-            ),
-            dim=-1,
-        )
-        neighbours = positions[:, :, None, None, :] + delta[None, :, :]
-        points = torch.arange(hw)[:, None].expand(hw, K**2)
-        local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
-            :,
-            points.flatten(),
-            neighbours[..., 0].flatten(),
-            neighbours[..., 1].flatten(),
-        ].reshape(b, h, w, K**2)
-        return local_cov
-
-    def reshape(self, x):
-        return rearrange(x, "b d h w -> b (h w) d")
-
-    def project_to_basis(self, x):
-        if self.basis == "fourier":
-            return torch.cos(8 * math.pi * self.pos_conv(x))
-        elif self.basis == "linear":
-            return self.pos_conv(x)
-        else:
-            raise ValueError(
-                "No other bases other than fourier and linear currently im_Bed in public release"
-            )
-
-    def get_pos_enc(self, y):
-        b, c, h, w = y.shape
-        coarse_coords = torch.meshgrid(
-            (
-                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
-                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
-            ),
-            indexing="ij",
-        )
-
-        coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
-            None
-        ].expand(b, h, w, 2)
-        coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
-        coarse_embedded_coords = self.project_to_basis(coarse_coords)
-        return coarse_embedded_coords
-
-    def forward(self, x, y, **kwargs):
-        b, c, h1, w1 = x.shape
-        b, c, h2, w2 = y.shape
-        f = self.get_pos_enc(y)
-        b, d, h2, w2 = f.shape
-        x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
-        # K_xx = self.K(x, x)
-        K_yy = self.K(y, y)
-        K_xy = self.K(x, y)
-        K_yx = K_xy.permute(0, 2, 1)
-        sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
-        if self.training:
-            K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
-            mu_x = K_xy.matmul(K_yy_inv.matmul(f))
-        else:
-            # faster inference, possibly also useful for training
-            L_t = torch.linalg.cholesky(K_yy + sigma_noise)
-            pos_emb = torch.cholesky_solve(f.reshape(b, h2 * w2, d), L_t, upper=False)
-            mu_x = K_xy @ pos_emb
-        mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
-
-
-        # if not self.no_cov:
-        #     cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
-        #     cov_x = rearrange(
-        #         cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
-        #     )
-        #     local_cov_x = self.get_local_cov(cov_x)
-        #     local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
-        #     gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
-        # else:
-        gp_feats = mu_x
-        return gp_feats
-
-
-class Decoder(nn.Module):
-    def __init__(
-        self,
-        embedding_decoder,
-        gps,
-        proj,
-        conv_refiner,
-        detach=False,
-        scales="all",
-        pos_embeddings=None,
-        num_refinement_steps_per_scale=1,
-        warp_noise_std=0.0,
-        displacement_dropout_p=0.0,
-        gm_warp_dropout_p=0.0,
-        flow_upsample_mode="bilinear",
-        amp_dtype=torch.float16,
-    ):
-        super().__init__()
-        self.embedding_decoder = embedding_decoder
-        self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
-        self.gps = gps
-        self.proj = proj
-        self.conv_refiner = conv_refiner
-        self.detach = detach
-        if pos_embeddings is None:
-            self.pos_embeddings = {}
-        else:
-            self.pos_embeddings = pos_embeddings
-        if scales == "all":
-            self.scales = ["32", "16", "8", "4", "2", "1"]
-        else:
-            self.scales = scales
-        self.warp_noise_std = warp_noise_std
-        self.refine_init = 4
-        self.displacement_dropout_p = displacement_dropout_p
-        self.gm_warp_dropout_p = gm_warp_dropout_p
-        self.flow_upsample_mode = flow_upsample_mode
-        self.amp_dtype = amp_dtype
-
-    def get_placeholder_flow(self, b, h, w, device):
-        coarse_coords = torch.meshgrid(
-            (
-                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
-                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
-            ),
-            indexing="ij",
-        )
-        coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
-            None
-        ].expand(b, h, w, 2)
-        coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
-        return coarse_coords
-
-    def get_positional_embedding(self, b, h, w, device):
-        coarse_coords = torch.meshgrid(
-            (
-                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
-                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
-            ),
-            indexing="ij",
-        )
-
-        coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
-            None
-        ].expand(b, h, w, 2)
-        coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
-        coarse_embedded_coords = self.pos_embedding(coarse_coords)
-        return coarse_embedded_coords
-
-    def forward(
-        self,
-        f1,
-        f2,
-        gt_warp=None,
-        gt_prob=None,
-        upsample=False,
-        flow=None,
-        certainty=None,
-        scale_factor=1,
-    ):
-        coarse_scales = self.embedding_decoder.scales()
-        all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
-        sizes = {scale: f1[scale].shape[-2:] for scale in f1}
-        h, w = sizes[1]
-        b = f1[1].shape[0]
-        device = f1[1].device
-        coarsest_scale = int(all_scales[0])
-        old_stuff = torch.zeros(
-            b,
-            self.embedding_decoder.hidden_dim,
-            *sizes[coarsest_scale],
-            device=f1[coarsest_scale].device,
-        )
-        corresps = {}
-        if not upsample:
-            flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
-            certainty = 0.0
-        else:
-            flow = F.interpolate(
-                flow,
-                size=sizes[coarsest_scale],
-                align_corners=False,
-                mode="bilinear",
-            )
-            certainty = F.interpolate(
-                certainty,
-                size=sizes[coarsest_scale],
-                align_corners=False,
-                mode="bilinear",
-            )
-        displacement = 0.0
-        for new_scale in all_scales:
-            ins = int(new_scale)
-            corresps[ins] = {}
-            f1_s, f2_s = f1[ins], f2[ins]
-            if new_scale in self.proj:
-                autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(
-                    f1_s.device, str(f1_s) == "cuda", self.amp_dtype
-                )
-                with torch.autocast(
-                    autocast_device, enabled=autocast_enabled, dtype=autocast_dtype
-                ):
-                    if not autocast_enabled:
-                        f1_s, f2_s = f1_s.to(torch.float32), f2_s.to(torch.float32)
-                    f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
-
-            if ins in coarse_scales:
-                old_stuff = F.interpolate(
-                    old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
-                )
-                gp_posterior = self.gps[new_scale](f1_s, f2_s)
-                gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
-                    gp_posterior, f1_s, old_stuff, new_scale
-                )
-
-                if self.embedding_decoder.is_classifier:
-                    flow = cls_to_flow_refine(
-                        gm_warp_or_cls,
-                    ).permute(0, 3, 1, 2)
-                    corresps[ins].update(
-                        {
-                            "gm_cls": gm_warp_or_cls,
-                            "gm_certainty": certainty,
-                        }
-                    ) if self.training else None
-                else:
-                    corresps[ins].update(
-                        {
-                            "gm_flow": gm_warp_or_cls,
-                            "gm_certainty": certainty,
-                        }
-                    ) if self.training else None
-                    flow = gm_warp_or_cls.detach()
-
-            if new_scale in self.conv_refiner:
-                corresps[ins].update(
-                    {"flow_pre_delta": flow}
-                ) if self.training else None
-                delta_flow, delta_certainty = self.conv_refiner[new_scale](
-                    f1_s,
-                    f2_s,
-                    flow,
-                    scale_factor=scale_factor,
-                    logits=certainty,
-                )
-                corresps[ins].update(
-                    {
-                        "delta_flow": delta_flow,
-                    }
-                ) if self.training else None
-                displacement = ins * torch.stack(
-                    (
-                        delta_flow[:, 0].float() / (self.refine_init * w),
-                        delta_flow[:, 1].float() / (self.refine_init * h),
-                    ),
-                    dim=1,
-                )
-                flow = flow + displacement
-                certainty = (
-                    certainty + delta_certainty
-                )  # predict both certainty and displacement
-            corresps[ins].update(
-                {
-                    "certainty": certainty,
-                    "flow": flow,
-                }
-            )
-            if new_scale != "1":
-                flow = F.interpolate(
-                    flow,
-                    size=sizes[ins // 2],
-                    mode=self.flow_upsample_mode,
-                )
-                certainty = F.interpolate(
-                    certainty,
-                    size=sizes[ins // 2],
-                    mode=self.flow_upsample_mode,
-                )
-                if self.detach:
-                    flow = flow.detach()
-                    certainty = certainty.detach()
-        return corresps
-
-
-def _check_input(im_input):
-    if isinstance(im_input, (str, os.PathLike)):
-        im = Image.open(im_input)
-        check_not_i16(im)
-        im = im.convert("RGB")
-    elif isinstance(im_input, Image.Image):
-        check_rgb(im_input)
-        im = im_input
-    else:
-        assert isinstance(im_input, torch.Tensor), (
-            "im_input must be a string, path, or PIL image"
-        )
-        B, C, H, W = im_input.shape
-        assert C == 3, "im_input must be a RGB image"
-        assert H % 14 == 0, "im_input must be a multiple of 14"
-        assert W % 14 == 0, "im_input must be a multiple of 14"
-        im = im_input
-    return im
-
-
-class RegressionMatcher(nn.Module):
-    def __init__(
-        self,
-        encoder: CNNandDinov2,
-        decoder: Decoder,
-        h=448,
-        w=448,
-        sample_mode="threshold_balanced",
-        upsample_preds=False,
-        symmetric=False,
-        sample_thresh=0.05,
-        name=None,
-        attenuate_cert=None,
-        upsample_res=None,
-    ):
-        super().__init__()
-        self.attenuate_cert = attenuate_cert
-        self.encoder = encoder
-        self.decoder = decoder
-        self.name = name
-        self.w_resized = w
-        self.h_resized = h
-        self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
-        self.sample_mode = sample_mode
-        self.upsample_preds = upsample_preds
-        self.upsample_res = upsample_res or (14 * 16 * 6, 14 * 16 * 6)
-        self.symmetric = symmetric
-        self.sample_thresh = sample_thresh
-
-    def get_output_resolution(self):
-        if not self.upsample_preds:
-            return self.h_resized, self.w_resized
-        else:
-            return self.upsample_res
-
-    def extract_backbone_features(self, batch, batched=True, upsample=False):
-        if 'unique_images' in batch:
-            unique_images = batch['unique_images']
-            im_AB_idx = batch['im_AB_idx']
-            feature_pyramid0 = self.encoder(unique_images, upsample=upsample)
-            feature_pyramid = {
-                scale: feature_pyramid0[scale][im_AB_idx]
-                for scale in feature_pyramid0
-            }
-            return feature_pyramid
-            
-        x_q = batch["im_A"]
-        x_s = batch["im_B"]
-        if batched:
-            X = torch.cat((x_q, x_s), dim=0)
-            feature_pyramid = self.encoder(X, upsample=upsample)
-        else:
-            feature_pyramid = (
-                self.encoder(x_q, upsample=upsample),
-                self.encoder(x_s, upsample=upsample),
-            )
-        return feature_pyramid
-
-    def sample(
-        self,
-        matches,
-        certainty,
-        num=10000,
-    ):
-        if "threshold" in self.sample_mode:
-            upper_thresh = self.sample_thresh
-            certainty = certainty.clone()
-            certainty[certainty > upper_thresh] = 1
-        matches, certainty = (
-            matches.reshape(-1, 4),
-            certainty.reshape(-1),
-        )
-        expansion_factor = 4 if "balanced" in self.sample_mode else 1
-        good_samples = torch.multinomial(
-            certainty,
-            num_samples=min(expansion_factor * num, len(certainty)),
-            replacement=False,
-        )
-        good_matches, good_certainty = matches[good_samples], certainty[good_samples]
-        if "balanced" not in self.sample_mode:
-            return good_matches, good_certainty
-        density = kde(good_matches, std=0.1)
-        p = 1 / (density + 1)
-        p[density < 10] = (
-            1e-7  # Basically should have at least 10 perfect neighbours, or around 100 ok ones
-        )
-        balanced_samples = torch.multinomial(
-            p, num_samples=min(num, len(good_certainty)), replacement=False
-        )
-        return good_matches[balanced_samples], good_certainty[balanced_samples]
-
-    def forward(self, batch, batched=True, upsample=False, scale_factor=1):
-        feature_pyramid = self.extract_backbone_features(
-            batch, batched=batched, upsample=upsample
-        )
-        if batched:
-            f_q_pyramid = {
-                scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
-            }
-            f_s_pyramid = {
-                scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
-            }
-        else:
-            f_q_pyramid, f_s_pyramid = feature_pyramid
-        corresps = self.decoder(
-            f_q_pyramid,
-            f_s_pyramid,
-            upsample=upsample,
-            **(batch["corresps"] if "corresps" in batch else {}),
-            scale_factor=scale_factor,
-        )
-
-        return corresps
-
-    def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1):
-        feature_pyramid = self.extract_backbone_features(
-            batch, batched=batched, upsample=upsample
-        )
-        f_q_pyramid = feature_pyramid
-        f_s_pyramid = {
-            scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0)
-            for scale, f_scale in feature_pyramid.items()
-        }
-        corresps = self.decoder(
-            f_q_pyramid,
-            f_s_pyramid,
-            upsample=upsample,
-            **(batch["corresps"] if "corresps" in batch else {}),
-            scale_factor=scale_factor,
-        )
-        return corresps
-
-    def conf_from_fb_consistency(self, flow_forward, flow_backward, th=2):
-        # assumes that flow forward is of shape (..., H, W, 2)
-        has_batch = False
-        if len(flow_forward.shape) == 3:
-            flow_forward, flow_backward = flow_forward[None], flow_backward[None]
-        else:
-            has_batch = True
-        H, W = flow_forward.shape[-3:-1]
-        th_n = 2 * th / max(H, W)
-        coords = torch.stack(
-            torch.meshgrid(
-                torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
-                torch.linspace(-1 + 1 / H, 1 - 1 / H, H),
-                indexing="xy",
-            ),
-            dim=-1,
-        ).to(flow_forward.device)
-        coords_fb = F.grid_sample(
-            flow_backward.permute(0, 3, 1, 2),
-            flow_forward,
-            align_corners=False,
-            mode="bilinear",
-        ).permute(0, 2, 3, 1)
-        diff = (coords - coords_fb).norm(dim=-1)
-        in_th = (diff < th_n).float()
-        if not has_batch:
-            in_th = in_th[0]
-        return in_th
-
-    def to_pixel_coordinates(self, coords, H_A, W_A, H_B=None, W_B=None):
-        if coords.shape[-1] == 2:
-            return self._to_pixel_coordinates(coords, H_A, W_A)
-
-        if isinstance(coords, (list, tuple)):
-            kpts_A, kpts_B = coords[0], coords[1]
-        else:
-            kpts_A, kpts_B = coords[..., :2], coords[..., 2:]
-        return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(
-            kpts_B, H_B, W_B
-        )
-
-    def _to_pixel_coordinates(self, coords, H, W):
-        kpts = torch.stack(
-            (W / 2 * (coords[..., 0] + 1), H / 2 * (coords[..., 1] + 1)), axis=-1
-        )
-        return kpts
-
-    def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
-        if isinstance(coords, (list, tuple)):
-            kpts_A, kpts_B = coords[0], coords[1]
-        else:
-            kpts_A, kpts_B = coords[..., :2], coords[..., 2:]
-        kpts_A = torch.stack(
-            (2 / W_A * kpts_A[..., 0] - 1, 2 / H_A * kpts_A[..., 1] - 1), axis=-1
-        )
-        kpts_B = torch.stack(
-            (2 / W_B * kpts_B[..., 0] - 1, 2 / H_B * kpts_B[..., 1] - 1), axis=-1
-        )
-        return kpts_A, kpts_B
-
-    def match_keypoints(
-        self,
-        x_A,
-        x_B,
-        warp,
-        certainty,
-        return_tuple=True,
-        return_inds=False,
-        max_dist=0.005,
-        cert_th=0,
-    ):
-        x_A_to_B = F.grid_sample(
-            warp[..., -2:].permute(2, 0, 1)[None],
-            x_A[None, None],
-            align_corners=False,
-            mode="bilinear",
-        )[0, :, 0].mT
-        cert_A_to_B = F.grid_sample(
-            certainty[None, None, ...],
-            x_A[None, None],
-            align_corners=False,
-            mode="bilinear",
-        )[0, 0, 0]
-        D = torch.cdist(x_A_to_B, x_B)
-        inds_A, inds_B = torch.nonzero(
-            (D == D.min(dim=-1, keepdim=True).values)
-            * (D == D.min(dim=-2, keepdim=True).values)
-            * (cert_A_to_B[:, None] > cert_th)
-            * (D < max_dist),
-            as_tuple=True,
-        )
-
-        if return_tuple:
-            if return_inds:
-                return inds_A, inds_B
-            else:
-                return x_A[inds_A], x_B[inds_B]
-        else:
-            if return_inds:
-                return torch.cat((inds_A, inds_B), dim=-1)
-            else:
-                return torch.cat((x_A[inds_A], x_B[inds_B]), dim=-1)
-    
-    def _get_device(self):
-        # let's hope this is same for all weights
-        return self.encoder.cnn.layers[0].weight.device
-
-    @torch.inference_mode()
-    def match(
-        self,
-        im_A_input,
-        im_B_input,
-        *args,
-        im_A_high_res=None,
-        im_B_high_res=None,
-        batched=True,
-        device=None,
-    ):
-        self.train(False)
-        if not batched:
-            raise ValueError("batched must be True, non-batched inference is no longer supported.")
-        if device is None and not isinstance(im_A_input, torch.Tensor):
-            device = self._get_device()
-        elif device is None and isinstance(im_A_input, torch.Tensor):
-            device = im_A_input.device
-
-        # Check if inputs are file paths or already loaded images
-        im_A = _check_input(im_A_input)
-        im_B = _check_input(im_B_input)
-        symmetric = self.symmetric
-        ws = self.w_resized
-        hs = self.h_resized
-
-        scale_factor = math.sqrt(hs * ws / (560**2)) # divide by training resolution
-        if isinstance(im_A, Image.Image) and isinstance(im_B, Image.Image):
-            b = 1
-            w, h = im_A.size
-            w2, h2 = im_B.size
-            # Get images in good format
-
-            test_transform = get_tuple_transform_ops(
-                resize=(hs, ws), normalize=True, clahe=False
-            )
-            im_A, im_B = test_transform((im_A, im_B))
-            batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
-        elif isinstance(im_A, torch.Tensor) and isinstance(im_B, torch.Tensor):
-            b, c, h, w = im_A.shape
-            b, c, h2, w2 = im_B.shape
-            assert w == w2 and h == h2, "For batched images we assume same size"
-            batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
-            if h != self.h_resized or self.w_resized != w:
-                warn(
-                    "Model resolution and batch resolution differ, may produce unexpected results"
-                )
-            hs, ws = h, w
-        else:
-            raise ValueError(f"Unsupported input type: {type(im_A)=} and {type(im_B)=}")
-        finest_scale = 1
-        # Run matcher
-        if symmetric:
-            corresps = self.forward_symmetric(batch, scale_factor=scale_factor)
-        else:
-            corresps = self(batch, batched=True, scale_factor=scale_factor)
-
-        if self.upsample_preds:
-            hs, ws = self.upsample_res
-
-        if self.attenuate_cert:
-            low_res_certainty = F.interpolate(
-                corresps[16]["certainty"],
-                size=(hs, ws),
-                align_corners=False,
-                mode="bilinear",
-            )
-            cert_clamp = 0
-            factor = 0.5
-            low_res_certainty = (
-                factor * low_res_certainty * (low_res_certainty < cert_clamp)
-            )
-
-        finest_corresps = corresps[finest_scale]
-        if self.upsample_preds and im_A_high_res is None and im_B_high_res is None:
-            torch.cuda.empty_cache()
-            test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True)
-            if isinstance(im_A_input, (str, os.PathLike)):
-                im_A, im_B = test_transform(
-                    (
-                        Image.open(im_A_input).convert("RGB"),
-                        Image.open(im_B_input).convert("RGB"),
-                    )
-                )
-            else:
-                assert isinstance(im_A_input, Image.Image), f"Unsupported input type: {type(im_A_input)=}"
-                assert isinstance(im_B_input, Image.Image), f"Unsupported input type: {type(im_B_input)=}"
-                im_A, im_B = test_transform((im_A_input, im_B_input))
-
-            im_A, im_B = im_A[None].to(device), im_B[None].to(device)
-            
-            batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
-        elif self.upsample_preds and im_A_high_res is not None and im_B_high_res is not None:
-            batch = {"im_A": im_A_high_res, "im_B": im_B_high_res, "corresps": finest_corresps}
-        elif self.upsample_preds:
-            raise ValueError(f"Invalid upsample_preds and high_res inputs with {im_A=},{im_A_high_res=},{im_B=} and {im_B_high_res=}")
-
-        if self.upsample_preds:
-            scale_factor = math.sqrt(
-                self.upsample_res[0]
-                * self.upsample_res[1]
-                / (560**2) # divide by training resolution
-            )
-            if symmetric:
-                corresps = self.forward_symmetric(
-                    batch, upsample=True, batched=True, scale_factor=scale_factor
-                )
-            else:
-                corresps = self(
-                    batch, batched=True, upsample=True, scale_factor=scale_factor
-                )
-
-        im_A_to_im_B = corresps[finest_scale]["flow"]
-        certainty = corresps[finest_scale]["certainty"] - (
-            low_res_certainty if self.attenuate_cert else 0
-        )
-        if finest_scale != 1:
-            im_A_to_im_B = F.interpolate(
-                im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
-            )
-            certainty = F.interpolate(
-                certainty, size=(hs, ws), align_corners=False, mode="bilinear"
-            )
-        im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
-        # Create im_A meshgrid
-        im_A_coords = torch.meshgrid(
-            (
-                torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
-                torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
-            ),
-            indexing="ij",
-        )
-        im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
-        im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
-        certainty = certainty.sigmoid()  # logits -> probs
-        im_A_coords = im_A_coords.permute(0, 2, 3, 1)
-        if (im_A_to_im_B.abs() > 1).any() and True:
-            wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
-            certainty[wrong[:, None]] = 0
-        im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
-        if symmetric:
-            A_to_B, B_to_A = im_A_to_im_B.chunk(2)
-            q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
-            im_B_coords = im_A_coords
-            s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
-            warp = torch.cat((q_warp, s_warp), dim=2)
-            certainty = torch.cat(certainty.chunk(2), dim=3)
-        else:
-            warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
-        if batched:
-            return (warp, certainty[:, 0])
-        else:
-            return (
-                warp[0],
-                certainty[0, 0],
-            )
-
-    def visualize_warp(
-        self,
-        warp,
-        certainty,
-        im_A=None,
-        im_B=None,
-        im_A_path=None,
-        im_B_path=None,
-        device="cuda",
-        symmetric=True,
-        save_path=None,
-        unnormalize=False,
-    ):
-        # assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
-        H, W2, _ = warp.shape
-        W = W2 // 2 if symmetric else W2
-        if im_A is None:
-            from PIL import Image
-
-            im_A, im_B = (
-                Image.open(im_A_path).convert("RGB"),
-                Image.open(im_B_path).convert("RGB"),
-            )
-        if not isinstance(im_A, torch.Tensor):
-            im_A = im_A.resize((W, H))
-            im_B = im_B.resize((W, H))
-            x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
-            if symmetric:
-                x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
-        else:
-            if symmetric:
-                x_A = im_A
-            x_B = im_B
-        im_A_transfer_rgb = F.grid_sample(
-            x_B[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False
-        )[0]
-        if symmetric:
-            im_B_transfer_rgb = F.grid_sample(
-                x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
-            )[0]
-            warp_im = torch.cat((im_A_transfer_rgb, im_B_transfer_rgb), dim=2)
-            white_im = torch.ones((H, 2 * W), device=device)
-        else:
-            warp_im = im_A_transfer_rgb
-            white_im = torch.ones((H, W), device=device)
-        vis_im = certainty * warp_im + (1 - certainty) * white_im
-        if save_path is not None:
-            from romatch.utils import tensor_to_pil
-
-            tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
-        return vis_im

+ 0 - 110
python/RoMa/romatch/models/model_zoo/__init__.py

@@ -1,110 +0,0 @@
-from typing import Union
-import torch
-from .roma_models import roma_model,roma_model_pad, tiny_roma_v1_model
-
-
-weight_urls = {
-    "romatch": {
-        "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
-        "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
-    },
-    "tiny_roma_v1": {
-        "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth",
-    },
-    "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",  # hopefully this doesnt change :D
-}
-
-
-def tiny_roma_v1_outdoor(device, weights=None, xfeat=None):
-    if weights is None:
-        weights = torch.hub.load_state_dict_from_url(
-            weight_urls["tiny_roma_v1"]["outdoor"], map_location=device
-        )
-    if xfeat is None:
-        xfeat = torch.hub.load(
-            "verlab/accelerated_features", "XFeat", pretrained=True, top_k=4096
-        ).net
-
-    return tiny_roma_v1_model(weights=weights, xfeat=xfeat).to(device)
-
-
-def roma_outdoor(
-    device,
-    weights=None,
-    dinov2_weights=None,
-    coarse_res: Union[int, tuple[int, int]] = 560,
-    upsample_res: Union[int, tuple[int, int]] = 864,
-    amp_dtype: torch.dtype = torch.float16,
-    symmetric=True,
-    use_custom_corr=True,
-    upsample_preds=True,
-    with_padding=False,
-    do_compile=False,
-):
-    if torch.get_float32_matmul_precision() != "highest":
-        raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
-
-    if weights is None:
-        weights = torch.hub.load_state_dict_from_url(
-            weight_urls["romatch"]["outdoor"], map_location=device
-        )
-    if dinov2_weights is None:
-        dinov2_weights = torch.hub.load_state_dict_from_url(
-            weight_urls["dinov2"], map_location=device
-        )
-    model_init = roma_model if not with_padding else roma_model_pad
-    model = model_init(
-        resolution=coarse_res,
-        upsample_preds=upsample_preds,
-        weights=weights,
-        dinov2_weights=dinov2_weights,
-        device=device,
-        amp_dtype=amp_dtype,
-        symmetric=symmetric,
-        use_custom_corr=use_custom_corr,
-        upsample_res=upsample_res,
-    )
-    if do_compile:
-        model.compile()
-    return model
-
-
-def roma_indoor(
-    device,
-    weights=None,
-    dinov2_weights=None,
-    coarse_res: Union[int, tuple[int, int]] = 560,
-    upsample_res: Union[int, tuple[int, int]] = 864,
-    amp_dtype: torch.dtype = torch.float16,
-    symmetric=True,
-    use_custom_corr=True,
-    upsample_preds=True,
-    with_padding=False,
-    do_compile=False,
-):
-    if torch.get_float32_matmul_precision() != "highest":
-        raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
-
-    if weights is None:
-        weights = torch.hub.load_state_dict_from_url(
-            weight_urls["romatch"]["indoor"], map_location=device
-        )
-    if dinov2_weights is None:
-        dinov2_weights = torch.hub.load_state_dict_from_url(
-            weight_urls["dinov2"], map_location=device
-        )
-    model_init = roma_model if not with_padding else roma_model_pad
-    model = model_init(
-        resolution=coarse_res,
-        upsample_preds=upsample_preds,
-        weights=weights,
-        dinov2_weights=dinov2_weights,
-        device=device,
-        amp_dtype=amp_dtype,
-        symmetric=symmetric,
-        use_custom_corr=use_custom_corr,
-        upsample_res=upsample_res,
-    )
-    if do_compile:
-        model.compile()
-    return model

+ 0 - 399
python/RoMa/romatch/models/model_zoo/roma_models.py

@@ -1,399 +0,0 @@
-import sys
-import warnings
-from functools import partial
-
-import torch
-import torch.nn as nn
-from loguru import logger
-
-from romatch.models.encoders import CNNandDinov2
-from romatch.models.matcher import (
-    GP,
-    ConvRefiner,
-    CosKernel,
-    Decoder,
-    RegressionMatcher,
-)
-from romatch.models.tiny import TinyRoMa
-from romatch.models.transformer import Block, MemEffAttention, TransformerDecoder
-
-
-def tiny_roma_v1_model(
-    weights=None, freeze_xfeat=False, exact_softmax=False, xfeat=None
-):
-    model = TinyRoMa(
-        xfeat=xfeat, freeze_xfeat=freeze_xfeat, exact_softmax=exact_softmax
-    )
-    if weights is not None:
-        model.load_state_dict(weights)
-    return model
-
-def pad_refiner_state_dict(state_dict_old,state_dict_pad):
-    for key in state_dict_pad.keys():
-        if key.startswith('decoder.conv_refiner'):
-            param = state_dict_old[key]
-            shape_old = param.shape
-            shape_pad = state_dict_pad[key].shape
-            if shape_old != shape_pad:
-                new_param = torch.zeros(shape_pad, device=param.device, dtype=param.dtype)
-                slices = tuple(slice(0, s) for s in shape_old)
-                new_param[slices] = param
-                state_dict_old[key] = new_param
-    return state_dict_old
-
-def roma_model_pad(
-    resolution,
-    upsample_preds,
-    device=None,
-    weights=None,
-    dinov2_weights=None,
-    amp_dtype: torch.dtype = torch.float16,
-    use_custom_corr=True,
-    symmetric=True,
-    upsample_res=None,
-    sample_thresh=0.05,
-    sample_mode="threshold_balanced",
-    attenuate_cert = True,
-    refiner_channels= [1384, 1144, 576, 144, 24],
-    **kwargs,
-):
-    if sys.platform != "linux":
-        use_custom_corr = False
-        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
-    if isinstance(resolution, int):
-        resolution = (resolution, resolution)
-    if isinstance(upsample_res, int):
-        upsample_res = (upsample_res, upsample_res)
-
-    if str(device) == "cpu":
-        amp_dtype = torch.float32
-
-    assert resolution[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
-    assert resolution[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
-
-    logger.info(
-        f"Using coarse resolution {resolution}, and upsample res {upsample_res}"
-    )
-
-    if sys.platform != "linux":
-        use_custom_corr = False
-        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
-    warnings.filterwarnings(
-        "ignore", category=UserWarning, message="TypedStorage is deprecated"
-    )
-    gp_dim = 512
-    feat_dim = 512
-    decoder_dim = gp_dim + feat_dim
-    cls_to_coord_res = 64
-    coordinate_decoder = TransformerDecoder(
-        nn.Sequential(
-            *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
-        ),
-        decoder_dim,
-        cls_to_coord_res**2 + 1,
-        is_classifier=True,
-        amp=True,
-        pos_enc=False,
-    )
-    dw = True
-    hidden_blocks = 8
-    kernel_size = 5
-    displacement_emb = "linear"
-    disable_local_corr_grad = True
-    partial_conv_refiner = partial(
-        ConvRefiner,
-        kernel_size=kernel_size,
-        dw=dw,
-        hidden_blocks=hidden_blocks,
-        displacement_emb=displacement_emb,
-        corr_in_other=True,
-        amp=True,
-        disable_local_corr_grad=disable_local_corr_grad,
-        bn_momentum=0.01,
-        use_custom_corr=use_custom_corr,
-    )
-
-    conv_refiner = nn.ModuleDict(
-        {
-            "16": partial_conv_refiner(
-                refiner_channels[0],
-                refiner_channels[0],
-                2 + 1,
-                displacement_emb_dim=128,
-                local_corr_radius=7,
-            ),
-            "8": partial_conv_refiner(
-                refiner_channels[1],
-                refiner_channels[1],
-                2 + 1,
-                displacement_emb_dim=64,
-                local_corr_radius=3,
-            ),
-            "4": partial_conv_refiner(
-                refiner_channels[2],
-                refiner_channels[2],
-                2 + 1,
-                displacement_emb_dim=32,
-                local_corr_radius=2,
-            ),
-            "2": partial_conv_refiner(
-                refiner_channels[3],
-                refiner_channels[3],
-                2 + 1,
-                displacement_emb_dim=16,
-            ),
-            "1": partial_conv_refiner(
-                refiner_channels[4],
-                refiner_channels[4],
-                2 + 1,
-                displacement_emb_dim=6,
-            ),
-        }
-    )
-    kernel_temperature = 0.2
-    learn_temperature = False
-    no_cov = True
-    kernel = CosKernel
-    only_attention = False
-    basis = "fourier"
-    gp16 = GP(
-        kernel,
-        T=kernel_temperature,
-        learn_temperature=learn_temperature,
-        only_attention=only_attention,
-        gp_dim=gp_dim,
-        basis=basis,
-        no_cov=no_cov,
-    )
-    gps = nn.ModuleDict({"16": gp16})
-    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
-    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
-    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
-    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
-    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
-    proj = nn.ModuleDict(
-        {
-            "16": proj16,
-            "8": proj8,
-            "4": proj4,
-            "2": proj2,
-            "1": proj1,
-        }
-    )
-    displacement_dropout_p = 0.0
-    gm_warp_dropout_p = 0.0
-    decoder = Decoder(
-        coordinate_decoder,
-        gps,
-        proj,
-        conv_refiner,
-        detach=True,
-        scales=["16", "8", "4", "2", "1"],
-        displacement_dropout_p=displacement_dropout_p,
-        gm_warp_dropout_p=gm_warp_dropout_p,
-    )
-
-    encoder = CNNandDinov2(
-        cnn_kwargs=dict(pretrained=False, amp=True),
-        amp=True,
-        dinov2_weights=dinov2_weights,
-        amp_dtype=amp_dtype,
-    )
-    h, w = resolution
-    
-    matcher = RegressionMatcher(
-        encoder,
-        decoder,
-        h=h,
-        w=w,
-        upsample_preds=upsample_preds,
-        upsample_res=upsample_res,
-        symmetric=symmetric,
-        attenuate_cert=attenuate_cert,
-        sample_mode=sample_mode,
-        sample_thresh=sample_thresh,
-        **kwargs,
-    ).to(device)
-    if weights is not None:
-        state_dict_pad = matcher.state_dict()
-        weights = pad_refiner_state_dict(weights,state_dict_pad)
-        del state_dict_pad
-
-    matcher.load_state_dict(weights)
-    return matcher
-
-
-def roma_model(
-    resolution,
-    upsample_preds,
-    device=None,
-    weights=None,
-    dinov2_weights=None,
-    amp_dtype: torch.dtype = torch.float16,
-    use_custom_corr=True,
-    symmetric=True,
-    upsample_res=None,
-    sample_thresh=0.05,
-    sample_mode="threshold_balanced",
-    attenuate_cert = True,
-    **kwargs,
-):
-    if sys.platform != "linux":
-        use_custom_corr = False
-        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
-    if isinstance(resolution, int):
-        resolution = (resolution, resolution)
-    if isinstance(upsample_res, int):
-        upsample_res = (upsample_res, upsample_res)
-
-    if str(device) == "cpu":
-        amp_dtype = torch.float32
-
-    assert resolution[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
-    assert resolution[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
-
-    logger.info(
-        f"Using coarse resolution {resolution}, and upsample res {upsample_res}"
-    )
-
-    if sys.platform != "linux":
-        use_custom_corr = False
-        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
-    warnings.filterwarnings(
-        "ignore", category=UserWarning, message="TypedStorage is deprecated"
-    )
-    gp_dim = 512
-    feat_dim = 512
-    decoder_dim = gp_dim + feat_dim
-    cls_to_coord_res = 64
-    coordinate_decoder = TransformerDecoder(
-        nn.Sequential(
-            *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
-        ),
-        decoder_dim,
-        cls_to_coord_res**2 + 1,
-        is_classifier=True,
-        amp=True,
-        pos_enc=False,
-    )
-    dw = True
-    hidden_blocks = 8
-    kernel_size = 5
-    displacement_emb = "linear"
-    disable_local_corr_grad = True
-    partial_conv_refiner = partial(
-        ConvRefiner,
-        kernel_size=kernel_size,
-        dw=dw,
-        hidden_blocks=hidden_blocks,
-        displacement_emb=displacement_emb,
-        corr_in_other=True,
-        amp=True,
-        disable_local_corr_grad=disable_local_corr_grad,
-        bn_momentum=0.01,
-        use_custom_corr=use_custom_corr,
-    )
-
-    conv_refiner = nn.ModuleDict(
-        {
-            "16": partial_conv_refiner(
-                2 * 512 + 128 + (2 * 7 + 1) ** 2,
-                2 * 512 + 128 + (2 * 7 + 1) ** 2,
-                2 + 1,
-                displacement_emb_dim=128,
-                local_corr_radius=7,
-            ),
-            "8": partial_conv_refiner(
-                2 * 512 + 64 + (2 * 3 + 1) ** 2,
-                2 * 512 + 64 + (2 * 3 + 1) ** 2,
-                2 + 1,
-                displacement_emb_dim=64,
-                local_corr_radius=3,
-            ),
-            "4": partial_conv_refiner(
-                2 * 256 + 32 + (2 * 2 + 1) ** 2,
-                2 * 256 + 32 + (2 * 2 + 1) ** 2,
-                2 + 1,
-                displacement_emb_dim=32,
-                local_corr_radius=2,
-            ),
-            "2": partial_conv_refiner(
-                2 * 64 + 16,
-                128 + 16,
-                2 + 1,
-                displacement_emb_dim=16,
-            ),
-            "1": partial_conv_refiner(
-                2 * 9 + 6,
-                24,
-                2 + 1,
-                displacement_emb_dim=6,
-            ),
-        }
-    )
-    kernel_temperature = 0.2
-    learn_temperature = False
-    no_cov = True
-    kernel = CosKernel
-    only_attention = False
-    basis = "fourier"
-    gp16 = GP(
-        kernel,
-        T=kernel_temperature,
-        learn_temperature=learn_temperature,
-        only_attention=only_attention,
-        gp_dim=gp_dim,
-        basis=basis,
-        no_cov=no_cov,
-    )
-    gps = nn.ModuleDict({"16": gp16})
-    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
-    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
-    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
-    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
-    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
-    proj = nn.ModuleDict(
-        {
-            "16": proj16,
-            "8": proj8,
-            "4": proj4,
-            "2": proj2,
-            "1": proj1,
-        }
-    )
-    displacement_dropout_p = 0.0
-    gm_warp_dropout_p = 0.0
-    decoder = Decoder(
-        coordinate_decoder,
-        gps,
-        proj,
-        conv_refiner,
-        detach=True,
-        scales=["16", "8", "4", "2", "1"],
-        displacement_dropout_p=displacement_dropout_p,
-        gm_warp_dropout_p=gm_warp_dropout_p,
-    )
-
-    encoder = CNNandDinov2(
-        cnn_kwargs=dict(pretrained=False, amp=True),
-        amp=True,
-        dinov2_weights=dinov2_weights,
-        amp_dtype=amp_dtype,
-    )
-    h, w = resolution
-    
-    matcher = RegressionMatcher(
-        encoder,
-        decoder,
-        h=h,
-        w=w,
-        upsample_preds=upsample_preds,
-        upsample_res=upsample_res,
-        symmetric=symmetric,
-        attenuate_cert=attenuate_cert,
-        sample_mode=sample_mode,
-        sample_thresh=sample_thresh,
-        **kwargs,
-    ).to(device)
-    matcher.load_state_dict(weights)
-    return matcher

+ 0 - 304
python/RoMa/romatch/models/tiny.py

@@ -1,304 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import os
-import torch
-from pathlib import Path
-import math
-import numpy as np
-
-from torch import nn
-from PIL import Image
-from torchvision.transforms import ToTensor
-from romatch.utils.kde import kde
-
-class BasicLayer(nn.Module):
-    """
-        Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
-    """
-    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, relu = True):
-        super().__init__()
-        self.layer = nn.Sequential(
-                                        nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
-                                        nn.BatchNorm2d(out_channels, affine=False),
-                                        nn.ReLU(inplace = True) if relu else nn.Identity()
-                                    )
-
-    def forward(self, x):
-        return self.layer(x)
-
-class TinyRoMa(nn.Module):
-    """
-        Implementation of architecture described in 
-        "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
-    """
-
-    def __init__(self, xfeat = None, 
-                 freeze_xfeat = True, 
-                 sample_mode = "threshold_balanced", 
-                 symmetric = False, 
-                 exact_softmax = False):
-        super().__init__()
-        del xfeat.heatmap_head, xfeat.keypoint_head, xfeat.fine_matcher
-        if freeze_xfeat:
-            xfeat.train(False)
-            self.xfeat = [xfeat]# hide params from ddp
-        else:
-            self.xfeat = nn.ModuleList([xfeat])
-        self.freeze_xfeat = freeze_xfeat
-        match_dim = 256
-        self.coarse_matcher = nn.Sequential(
-            BasicLayer(64+64+2, match_dim,),
-            BasicLayer(match_dim, match_dim,), 
-            BasicLayer(match_dim, match_dim,), 
-            BasicLayer(match_dim, match_dim,), 
-            nn.Conv2d(match_dim, 3, kernel_size=1, bias=True, padding=0))
-        fine_match_dim = 64
-        self.fine_matcher = nn.Sequential(
-            BasicLayer(24+24+2, fine_match_dim,),
-            BasicLayer(fine_match_dim, fine_match_dim,), 
-            BasicLayer(fine_match_dim, fine_match_dim,), 
-            BasicLayer(fine_match_dim, fine_match_dim,), 
-            nn.Conv2d(fine_match_dim, 3, kernel_size=1, bias=True, padding=0),)
-        self.sample_mode = sample_mode
-        self.sample_thresh = 0.05
-        self.symmetric = symmetric
-        self.exact_softmax = exact_softmax
-    
-    @property
-    def device(self):
-        return self.fine_matcher[-1].weight.device
-    
-    def preprocess_tensor(self, x):
-        """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
-        H, W = x.shape[-2:]
-        _H, _W = (H//32) * 32, (W//32) * 32
-        rh, rw = H/_H, W/_W
-
-        x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
-        return x, rh, rw        
-    
-    def forward_single(self, x):
-        with torch.inference_mode(self.freeze_xfeat or not self.training):
-            xfeat = self.xfeat[0]
-            with torch.no_grad():
-                x = x.mean(dim=1, keepdim = True)
-                x = xfeat.norm(x)
-
-            #main backbone
-            x1 = xfeat.block1(x)
-            x2 = xfeat.block2(x1 + xfeat.skip1(x))
-            x3 = xfeat.block3(x2)
-            x4 = xfeat.block4(x3)
-            x5 = xfeat.block5(x4)
-            x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
-            x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
-            feats = xfeat.block_fusion( x3 + x4 + x5 )
-        if self.freeze_xfeat:
-            return x2.clone(), feats.clone()
-        return x2, feats
-
-    def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
-        if coords.shape[-1] == 2:
-            return self._to_pixel_coordinates(coords, H_A, W_A) 
-        
-        if isinstance(coords, (list, tuple)):
-            kpts_A, kpts_B = coords[0], coords[1]
-        else:
-            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
-        return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)
-
-    def _to_pixel_coordinates(self, coords, H, W):
-        kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
-        return kpts
-    
-    def pos_embed(self, corr_volume: torch.Tensor):
-        B, H1, W1, H0, W0 = corr_volume.shape 
-        grid = torch.stack(
-                torch.meshgrid(
-                    torch.linspace(-1+1/W1,1-1/W1, W1), 
-                    torch.linspace(-1+1/H1,1-1/H1, H1), 
-                    indexing = "xy"), 
-                dim = -1).float().to(corr_volume).reshape(H1*W1, 2)
-        down = 4
-        if not self.training and not self.exact_softmax:
-            grid_lr = torch.stack(
-                torch.meshgrid(
-                    torch.linspace(-1+down/W1,1-down/W1, W1//down), 
-                    torch.linspace(-1+down/H1,1-down/H1, H1//down), 
-                    indexing = "xy"), 
-                dim = -1).float().to(corr_volume).reshape(H1*W1 //down**2, 2)
-            cv = corr_volume
-            best_match = cv.reshape(B,H1*W1,H0,W0).argmax(dim=1) # B, HW, H, W
-            P_lowres = torch.cat((cv[:,::down,::down].reshape(B,H1*W1 // down**2,H0,W0), best_match[:,None]),dim=1).softmax(dim=1)
-            pos_embeddings = torch.einsum('bchw,cd->bdhw', P_lowres[:,:-1], grid_lr)
-            pos_embeddings += P_lowres[:,-1] * grid[best_match].permute(0,3,1,2)
-            #print("hej")
-        else:
-            P = corr_volume.reshape(B,H1*W1,H0,W0).softmax(dim=1) # B, HW, H, W
-            pos_embeddings = torch.einsum('bchw,cd->bdhw', P, grid)
-        return pos_embeddings
-    
-    def visualize_warp(self, warp, certainty, im_A = None, im_B = None, 
-                       im_A_path = None, im_B_path = None, symmetric = True, save_path = None, unnormalize = False):
-        device = warp.device
-        H,W2,_ = warp.shape
-        W = W2//2 if symmetric else W2
-        if im_A is None:
-            from PIL import Image
-            im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
-        if not isinstance(im_A, torch.Tensor):
-            im_A = im_A.resize((W,H))
-            im_B = im_B.resize((W,H))    
-            x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
-            if symmetric:
-                x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
-        else:
-            if symmetric:
-                x_A = im_A
-            x_B = im_B
-        im_A_transfer_rgb = F.grid_sample(
-        x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
-        )[0]
-        if symmetric:
-            im_B_transfer_rgb = F.grid_sample(
-            x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
-            )[0]
-            warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
-            white_im = torch.ones((H,2*W),device=device)
-        else:
-            warp_im = im_A_transfer_rgb
-            white_im = torch.ones((H, W), device = device)
-        vis_im = certainty * warp_im + (1 - certainty) * white_im
-        if save_path is not None:
-            from romatch.utils import tensor_to_pil
-            tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
-        return vis_im
-     
-    def corr_volume(self, feat0, feat1):
-        """
-            input:
-                feat0 -> torch.Tensor(B, C, H, W)
-                feat1 -> torch.Tensor(B, C, H, W)
-            return:
-                corr_volume -> torch.Tensor(B, H, W, H, W)
-        """
-        B, C, H0, W0 = feat0.shape
-        B, C, H1, W1 = feat1.shape
-        feat0 = feat0.view(B, C, H0*W0)
-        feat1 = feat1.view(B, C, H1*W1)
-        corr_volume = torch.einsum('bci,bcj->bji', feat0, feat1).reshape(B, H1, W1, H0 , W0)/math.sqrt(C) #16*16*16
-        return corr_volume
-    
-    @torch.inference_mode()
-    def match_from_path(self, im0_path, im1_path):
-        device = self.device
-        im0 = ToTensor()(Image.open(im0_path))[None].to(device)
-        im1 = ToTensor()(Image.open(im1_path))[None].to(device)
-        return self.match(im0, im1, batched = False)
-    
-    @torch.inference_mode()
-    def match(self, im0, im1, *args, batched = True):
-        # stupid
-        if isinstance(im0, (str, Path)):
-            return self.match_from_path(im0, im1)
-        elif isinstance(im0, Image.Image):
-            batched = False
-            device = self.device
-            im0 = ToTensor()(im0)[None].to(device)
-            im1 = ToTensor()(im1)[None].to(device)
- 
-        B,C,H0,W0 = im0.shape
-        B,C,H1,W1 = im1.shape
-        self.train(False)
-        corresps = self.forward({"im_A":im0, "im_B":im1})
-        #return 1,1
-        flow = F.interpolate(
-            corresps[4]["flow"], 
-            size = (H0, W0), 
-            mode = "bilinear", align_corners = False).permute(0,2,3,1).reshape(B,H0,W0,2)
-        grid = torch.stack(
-            torch.meshgrid(
-                torch.linspace(-1+1/W0,1-1/W0, W0), 
-                torch.linspace(-1+1/H0,1-1/H0, H0), 
-                indexing = "xy"), 
-            dim = -1).float().to(flow.device).expand(B, H0, W0, 2)
-        
-        certainty = F.interpolate(corresps[4]["certainty"], size = (H0,W0), mode = "bilinear", align_corners = False)
-        warp, cert = torch.cat((grid, flow), dim = -1), certainty[:,0].sigmoid()
-        if batched:
-            return warp, cert
-        else:
-            return warp[0], cert[0]
-
-    def sample(
-        self,
-        matches,
-        certainty,
-        num=5_000,
-    ):
-        H,W,_ = matches.shape
-        if "threshold" in self.sample_mode:
-            upper_thresh = self.sample_thresh
-            certainty = certainty.clone()
-            certainty[certainty > upper_thresh] = 1
-        matches, certainty = (
-            matches.reshape(-1, 4),
-            certainty.reshape(-1),
-        )
-        expansion_factor = 4 if "balanced" in self.sample_mode else 1
-        good_samples = torch.multinomial(certainty, 
-                        num_samples = min(expansion_factor*num, len(certainty)), 
-                        replacement=False)
-        good_matches, good_certainty = matches[good_samples], certainty[good_samples]
-        if "balanced" not in self.sample_mode:
-            return good_matches, good_certainty 
-        use_half = True if matches.device.type == "cuda" else False
-        down = 1 if matches.device.type == "cuda" else 8
-        density = kde(good_matches, std=0.1, half = use_half, down = down)
-        p = 1 / (density+1)
-        p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
-        balanced_samples = torch.multinomial(p, 
-                        num_samples = min(num,len(good_certainty)), 
-                        replacement=False)
-        return good_matches[balanced_samples], good_certainty[balanced_samples]
-        
-            
-    def forward(self, batch):
-        """
-            input:
-                x -> torch.Tensor(B, C, H, W) grayscale or rgb images
-            return:
-
-        """
-        im0 = batch["im_A"]
-        im1 = batch["im_B"]
-        corresps = {}
-        im0, rh0, rw0 = self.preprocess_tensor(im0)
-        im1, rh1, rw1 = self.preprocess_tensor(im1)
-        B, C, H0, W0 = im0.shape
-        B, C, H1, W1 = im1.shape
-        to_normalized = torch.tensor((2/W1, 2/H1, 1)).to(im0.device)[None,:,None,None]
- 
-        if im0.shape[-2:] == im1.shape[-2:]:
-            x = torch.cat([im0, im1], dim=0)
-            x = self.forward_single(x)
-            feats_x0_c, feats_x1_c = x[1].chunk(2)
-            feats_x0_f, feats_x1_f = x[0].chunk(2)
-        else:
-            feats_x0_f, feats_x0_c = self.forward_single(im0)
-            feats_x1_f, feats_x1_c = self.forward_single(im1)
-        corr_volume = self.corr_volume(feats_x0_c, feats_x1_c)
-        coarse_warp = self.pos_embed(corr_volume)
-        coarse_matches = torch.cat((coarse_warp, torch.zeros_like(coarse_warp[:,-1:])), dim=1)
-        feats_x1_c_warped = F.grid_sample(feats_x1_c, coarse_matches.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
-        coarse_matches_delta = self.coarse_matcher(torch.cat((feats_x0_c, feats_x1_c_warped, coarse_warp), dim=1))
-        coarse_matches = coarse_matches + coarse_matches_delta * to_normalized
-        corresps[8] = {"flow": coarse_matches[:,:2], "certainty": coarse_matches[:,2:]}
-        coarse_matches_up = F.interpolate(coarse_matches, size = feats_x0_f.shape[-2:], mode = "bilinear", align_corners = False)        
-        coarse_matches_up_detach = coarse_matches_up.detach()#note the detach
-        feats_x1_f_warped = F.grid_sample(feats_x1_f, coarse_matches_up_detach.permute(0, 2, 3, 1)[...,:2], mode = 'bilinear', align_corners = False)
-        fine_matches_delta = self.fine_matcher(torch.cat((feats_x0_f, feats_x1_f_warped, coarse_matches_up_detach[:,:2]), dim=1))
-        fine_matches = coarse_matches_up_detach+fine_matches_delta * to_normalized
-        corresps[4] = {"flow": fine_matches[:,:2], "certainty": fine_matches[:,2:]}
-        return corresps

+ 0 - 48
python/RoMa/romatch/models/transformer/__init__.py

@@ -1,48 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from romatch.utils.utils import get_grid, get_autocast_params
-from .layers.block import Block
-from .layers.attention import MemEffAttention
-from .dinov2 import vit_large
-
-class TransformerDecoder(nn.Module):
-    def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, 
-                 amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
-        super().__init__(*args, **kwargs)
-        self.blocks = blocks
-        self.to_out = nn.Linear(hidden_dim, out_dim)
-        self.hidden_dim = hidden_dim
-        self.out_dim = out_dim
-        self._scales = [16]
-        self.is_classifier = is_classifier
-        self.amp = amp
-        self.amp_dtype = amp_dtype
-        self.pos_enc = pos_enc
-        self.learned_embeddings = learned_embeddings
-        if self.learned_embeddings:
-            self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
-
-    def scales(self):
-        return self._scales.copy()
-
-    def forward(self, gp_posterior, features, old_stuff, new_scale):
-        autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(gp_posterior.device, enabled=self.amp, dtype=self.amp_dtype)
-        with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
-            B,C,H,W = gp_posterior.shape
-            x = torch.cat((gp_posterior, features), dim = 1)
-            B,C,H,W = x.shape
-            grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
-            if self.learned_embeddings:
-                pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
-            else:
-                pos_enc = 0
-            tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
-            z = self.blocks(tokens)
-            out = self.to_out(z)
-            out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
-            warp, certainty = out[:, :-1], out[:, -1:]
-            return warp, certainty, None
-
-

+ 0 - 359
python/RoMa/romatch/models/transformer/dinov2.py

@@ -1,359 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-# References:
-#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
-#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
-
-from functools import partial
-import math
-import logging
-from typing import Sequence, Tuple, Union, Callable
-
-import torch
-import torch.nn as nn
-import torch.utils.checkpoint
-from torch.nn.init import trunc_normal_
-
-from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
-
-
-
-def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
-    if not depth_first and include_root:
-        fn(module=module, name=name)
-    for child_name, child_module in module.named_children():
-        child_name = ".".join((name, child_name)) if name else child_name
-        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
-    if depth_first and include_root:
-        fn(module=module, name=name)
-    return module
-
-
-class BlockChunk(nn.ModuleList):
-    def forward(self, x):
-        for b in self:
-            x = b(x)
-        return x
-
-
-class DinoVisionTransformer(nn.Module):
-    def __init__(
-        self,
-        img_size=224,
-        patch_size=16,
-        in_chans=3,
-        embed_dim=768,
-        depth=12,
-        num_heads=12,
-        mlp_ratio=4.0,
-        qkv_bias=True,
-        ffn_bias=True,
-        proj_bias=True,
-        drop_path_rate=0.0,
-        drop_path_uniform=False,
-        init_values=None,  # for layerscale: None or 0 => no layerscale
-        embed_layer=PatchEmbed,
-        act_layer=nn.GELU,
-        block_fn=Block,
-        ffn_layer="mlp",
-        block_chunks=1,
-    ):
-        """
-        Args:
-            img_size (int, tuple): input image size
-            patch_size (int, tuple): patch size
-            in_chans (int): number of input channels
-            embed_dim (int): embedding dimension
-            depth (int): depth of transformer
-            num_heads (int): number of attention heads
-            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
-            qkv_bias (bool): enable bias for qkv if True
-            proj_bias (bool): enable bias for proj in attn if True
-            ffn_bias (bool): enable bias for ffn if True
-            drop_path_rate (float): stochastic depth rate
-            drop_path_uniform (bool): apply uniform drop rate across blocks
-            weight_init (str): weight init scheme
-            init_values (float): layer-scale init values
-            embed_layer (nn.Module): patch embedding layer
-            act_layer (nn.Module): MLP activation layer
-            block_fn (nn.Module): transformer block class
-            ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
-            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
-        """
-        super().__init__()
-        norm_layer = partial(nn.LayerNorm, eps=1e-6)
-
-        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
-        self.num_tokens = 1
-        self.n_blocks = depth
-        self.num_heads = num_heads
-        self.patch_size = patch_size
-
-        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
-        num_patches = self.patch_embed.num_patches
-
-        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
-        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
-
-        if drop_path_uniform is True:
-            dpr = [drop_path_rate] * depth
-        else:
-            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
-
-        if ffn_layer == "mlp":
-            ffn_layer = Mlp
-        elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
-            ffn_layer = SwiGLUFFNFused
-        elif ffn_layer == "identity":
-
-            def f(*args, **kwargs):
-                return nn.Identity()
-
-            ffn_layer = f
-        else:
-            raise NotImplementedError
-
-        blocks_list = [
-            block_fn(
-                dim=embed_dim,
-                num_heads=num_heads,
-                mlp_ratio=mlp_ratio,
-                qkv_bias=qkv_bias,
-                proj_bias=proj_bias,
-                ffn_bias=ffn_bias,
-                drop_path=dpr[i],
-                norm_layer=norm_layer,
-                act_layer=act_layer,
-                ffn_layer=ffn_layer,
-                init_values=init_values,
-            )
-            for i in range(depth)
-        ]
-        if block_chunks > 0:
-            self.chunked_blocks = True
-            chunked_blocks = []
-            chunksize = depth // block_chunks
-            for i in range(0, depth, chunksize):
-                # this is to keep the block index consistent if we chunk the block list
-                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
-            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
-        else:
-            self.chunked_blocks = False
-            self.blocks = nn.ModuleList(blocks_list)
-
-        self.norm = norm_layer(embed_dim)
-        self.head = nn.Identity()
-
-        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
-
-        self.init_weights()
-        for param in self.parameters():
-            param.requires_grad = False
-    
-    @property
-    def device(self):
-        return self.cls_token.device
-
-    def init_weights(self):
-        trunc_normal_(self.pos_embed, std=0.02)
-        nn.init.normal_(self.cls_token, std=1e-6)
-        named_apply(init_weights_vit_timm, self)
-
-    def interpolate_pos_encoding(self, x, w, h):
-        previous_dtype = x.dtype
-        npatch = x.shape[1] - 1
-        N = self.pos_embed.shape[1] - 1
-        if npatch == N and w == h:
-            return self.pos_embed
-        pos_embed = self.pos_embed.float()
-        class_pos_embed = pos_embed[:, 0]
-        patch_pos_embed = pos_embed[:, 1:]
-        dim = x.shape[-1]
-        w0 = w // self.patch_size
-        h0 = h // self.patch_size
-        # we add a small number to avoid floating point error in the interpolation
-        # see discussion at https://github.com/facebookresearch/dino/issues/8
-        w0, h0 = w0 + 0.1, h0 + 0.1
-
-        patch_pos_embed = nn.functional.interpolate(
-            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
-            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
-            mode="bicubic",
-        )
-
-        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
-        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
-        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
-
-    def prepare_tokens_with_masks(self, x, masks=None):
-        B, nc, w, h = x.shape
-        x = self.patch_embed(x)
-        if masks is not None:
-            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
-
-        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
-        x = x + self.interpolate_pos_encoding(x, w, h)
-
-        return x
-
-    def forward_features_list(self, x_list, masks_list):
-        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
-        for blk in self.blocks:
-            x = blk(x)
-
-        all_x = x
-        output = []
-        for x, masks in zip(all_x, masks_list):
-            x_norm = self.norm(x)
-            output.append(
-                {
-                    "x_norm_clstoken": x_norm[:, 0],
-                    "x_norm_patchtokens": x_norm[:, 1:],
-                    "x_prenorm": x,
-                    "masks": masks,
-                }
-            )
-        return output
-
-    def forward_features(self, x, masks=None):
-        if isinstance(x, list):
-            return self.forward_features_list(x, masks)
-
-        x = self.prepare_tokens_with_masks(x, masks)
-
-        for blk in self.blocks:
-            x = blk(x)
-
-        x_norm = self.norm(x)
-        return {
-            "x_norm_clstoken": x_norm[:, 0],
-            "x_norm_patchtokens": x_norm[:, 1:],
-            "x_prenorm": x,
-            "masks": masks,
-        }
-
-    def _get_intermediate_layers_not_chunked(self, x, n=1):
-        x = self.prepare_tokens_with_masks(x)
-        # If n is an int, take the n last blocks. If it's a list, take them
-        output, total_block_len = [], len(self.blocks)
-        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
-        for i, blk in enumerate(self.blocks):
-            x = blk(x)
-            if i in blocks_to_take:
-                output.append(x)
-        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
-        return output
-
-    def _get_intermediate_layers_chunked(self, x, n=1):
-        x = self.prepare_tokens_with_masks(x)
-        output, i, total_block_len = [], 0, len(self.blocks[-1])
-        # If n is an int, take the n last blocks. If it's a list, take them
-        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
-        for block_chunk in self.blocks:
-            for blk in block_chunk[i:]:  # Passing the nn.Identity()
-                x = blk(x)
-                if i in blocks_to_take:
-                    output.append(x)
-                i += 1
-        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
-        return output
-
-    def get_intermediate_layers(
-        self,
-        x: torch.Tensor,
-        n: Union[int, Sequence] = 1,  # Layers or n last layers to take
-        reshape: bool = False,
-        return_class_token: bool = False,
-        norm=True,
-    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
-        if self.chunked_blocks:
-            outputs = self._get_intermediate_layers_chunked(x, n)
-        else:
-            outputs = self._get_intermediate_layers_not_chunked(x, n)
-        if norm:
-            outputs = [self.norm(out) for out in outputs]
-        class_tokens = [out[:, 0] for out in outputs]
-        outputs = [out[:, 1:] for out in outputs]
-        if reshape:
-            B, _, w, h = x.shape
-            outputs = [
-                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
-                for out in outputs
-            ]
-        if return_class_token:
-            return tuple(zip(outputs, class_tokens))
-        return tuple(outputs)
-
-    def forward(self, *args, is_training=False, **kwargs):
-        ret = self.forward_features(*args, **kwargs)
-        if is_training:
-            return ret
-        else:
-            return self.head(ret["x_norm_clstoken"])
-
-
-def init_weights_vit_timm(module: nn.Module, name: str = ""):
-    """ViT weight initialization, original timm impl (for reproducibility)"""
-    if isinstance(module, nn.Linear):
-        trunc_normal_(module.weight, std=0.02)
-        if module.bias is not None:
-            nn.init.zeros_(module.bias)
-
-
-def vit_small(patch_size=16, **kwargs):
-    model = DinoVisionTransformer(
-        patch_size=patch_size,
-        embed_dim=384,
-        depth=12,
-        num_heads=6,
-        mlp_ratio=4,
-        block_fn=partial(Block, attn_class=MemEffAttention),
-        **kwargs,
-    )
-    return model
-
-
-def vit_base(patch_size=16, **kwargs):
-    model = DinoVisionTransformer(
-        patch_size=patch_size,
-        embed_dim=768,
-        depth=12,
-        num_heads=12,
-        mlp_ratio=4,
-        block_fn=partial(Block, attn_class=MemEffAttention),
-        **kwargs,
-    )
-    return model
-
-
-def vit_large(patch_size=16, **kwargs):
-    model = DinoVisionTransformer(
-        patch_size=patch_size,
-        embed_dim=1024,
-        depth=24,
-        num_heads=16,
-        mlp_ratio=4,
-        block_fn=partial(Block, attn_class=MemEffAttention),
-        **kwargs,
-    )
-    return model
-
-
-def vit_giant2(patch_size=16, **kwargs):
-    """
-    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
-    """
-    model = DinoVisionTransformer(
-        patch_size=patch_size,
-        embed_dim=1536,
-        depth=40,
-        num_heads=24,
-        mlp_ratio=4,
-        block_fn=partial(Block, attn_class=MemEffAttention),
-        **kwargs,
-    )
-    return model

+ 0 - 12
python/RoMa/romatch/models/transformer/layers/__init__.py

@@ -1,12 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from .dino_head import DINOHead
-from .mlp import Mlp
-from .patch_embed import PatchEmbed
-from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
-from .block import NestedTensorBlock
-from .attention import MemEffAttention

+ 0 - 96
python/RoMa/romatch/models/transformer/layers/attention.py

@@ -1,96 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-# References:
-#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
-#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
-
-import logging
-
-from torch import Tensor
-from torch import nn
-import torch
-
-
-logger = logging.getLogger("dinov2")
-
-
-try:
-    from xformers.ops import memory_efficient_attention, unbind, fmha
-
-    XFORMERS_AVAILABLE = True
-except ImportError:
-    # logger.warning("xFormers not available")
-    XFORMERS_AVAILABLE = False
-
-
-class Attention(nn.Module):
-    def __init__(
-        self,
-        dim: int,
-        num_heads: int = 8,
-        qkv_bias: bool = False,
-        proj_bias: bool = True,
-        attn_drop: float = 0.0,
-        proj_drop: float = 0.0,
-    ) -> None:
-        super().__init__()
-        self.num_heads = num_heads
-        head_dim = dim // num_heads
-        self.scale = head_dim**-0.5
-
-        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
-        self.attn_drop = nn.Dropout(attn_drop)
-        self.proj = nn.Linear(dim, dim, bias=proj_bias)
-        self.proj_drop = nn.Dropout(proj_drop)
-
-    def forward(self, x: Tensor) -> Tensor:
-        # use new pytorch native attn
-        qkv = self.qkv(x)
-        B, N, _ = qkv.shape
-        C = self.qkv.in_features
-
-        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
-        q, k, v = torch.unbind(qkv, 2)
-        q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
-        x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
-        x = x.transpose(1, 2).reshape([B, N, C])
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x
-        # old code below
-        B, N, C = x.shape
-        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
-
-        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
-        attn = q @ k.transpose(-2, -1)
-
-        attn = attn.softmax(dim=-1)
-        attn = self.attn_drop(attn)
-
-        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x
-
-
-class MemEffAttention(Attention):
-    def forward(self, x: Tensor, attn_bias=None) -> Tensor:
-        if not XFORMERS_AVAILABLE:
-            assert attn_bias is None, "xFormers is required for nested tensors usage"
-            return super().forward(x)
-
-        B, N, C = x.shape
-        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
-
-        q, k, v = unbind(qkv, 2)
-
-        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
-        x = x.reshape([B, N, C])
-
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x

+ 0 - 252
python/RoMa/romatch/models/transformer/layers/block.py

@@ -1,252 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-# References:
-#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
-#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
-
-import logging
-from typing import Callable, List, Any, Tuple, Dict
-
-import torch
-from torch import nn, Tensor
-
-from .attention import Attention, MemEffAttention
-from .drop_path import DropPath
-from .layer_scale import LayerScale
-from .mlp import Mlp
-
-
-logger = logging.getLogger("dinov2")
-
-
-try:
-    from xformers.ops import fmha
-    from xformers.ops import scaled_index_add, index_select_cat
-
-    XFORMERS_AVAILABLE = True
-except ImportError:
-    # logger.warning("xFormers not available")
-    XFORMERS_AVAILABLE = False
-
-
-class Block(nn.Module):
-    def __init__(
-        self,
-        dim: int,
-        num_heads: int,
-        mlp_ratio: float = 4.0,
-        qkv_bias: bool = False,
-        proj_bias: bool = True,
-        ffn_bias: bool = True,
-        drop: float = 0.0,
-        attn_drop: float = 0.0,
-        init_values=None,
-        drop_path: float = 0.0,
-        act_layer: Callable[..., nn.Module] = nn.GELU,
-        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
-        attn_class: Callable[..., nn.Module] = Attention,
-        ffn_layer: Callable[..., nn.Module] = Mlp,
-    ) -> None:
-        super().__init__()
-        # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
-        self.norm1 = norm_layer(dim)
-        self.attn = attn_class(
-            dim,
-            num_heads=num_heads,
-            qkv_bias=qkv_bias,
-            proj_bias=proj_bias,
-            attn_drop=attn_drop,
-            proj_drop=drop,
-        )
-        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
-        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
-        self.norm2 = norm_layer(dim)
-        mlp_hidden_dim = int(dim * mlp_ratio)
-        self.mlp = ffn_layer(
-            in_features=dim,
-            hidden_features=mlp_hidden_dim,
-            act_layer=act_layer,
-            drop=drop,
-            bias=ffn_bias,
-        )
-        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
-        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
-        self.sample_drop_ratio = drop_path
-
-    def forward(self, x: Tensor) -> Tensor:
-        def attn_residual_func(x: Tensor) -> Tensor:
-            return self.ls1(self.attn(self.norm1(x)))
-
-        def ffn_residual_func(x: Tensor) -> Tensor:
-            return self.ls2(self.mlp(self.norm2(x)))
-
-        if self.training and self.sample_drop_ratio > 0.1:
-            # the overhead is compensated only for a drop path rate larger than 0.1
-            x = drop_add_residual_stochastic_depth(
-                x,
-                residual_func=attn_residual_func,
-                sample_drop_ratio=self.sample_drop_ratio,
-            )
-            x = drop_add_residual_stochastic_depth(
-                x,
-                residual_func=ffn_residual_func,
-                sample_drop_ratio=self.sample_drop_ratio,
-            )
-        elif self.training and self.sample_drop_ratio > 0.0:
-            x = x + self.drop_path1(attn_residual_func(x))
-            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
-        else:
-            x = x + attn_residual_func(x)
-            x = x + ffn_residual_func(x)
-        return x
-
-
-def drop_add_residual_stochastic_depth(
-    x: Tensor,
-    residual_func: Callable[[Tensor], Tensor],
-    sample_drop_ratio: float = 0.0,
-) -> Tensor:
-    # 1) extract subset using permutation
-    b, n, d = x.shape
-    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
-    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
-    x_subset = x[brange]
-
-    # 2) apply residual_func to get residual
-    residual = residual_func(x_subset)
-
-    x_flat = x.flatten(1)
-    residual = residual.flatten(1)
-
-    residual_scale_factor = b / sample_subset_size
-
-    # 3) add the residual
-    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
-    return x_plus_residual.view_as(x)
-
-
-def get_branges_scales(x, sample_drop_ratio=0.0):
-    b, n, d = x.shape
-    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
-    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
-    residual_scale_factor = b / sample_subset_size
-    return brange, residual_scale_factor
-
-
-def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
-    if scaling_vector is None:
-        x_flat = x.flatten(1)
-        residual = residual.flatten(1)
-        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
-    else:
-        x_plus_residual = scaled_index_add(
-            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
-        )
-    return x_plus_residual
-
-
-attn_bias_cache: Dict[Tuple, Any] = {}
-
-
-def get_attn_bias_and_cat(x_list, branges=None):
-    """
-    this will perform the index select, cat the tensors, and provide the attn_bias from cache
-    """
-    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
-    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
-    if all_shapes not in attn_bias_cache.keys():
-        seqlens = []
-        for b, x in zip(batch_sizes, x_list):
-            for _ in range(b):
-                seqlens.append(x.shape[1])
-        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
-        attn_bias._batch_sizes = batch_sizes
-        attn_bias_cache[all_shapes] = attn_bias
-
-    if branges is not None:
-        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
-    else:
-        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
-        cat_tensors = torch.cat(tensors_bs1, dim=1)
-
-    return attn_bias_cache[all_shapes], cat_tensors
-
-
-def drop_add_residual_stochastic_depth_list(
-    x_list: List[Tensor],
-    residual_func: Callable[[Tensor, Any], Tensor],
-    sample_drop_ratio: float = 0.0,
-    scaling_vector=None,
-) -> Tensor:
-    # 1) generate random set of indices for dropping samples in the batch
-    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
-    branges = [s[0] for s in branges_scales]
-    residual_scale_factors = [s[1] for s in branges_scales]
-
-    # 2) get attention bias and index+concat the tensors
-    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
-
-    # 3) apply residual_func to get residual, and split the result
-    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
-
-    outputs = []
-    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
-        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
-    return outputs
-
-
-class NestedTensorBlock(Block):
-    def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
-        """
-        x_list contains a list of tensors to nest together and run
-        """
-        assert isinstance(self.attn, MemEffAttention)
-
-        if self.training and self.sample_drop_ratio > 0.0:
-
-            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
-                return self.attn(self.norm1(x), attn_bias=attn_bias)
-
-            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
-                return self.mlp(self.norm2(x))
-
-            x_list = drop_add_residual_stochastic_depth_list(
-                x_list,
-                residual_func=attn_residual_func,
-                sample_drop_ratio=self.sample_drop_ratio,
-                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
-            )
-            x_list = drop_add_residual_stochastic_depth_list(
-                x_list,
-                residual_func=ffn_residual_func,
-                sample_drop_ratio=self.sample_drop_ratio,
-                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
-            )
-            return x_list
-        else:
-
-            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
-                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
-
-            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
-                return self.ls2(self.mlp(self.norm2(x)))
-
-            attn_bias, x = get_attn_bias_and_cat(x_list)
-            x = x + attn_residual_func(x, attn_bias=attn_bias)
-            x = x + ffn_residual_func(x)
-            return attn_bias.split(x)
-
-    def forward(self, x_or_x_list):
-        if isinstance(x_or_x_list, Tensor):
-            return super().forward(x_or_x_list)
-        elif isinstance(x_or_x_list, list):
-            assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
-            return self.forward_nested(x_or_x_list)
-        else:
-            raise AssertionError

+ 0 - 59
python/RoMa/romatch/models/transformer/layers/dino_head.py

@@ -1,59 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn as nn
-from torch.nn.init import trunc_normal_
-from torch.nn.utils import weight_norm
-
-
-class DINOHead(nn.Module):
-    def __init__(
-        self,
-        in_dim,
-        out_dim,
-        use_bn=False,
-        nlayers=3,
-        hidden_dim=2048,
-        bottleneck_dim=256,
-        mlp_bias=True,
-    ):
-        super().__init__()
-        nlayers = max(nlayers, 1)
-        self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
-        self.apply(self._init_weights)
-        self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
-        self.last_layer.weight_g.data.fill_(1)
-
-    def _init_weights(self, m):
-        if isinstance(m, nn.Linear):
-            trunc_normal_(m.weight, std=0.02)
-            if isinstance(m, nn.Linear) and m.bias is not None:
-                nn.init.constant_(m.bias, 0)
-
-    def forward(self, x):
-        x = self.mlp(x)
-        eps = 1e-6 if x.dtype == torch.float16 else 1e-12
-        x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
-        x = self.last_layer(x)
-        return x
-
-
-def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
-    if nlayers == 1:
-        return nn.Linear(in_dim, bottleneck_dim, bias=bias)
-    else:
-        layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
-        if use_bn:
-            layers.append(nn.BatchNorm1d(hidden_dim))
-        layers.append(nn.GELU())
-        for _ in range(nlayers - 2):
-            layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
-            if use_bn:
-                layers.append(nn.BatchNorm1d(hidden_dim))
-            layers.append(nn.GELU())
-        layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
-        return nn.Sequential(*layers)

Некоторые файлы не были показаны из-за большого количества измененных файлов