#!/usr/bin/env python # -*- coding: utf-8 -*- """ 根据描述词(prompt)用 imagedl 搜索并下载一张图片到指定路径。 用法: python download-img-by-prompt.py --prompt "关键词" --save-path "/path/to/out.png" 需设置 PYTHONPATH 包含 python/imagedl 的父目录(如项目根下的 python)。 输出: JSON 到 stdout,{ "success": true, "path": "..." } 或 { "success": false, "error": "..." } """ import os import sys import json import tempfile import shutil def main(): prompt = None save_path = None source = 'BaiduImageClient' i = 1 while i < len(sys.argv): if sys.argv[i] == '--prompt' and i + 1 < len(sys.argv): prompt = sys.argv[i + 1] i += 2 elif sys.argv[i] == '--save-path' and i + 1 < len(sys.argv): save_path = sys.argv[i + 1] i += 2 elif sys.argv[i] == '--source' and i + 1 < len(sys.argv): source = sys.argv[i + 1] i += 2 else: i += 1 if not prompt or not save_path: print(json.dumps({"success": False, "error": "缺少 --prompt 或 --save-path"})) sys.exit(1) save_path = os.path.abspath(os.path.normpath(save_path)) save_dir = os.path.dirname(save_path) os.makedirs(save_dir, exist_ok=True) # 确保能 import imagedl(由调用方设置 PYTHONPATH 或在此添加) imagedl_parent = os.environ.get('IMAGEDL_PARENT') if imagedl_parent: if imagedl_parent not in sys.path: sys.path.insert(0, imagedl_parent) try: from imagedl.imagedl.imagedl import ImageClient except ImportError as e: print(json.dumps({"success": False, "error": f"无法导入 imagedl: {e}。请设置 PYTHONPATH 或 IMAGEDL_PARENT 为 python/imagedl 的父目录(即项目下的 python)。"})) sys.exit(1) tmp_dir = tempfile.mkdtemp(prefix='imagedl_download_img_') try: client = ImageClient( image_source=source, init_image_client_cfg={'work_dir': tmp_dir, 'disable_print': True}, search_limits=5, num_threadings=1 ) image_infos = client.search(prompt, search_limits_overrides=1, num_threadings_overrides=1) if not image_infos or len(image_infos) == 0: print(json.dumps({"success": False, "error": "未搜到任何图片"})) sys.exit(1) # 只取前几张,依次尝试下载,成功一张即用 exts = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp') max_try = min(5, len(image_infos)) found = None for i in range(max_try): client.download([image_infos[i]], num_threadings_overrides=1) for root, _, files in os.walk(tmp_dir): for f in files: if any(f.lower().endswith(ext) for ext in exts): found = os.path.join(root, f) break if found: break if found: break if not found: print(json.dumps({"success": False, "error": "下载后未找到图片文件"})) sys.exit(1) # 若 save_path 无扩展名则沿用下载文件的扩展名 if not os.path.splitext(save_path)[1]: save_path = save_path + os.path.splitext(found)[1] shutil.copy2(found, save_path) print(json.dumps({"success": True, "path": save_path})) except Exception as e: print(json.dumps({"success": False, "error": str(e)})) sys.exit(1) finally: shutil.rmtree(tmp_dir, ignore_errors=True) if __name__ == '__main__': main()