replicate.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from typing import Any
  2. from huggingface_hub.hf_api import InferenceProviderMapping
  3. from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url
  4. from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none
  5. from huggingface_hub.utils import get_session
  6. _PROVIDER = "replicate"
  7. _BASE_URL = "https://api.replicate.com"
  8. class ReplicateTask(TaskProviderHelper):
  9. def __init__(self, task: str):
  10. super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task)
  11. def _prepare_headers(self, headers: dict, api_key: str) -> dict[str, Any]:
  12. headers = super()._prepare_headers(headers, api_key)
  13. headers["Prefer"] = "wait"
  14. return headers
  15. def _prepare_route(self, mapped_model: str, api_key: str) -> str:
  16. if ":" in mapped_model:
  17. return "/v1/predictions"
  18. return f"/v1/models/{mapped_model}/predictions"
  19. def _prepare_payload_as_dict(
  20. self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
  21. ) -> dict | None:
  22. mapped_model = provider_mapping_info.provider_id
  23. payload: dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}}
  24. if ":" in mapped_model:
  25. version = mapped_model.split(":", 1)[1]
  26. payload["version"] = version
  27. return payload
  28. def get_response(self, response: bytes | dict, request_params: RequestParameters | None = None) -> Any:
  29. response_dict = _as_dict(response)
  30. if response_dict.get("output") is None:
  31. raise TimeoutError(
  32. f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}"
  33. "The model might be in cold state or starting up. Please try again later."
  34. )
  35. output_url = (
  36. response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0]
  37. )
  38. return get_session().get(output_url).content
  39. class ReplicateTextToImageTask(ReplicateTask):
  40. def __init__(self):
  41. super().__init__("text-to-image")
  42. def _prepare_payload_as_dict(
  43. self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
  44. ) -> dict | None:
  45. payload: dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore
  46. if provider_mapping_info.adapter_weights_path is not None:
  47. payload["input"]["lora_weights"] = f"https://huggingface.co/{provider_mapping_info.hf_model_id}"
  48. return payload
  49. class ReplicateTextToSpeechTask(ReplicateTask):
  50. def __init__(self):
  51. super().__init__("text-to-speech")
  52. def _prepare_payload_as_dict(
  53. self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
  54. ) -> dict | None:
  55. payload: dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore
  56. payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS
  57. return payload
  58. class ReplicateAutomaticSpeechRecognitionTask(ReplicateTask):
  59. def __init__(self) -> None:
  60. super().__init__("automatic-speech-recognition")
  61. def _prepare_payload_as_dict(
  62. self,
  63. inputs: Any,
  64. parameters: dict,
  65. provider_mapping_info: InferenceProviderMapping,
  66. ) -> dict | None:
  67. mapped_model = provider_mapping_info.provider_id
  68. audio_url = _as_url(inputs, default_mime_type="audio/wav")
  69. payload: dict[str, Any] = {
  70. "input": {
  71. **{"audio": audio_url},
  72. **filter_none(parameters),
  73. }
  74. }
  75. if ":" in mapped_model:
  76. payload["version"] = mapped_model.split(":", 1)[1]
  77. return payload
  78. def get_response(self, response: bytes | dict, request_params: RequestParameters | None = None) -> Any:
  79. response_dict = _as_dict(response)
  80. output = response_dict.get("output")
  81. if isinstance(output, str):
  82. return {"text": output}
  83. if isinstance(output, list) and output:
  84. first_item = output[0]
  85. if isinstance(first_item, str):
  86. return {"text": first_item}
  87. if isinstance(first_item, dict):
  88. output = first_item
  89. text: str | None = None
  90. if isinstance(output, dict):
  91. transcription = output.get("transcription")
  92. if isinstance(transcription, str):
  93. text = transcription
  94. translation = output.get("translation")
  95. if isinstance(translation, str):
  96. text = translation
  97. txt_file = output.get("txt_file")
  98. if isinstance(txt_file, str):
  99. text_response = get_session().get(txt_file)
  100. text_response.raise_for_status()
  101. text = text_response.text
  102. if text is not None:
  103. return {"text": text}
  104. raise ValueError("Received malformed response from Replicate automatic-speech-recognition API")
  105. class ReplicateImageToImageTask(ReplicateTask):
  106. def __init__(self):
  107. super().__init__("image-to-image")
  108. def _prepare_payload_as_dict(
  109. self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
  110. ) -> dict | None:
  111. image_url = _as_url(inputs, default_mime_type="image/jpeg")
  112. # Different Replicate models expect the image in different keys
  113. payload: dict[str, Any] = {
  114. "input": {
  115. "image": image_url,
  116. "images": [image_url],
  117. "input_image": image_url,
  118. "input_images": [image_url],
  119. **filter_none(parameters),
  120. }
  121. }
  122. mapped_model = provider_mapping_info.provider_id
  123. if ":" in mapped_model:
  124. version = mapped_model.split(":", 1)[1]
  125. payload["version"] = version
  126. return payload