novita.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from typing import Any
  2. from huggingface_hub.hf_api import InferenceProviderMapping
  3. from huggingface_hub.inference._common import RequestParameters, _as_dict
  4. from huggingface_hub.inference._providers._common import (
  5. BaseConversationalTask,
  6. BaseTextGenerationTask,
  7. TaskProviderHelper,
  8. filter_none,
  9. )
  10. from huggingface_hub.utils import get_session
  11. _PROVIDER = "novita"
  12. _BASE_URL = "https://api.novita.ai"
  13. class NovitaTextGenerationTask(BaseTextGenerationTask):
  14. def __init__(self):
  15. super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
  16. def _prepare_route(self, mapped_model: str, api_key: str) -> str:
  17. # there is no v1/ route for novita
  18. return "/v3/openai/completions"
  19. def get_response(self, response: bytes | dict, request_params: RequestParameters | None = None) -> Any:
  20. output = _as_dict(response)["choices"][0]
  21. return {
  22. "generated_text": output["text"],
  23. "details": {
  24. "finish_reason": output.get("finish_reason"),
  25. "seed": output.get("seed"),
  26. },
  27. }
  28. class NovitaConversationalTask(BaseConversationalTask):
  29. def __init__(self):
  30. super().__init__(provider=_PROVIDER, base_url=_BASE_URL)
  31. def _prepare_route(self, mapped_model: str, api_key: str) -> str:
  32. # there is no v1/ route for novita
  33. return "/v3/openai/chat/completions"
  34. class NovitaTextToVideoTask(TaskProviderHelper):
  35. def __init__(self):
  36. super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task="text-to-video")
  37. def _prepare_route(self, mapped_model: str, api_key: str) -> str:
  38. return f"/v3/hf/{mapped_model}"
  39. def _prepare_payload_as_dict(
  40. self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
  41. ) -> dict | None:
  42. return {"prompt": inputs, **filter_none(parameters)}
  43. def get_response(self, response: bytes | dict, request_params: RequestParameters | None = None) -> Any:
  44. response_dict = _as_dict(response)
  45. if not (
  46. isinstance(response_dict, dict)
  47. and "video" in response_dict
  48. and isinstance(response_dict["video"], dict)
  49. and "video_url" in response_dict["video"]
  50. ):
  51. raise ValueError("Expected response format: { 'video': { 'video_url': string } }")
  52. video_url = response_dict["video"]["video_url"]
  53. return get_session().get(video_url).content