sambanova.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from typing import Any
  2. from huggingface_hub.hf_api import InferenceProviderMapping
  3. from huggingface_hub.inference._common import RequestParameters, _as_dict
  4. from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none
  5. class SambanovaConversationalTask(BaseConversationalTask):
  6. def __init__(self):
  7. super().__init__(provider="sambanova", base_url="https://api.sambanova.ai")
  8. def _prepare_payload_as_dict(
  9. self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
  10. ) -> dict | None:
  11. response_format_config = parameters.get("response_format")
  12. if isinstance(response_format_config, dict):
  13. if response_format_config.get("type") == "json_schema":
  14. json_schema_config = response_format_config.get("json_schema", {})
  15. strict = json_schema_config.get("strict")
  16. if isinstance(json_schema_config, dict) and (strict is True or strict is None):
  17. json_schema_config["strict"] = False
  18. payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
  19. return payload
  20. class SambanovaFeatureExtractionTask(TaskProviderHelper):
  21. def __init__(self):
  22. super().__init__(provider="sambanova", base_url="https://api.sambanova.ai", task="feature-extraction")
  23. def _prepare_route(self, mapped_model: str, api_key: str) -> str:
  24. return "/v1/embeddings"
  25. def _prepare_payload_as_dict(
  26. self, inputs: Any, parameters: dict, provider_mapping_info: InferenceProviderMapping
  27. ) -> dict | None:
  28. parameters = filter_none(parameters)
  29. return {"input": inputs, "model": provider_mapping_info.provider_id, **parameters}
  30. def get_response(self, response: bytes | dict, request_params: RequestParameters | None = None) -> Any:
  31. embeddings = _as_dict(response)["data"]
  32. return [embedding["embedding"] for embedding in embeddings]