| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 根据描述词(prompt)用 imagedl 搜索并下载一张图片到指定路径。
- 用法: python download-img-by-prompt.py --prompt "关键词" --save-path "/path/to/out.png"
- 需设置 PYTHONPATH 包含 python/imagedl 的父目录(如项目根下的 python)。
- 输出: JSON 到 stdout,{ "success": true, "path": "..." } 或 { "success": false, "error": "..." }
- """
- import os
- import sys
- import json
- import tempfile
- import shutil
- def main():
- prompt = None
- save_path = None
- source = 'BaiduImageClient'
- i = 1
- while i < len(sys.argv):
- if sys.argv[i] == '--prompt' and i + 1 < len(sys.argv):
- prompt = sys.argv[i + 1]
- i += 2
- elif sys.argv[i] == '--save-path' and i + 1 < len(sys.argv):
- save_path = sys.argv[i + 1]
- i += 2
- elif sys.argv[i] == '--source' and i + 1 < len(sys.argv):
- source = sys.argv[i + 1]
- i += 2
- else:
- i += 1
- if not prompt or not save_path:
- print(json.dumps({"success": False, "error": "缺少 --prompt 或 --save-path"}))
- sys.exit(1)
- save_path = os.path.abspath(os.path.normpath(save_path))
- save_dir = os.path.dirname(save_path)
- os.makedirs(save_dir, exist_ok=True)
- # 确保能 import imagedl(由调用方设置 PYTHONPATH 或在此添加)
- imagedl_parent = os.environ.get('IMAGEDL_PARENT')
- if imagedl_parent:
- if imagedl_parent not in sys.path:
- sys.path.insert(0, imagedl_parent)
- try:
- from imagedl.imagedl.imagedl import ImageClient
- except ImportError as e:
- print(json.dumps({"success": False, "error": f"无法导入 imagedl: {e}。请设置 PYTHONPATH 或 IMAGEDL_PARENT 为 python/imagedl 的父目录(即项目下的 python)。"}))
- sys.exit(1)
- tmp_dir = tempfile.mkdtemp(prefix='imagedl_download_img_')
- try:
- client = ImageClient(
- image_source=source,
- init_image_client_cfg={'work_dir': tmp_dir, 'disable_print': True},
- search_limits=5,
- num_threadings=1
- )
- image_infos = client.search(prompt, search_limits_overrides=1, num_threadings_overrides=1)
- if not image_infos or len(image_infos) == 0:
- print(json.dumps({"success": False, "error": "未搜到任何图片"}))
- sys.exit(1)
- # 只取前几张,依次尝试下载,成功一张即用
- exts = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
- max_try = min(5, len(image_infos))
- found = None
- for i in range(max_try):
- client.download([image_infos[i]], num_threadings_overrides=1)
- for root, _, files in os.walk(tmp_dir):
- for f in files:
- if any(f.lower().endswith(ext) for ext in exts):
- found = os.path.join(root, f)
- break
- if found:
- break
- if found:
- break
- if not found:
- print(json.dumps({"success": False, "error": "下载后未找到图片文件"}))
- sys.exit(1)
- # 若 save_path 无扩展名则沿用下载文件的扩展名
- if not os.path.splitext(save_path)[1]:
- save_path = save_path + os.path.splitext(found)[1]
- shutil.copy2(found, save_path)
- print(json.dumps({"success": True, "path": save_path}))
- except Exception as e:
- print(json.dumps({"success": False, "error": str(e)}))
- sys.exit(1)
- finally:
- shutil.rmtree(tmp_dir, ignore_errors=True)
- if __name__ == '__main__':
- main()
|