| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- import base64
- import time
- from abc import ABC
- from typing import Any
- from urllib.parse import urlparse
- from huggingface_hub.hf_api import InferenceProviderMapping
- from huggingface_hub.inference._common import RequestParameters, _as_dict
- from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
- from huggingface_hub.utils import get_session, hf_raise_for_status
- from huggingface_hub.utils.logging import get_logger
- logger = get_logger(__name__)
- # Polling interval (in seconds)
- _POLLING_INTERVAL = 0.5
- class WavespeedAITask(TaskProviderHelper, ABC):
- def __init__(self, task: str):
- super().__init__(provider="wavespeed", base_url="https://api.wavespeed.ai", task=task)
- def _prepare_route(self, mapped_model: str, api_key: str) -> str:
- return f"/api/v3/{mapped_model}"
- def get_response(
- self,
- response: bytes | dict,
- request_params: RequestParameters | None = None,
- ) -> Any:
- response_dict = _as_dict(response)
- data = response_dict.get("data", {})
- result_path = data.get("urls", {}).get("get")
- if not result_path:
- raise ValueError("No result URL found in the response")
- if request_params is None:
- raise ValueError("A `RequestParameters` object should be provided to get responses with WaveSpeed AI.")
- # Parse the request URL to determine base URL
- parsed_url = urlparse(request_params.url)
- # Add /wavespeed to base URL if going through HF router
- if parsed_url.netloc == "router.huggingface.co":
- base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/wavespeed"
- else:
- base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
- # Extract path from result_path URL
- if isinstance(result_path, str):
- result_url_path = urlparse(result_path).path
- else:
- result_url_path = result_path
- result_url = f"{base_url}{result_url_path}"
- logger.info("Processing request, polling for results...")
- # Poll until task is completed
- while True:
- time.sleep(_POLLING_INTERVAL)
- result_response = get_session().get(result_url, headers=request_params.headers)
- hf_raise_for_status(result_response)
- result = result_response.json()
- task_result = result.get("data", {})
- status = task_result.get("status")
- if status == "completed":
- # Get content from the first output URL
- if not task_result.get("outputs") or len(task_result["outputs"]) == 0:
- raise ValueError("No output URL in completed response")
- output_url = task_result["outputs"][0]
- return get_session().get(output_url).content
- elif status == "failed":
- error_msg = task_result.get("error", "Task failed with no specific error message")
- raise ValueError(f"WaveSpeed AI task failed: {error_msg}")
- elif status in ["processing", "created"]:
- continue
- else:
- raise ValueError(f"Unknown status: {status}")
- class WavespeedAITextToImageTask(WavespeedAITask):
- def __init__(self):
- super().__init__("text-to-image")
- def _prepare_payload_as_dict(
- self,
- inputs: Any,
- parameters: dict,
- provider_mapping_info: InferenceProviderMapping,
- ) -> dict | None:
- return {"prompt": inputs, **filter_none(parameters)}
- class WavespeedAITextToVideoTask(WavespeedAITextToImageTask):
- def __init__(self):
- WavespeedAITask.__init__(self, "text-to-video")
- class WavespeedAIImageToImageTask(WavespeedAITask):
- def __init__(self):
- super().__init__("image-to-image")
- def _prepare_payload_as_dict(
- self,
- inputs: Any,
- parameters: dict,
- provider_mapping_info: InferenceProviderMapping,
- ) -> dict | None:
- # Convert inputs to image (URL or base64)
- if isinstance(inputs, str) and inputs.startswith(("http://", "https://")):
- image = inputs
- elif isinstance(inputs, str):
- # If input is a file path, read it first
- with open(inputs, "rb") as f:
- file_content = f.read()
- image_b64 = base64.b64encode(file_content).decode("utf-8")
- image = f"data:image/jpeg;base64,{image_b64}"
- else:
- # If input is binary data
- image_b64 = base64.b64encode(inputs).decode("utf-8")
- image = f"data:image/jpeg;base64,{image_b64}"
- # Extract prompt from parameters if present
- prompt = parameters.pop("prompt", None)
- payload = {"image": image, **filter_none(parameters)}
- if prompt is not None:
- payload["prompt"] = prompt
- return payload
- class WavespeedAIImageToVideoTask(WavespeedAIImageToImageTask):
- def __init__(self):
- WavespeedAITask.__init__(self, "image-to-video")
|