download-img-by-prompt.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 根据描述词(prompt)用 imagedl 搜索并下载一张图片到指定路径。
  5. 用法: python download-img-by-prompt.py --prompt "关键词" --save-path "/path/to/out.png"
  6. 需设置 PYTHONPATH 包含 python/imagedl 的父目录(如项目根下的 python)。
  7. 输出: JSON 到 stdout,{ "success": true, "path": "..." } 或 { "success": false, "error": "..." }
  8. """
  9. import os
  10. import sys
  11. import json
  12. import tempfile
  13. import shutil
  14. def main():
  15. prompt = None
  16. save_path = None
  17. source = 'BaiduImageClient'
  18. i = 1
  19. while i < len(sys.argv):
  20. if sys.argv[i] == '--prompt' and i + 1 < len(sys.argv):
  21. prompt = sys.argv[i + 1]
  22. i += 2
  23. elif sys.argv[i] == '--save-path' and i + 1 < len(sys.argv):
  24. save_path = sys.argv[i + 1]
  25. i += 2
  26. elif sys.argv[i] == '--source' and i + 1 < len(sys.argv):
  27. source = sys.argv[i + 1]
  28. i += 2
  29. else:
  30. i += 1
  31. if not prompt or not save_path:
  32. print(json.dumps({"success": False, "error": "缺少 --prompt 或 --save-path"}))
  33. sys.exit(1)
  34. save_path = os.path.abspath(os.path.normpath(save_path))
  35. save_dir = os.path.dirname(save_path)
  36. os.makedirs(save_dir, exist_ok=True)
  37. # 确保能 import imagedl(由调用方设置 PYTHONPATH 或在此添加)
  38. imagedl_parent = os.environ.get('IMAGEDL_PARENT')
  39. if imagedl_parent:
  40. if imagedl_parent not in sys.path:
  41. sys.path.insert(0, imagedl_parent)
  42. try:
  43. from imagedl.imagedl.imagedl import ImageClient
  44. except ImportError as e:
  45. print(json.dumps({"success": False, "error": f"无法导入 imagedl: {e}。请设置 PYTHONPATH 或 IMAGEDL_PARENT 为 python/imagedl 的父目录(即项目下的 python)。"}))
  46. sys.exit(1)
  47. tmp_dir = tempfile.mkdtemp(prefix='imagedl_download_img_')
  48. try:
  49. client = ImageClient(
  50. image_source=source,
  51. init_image_client_cfg={'work_dir': tmp_dir, 'disable_print': True},
  52. search_limits=5,
  53. num_threadings=1
  54. )
  55. image_infos = client.search(prompt, search_limits_overrides=1, num_threadings_overrides=1)
  56. if not image_infos or len(image_infos) == 0:
  57. print(json.dumps({"success": False, "error": "未搜到任何图片"}))
  58. sys.exit(1)
  59. # 只取前几张,依次尝试下载,成功一张即用
  60. exts = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
  61. max_try = min(5, len(image_infos))
  62. found = None
  63. for i in range(max_try):
  64. client.download([image_infos[i]], num_threadings_overrides=1)
  65. for root, _, files in os.walk(tmp_dir):
  66. for f in files:
  67. if any(f.lower().endswith(ext) for ext in exts):
  68. found = os.path.join(root, f)
  69. break
  70. if found:
  71. break
  72. if found:
  73. break
  74. if not found:
  75. print(json.dumps({"success": False, "error": "下载后未找到图片文件"}))
  76. sys.exit(1)
  77. # 若 save_path 无扩展名则沿用下载文件的扩展名
  78. if not os.path.splitext(save_path)[1]:
  79. save_path = save_path + os.path.splitext(found)[1]
  80. shutil.copy2(found, save_path)
  81. print(json.dumps({"success": True, "path": save_path}))
  82. except Exception as e:
  83. print(json.dumps({"success": False, "error": str(e)}))
  84. sys.exit(1)
  85. finally:
  86. shutil.rmtree(tmp_dir, ignore_errors=True)
  87. if __name__ == '__main__':
  88. main()