constants.py 893 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. from typing import Literal
  2. class TaskType:
  3. @classmethod
  4. def values(cls):
  5. """Return a set of all valid task type values."""
  6. return {
  7. value
  8. for key, value in vars(cls).items()
  9. if not key.startswith("_") and isinstance(value, str)
  10. }
  11. class vLLMTaskType(TaskType):
  12. """The type of task to run on the vLLM engine."""
  13. # Generate text.
  14. GENERATE = "generate"
  15. # Generate embeddings.
  16. EMBED = "embed"
  17. # Classification (e.g., sequence classification models).
  18. CLASSIFY = "classify"
  19. # Scoring (e.g., cross-encoder models).
  20. SCORE = "score"
  21. class SGLangTaskType(TaskType):
  22. """The type of task to run on the SGLang engine."""
  23. # Generate text.
  24. GENERATE = "generate"
  25. TypeVLLMTaskType = Literal[tuple(vLLMTaskType.values())]
  26. TypeSGLangTaskType = Literal[tuple(SGLangTaskType.values())]