| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- from typing import Any
- from huggingface_hub.hf_api import InferenceProviderMapping
- from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
- from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
- from huggingface_hub.utils import get_session
- _PROVIDER = "replicate"
- _BASE_URL = "https://api.replicate.com"
- class ReplicateTask(TaskProviderHelper):
- def __init__(self, task: str):
- super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
- def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
- headers = super()._prepare_headers(headers, api_key)
- headers["Prefer"] = "wait"
- return headers
- def _prepare_route(self, mapped_model: str, api_key: str) -> str:
- if ":" in mapped_model:
- return "/v1/predictions"
- return f"/v1/models/{mapped_model}/predictions"
- def _prepare_payload_as_dict(
- self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
- ) -> dict | None:
- mapped_model = provider_mapping_info.provider_id
- payload: dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}}
- if ":" in mapped_model:
- version = mapped_model.split(":", 1)[1]
- payload["version"] = version
- return payload
- def get_response(self, response: bytes | dict, request_params: RequestParameters | None = None) -> Any:
- response_dict = _as_dict(response)
- if response_dict.get("output") is None:
- raise TimeoutError(
- f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}"
- "The model might be in cold state or starting up. Please try again later."
- )
- output_url = (
- response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0]
- )
- return get_session().get(output_url).content
- class ReplicateTextToImageTask(ReplicateTask):
- def __init__(self):
- super().__init__("text-to-image")
- def _prepare_payload_as_dict(
- self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
- ) -> dict | None:
- payload: dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore
- if provider_mapping_info.adapter_weights_path is not None:
- payload["input"]["lora_weights"] = f"https://huggingface.co/{provider_mapping_info.hf_model_id}"
- return payload
- class ReplicateTextToSpeechTask(ReplicateTask):
- def __init__(self):
- super().__init__("text-to-speech")
- def _prepare_payload_as_dict(
- self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
- ) -> dict | None:
- payload: dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore
- payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
- return payload
- class ReplicateAutomaticSpeechRecognitionTask(ReplicateTask):
- def __init__(self) -> None:
- super().__init__("automatic-speech-recognition")
- def _prepare_payload_as_dict(
- self,
- inputs: Any,
- parameters: dict,
- provider_mapping_info: InferenceProviderMapping,
- ) -> dict | None:
- mapped_model = provider_mapping_info.provider_id
- audio_url = _as_url(inputs, default_mime_type="audio/wav")
- payload: dict[str, Any] = {
- "input": {
- **{"audio": audio_url},
- **filter_none(parameters),
- }
- }
- if ":" in mapped_model:
- payload["version"] = mapped_model.split(":", 1)[1]
- return payload
- def get_response(self, response: bytes | dict, request_params: RequestParameters | None = None) -> Any:
- response_dict = _as_dict(response)
- output = response_dict.get("output")
- if isinstance(output, str):
- return {"text": output}
- if isinstance(output, list) and output:
- first_item = output[0]
- if isinstance(first_item, str):
- return {"text": first_item}
- if isinstance(first_item, dict):
- output = first_item
- text: str | None = None
- if isinstance(output, dict):
- transcription = output.get("transcription")
- if isinstance(transcription, str):
- text = transcription
- translation = output.get("translation")
- if isinstance(translation, str):
- text = translation
- txt_file = output.get("txt_file")
- if isinstance(txt_file, str):
- text_response = get_session().get(txt_file)
- text_response.raise_for_status()
- text = text_response.text
- if text is not None:
- return {"text": text}
- raise ValueError("Received malformed response from Replicate automatic-speech-recognition API")
- class ReplicateImageToImageTask(ReplicateTask):
- def __init__(self):
- super().__init__("image-to-image")
- def _prepare_payload_as_dict(
- self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
- ) -> dict | None:
- image_url = _as_url(inputs, default_mime_type="image/jpeg")
- # Different Replicate models expect the image in different keys
- payload: dict[str, Any] = {
- "input": {
- "image": image_url,
- "images": [image_url],
- "input_image": image_url,
- "input_images": [image_url],
- **filter_none(parameters),
- }
- }
- mapped_model = provider_mapping_info.provider_id
- if ":" in mapped_model:
- version = mapped_model.split(":", 1)[1]
- payload["version"] = version
- return payload
|