wavespeed.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import base64
  2. import time
  3. from abc import ABC
  4. from typing import Any
  5. from urllib.parse import urlparse
  6. from huggingface_hub.hf_api import InferenceProviderMapping
  7. from huggingface_hub.inference._common import RequestParameters, _as_dict
  8. from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
  9. from huggingface_hub.utils import get_session, hf_raise_for_status
  10. from huggingface_hub.utils.logging import get_logger
  11. logger = get_logger(__name__)
  12. # Polling interval (in seconds)
  13. _POLLING_INTERVAL = 0.5
  14. class WavespeedAITask(TaskProviderHelper, ABC):
  15. def __init__(self, task: str):
  16. super().__init__(provider="wavespeed", base_url="https://api.wavespeed.ai", task=task)
  17. def _prepare_route(self, mapped_model: str, api_key: str) -> str:
  18. return f"/api/v3/{mapped_model}"
  19. def get_response(
  20. self,
  21. response: bytes | dict,
  22. request_params: RequestParameters | None = None,
  23. ) -> Any:
  24. response_dict = _as_dict(response)
  25. data = response_dict.get("data", {})
  26. result_path = data.get("urls", {}).get("get")
  27. if not result_path:
  28. raise ValueError("No result URL found in the response")
  29. if request_params is None:
  30. raise ValueError("A `RequestParameters` object should be provided to get responses with WaveSpeed AI.")
  31. # Parse the request URL to determine base URL
  32. parsed_url = urlparse(request_params.url)
  33. # Add /wavespeed to base URL if going through HF router
  34. if parsed_url.netloc == "router.huggingface.co":
  35. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/wavespeed"
  36. else:
  37. base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
  38. # Extract path from result_path URL
  39. if isinstance(result_path, str):
  40. result_url_path = urlparse(result_path).path
  41. else:
  42. result_url_path = result_path
  43. result_url = f"{base_url}{result_url_path}"
  44. logger.info("Processing request, polling for results...")
  45. # Poll until task is completed
  46. while True:
  47. time.sleep(_POLLING_INTERVAL)
  48. result_response = get_session().get(result_url, headers=request_params.headers)
  49. hf_raise_for_status(result_response)
  50. result = result_response.json()
  51. task_result = result.get("data", {})
  52. status = task_result.get("status")
  53. if status == "completed":
  54. # Get content from the first output URL
  55. if not task_result.get("outputs") or len(task_result["outputs"]) == 0:
  56. raise ValueError("No output URL in completed response")
  57. output_url = task_result["outputs"][0]
  58. return get_session().get(output_url).content
  59. elif status == "failed":
  60. error_msg = task_result.get("error", "Task failed with no specific error message")
  61. raise ValueError(f"WaveSpeed AI task failed: {error_msg}")
  62. elif status in ["processing", "created"]:
  63. continue
  64. else:
  65. raise ValueError(f"Unknown status: {status}")
  66. class WavespeedAITextToImageTask(WavespeedAITask):
  67. def __init__(self):
  68. super().__init__("text-to-image")
  69. def _prepare_payload_as_dict(
  70. self,
  71. inputs: Any,
  72. parameters: dict,
  73. provider_mapping_info: InferenceProviderMapping,
  74. ) -> dict | None:
  75. return {"prompt": inputs, **filter_none(parameters)}
  76. class WavespeedAITextToVideoTask(WavespeedAITextToImageTask):
  77. def __init__(self):
  78. WavespeedAITask.__init__(self, "text-to-video")
  79. class WavespeedAIImageToImageTask(WavespeedAITask):
  80. def __init__(self):
  81. super().__init__("image-to-image")
  82. def _prepare_payload_as_dict(
  83. self,
  84. inputs: Any,
  85. parameters: dict,
  86. provider_mapping_info: InferenceProviderMapping,
  87. ) -> dict | None:
  88. # Convert inputs to image (URL or base64)
  89. if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
  90. image = inputs
  91. elif isinstance(inputs, str):
  92. # If input is a file path, read it first
  93. with open(inputs, "rb") as f:
  94. file_content = f.read()
  95. image_b64 = base64.b64encode(file_content).decode("utf-8")
  96. image = f"data:image/jpeg;base64,{image_b64}"
  97. else:
  98. # If input is binary data
  99. image_b64 = base64.b64encode(inputs).decode("utf-8")
  100. image = f"data:image/jpeg;base64,{image_b64}"
  101. # Extract prompt from parameters if present
  102. prompt = parameters.pop("prompt", None)
  103. payload = {"image": image, **filter_none(parameters)}
  104. if prompt is not None:
  105. payload["prompt"] = prompt
  106. return payload
  107. class WavespeedAIImageToVideoTask(WavespeedAIImageToImageTask):
  108. def __init__(self):
  109. WavespeedAITask.__init__(self, "image-to-video")