offline_evaluator.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import abc
  2. import logging
  3. import os
  4. from typing import Any, Dict
  5. from ray.data import Dataset
  6. from ray.rllib.policy import Policy
  7. from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI
  8. from ray.rllib.utils.typing import SampleBatchType
  9. logger = logging.getLogger(__name__)
  10. @DeveloperAPI
  11. class OfflineEvaluator(abc.ABC):
  12. """Interface for an offline evaluator of a policy"""
  13. @DeveloperAPI
  14. def __init__(self, policy: Policy, **kwargs):
  15. """Initializes an OffPolicyEstimator instance.
  16. Args:
  17. policy: Policy to evaluate.
  18. kwargs: forward compatibility placeholder.
  19. """
  20. self.policy = policy
  21. @abc.abstractmethod
  22. @DeveloperAPI
  23. def estimate(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]:
  24. """Returns the evaluation results for the given batch of episodes.
  25. Args:
  26. batch: The batch to evaluate.
  27. kwargs: forward compatibility placeholder.
  28. Returns:
  29. The evaluation done on the given batch. The returned
  30. dict can be any arbitrary mapping of strings to metrics.
  31. """
  32. raise NotImplementedError
  33. @DeveloperAPI
  34. def train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]:
  35. """Sometimes you need to train a model inside an evaluator. This method
  36. abstracts the training process.
  37. Args:
  38. batch: SampleBatch to train on
  39. kwargs: forward compatibility placeholder.
  40. Returns:
  41. Any optional metrics to return from the evaluator
  42. """
  43. return {}
  44. @ExperimentalAPI
  45. def estimate_on_dataset(
  46. self,
  47. dataset: Dataset,
  48. *,
  49. n_parallelism: int = os.cpu_count(),
  50. ) -> Dict[str, Any]:
  51. """Calculates the estimate of the metrics based on the given offline dataset.
  52. Typically, the dataset is passed through only once via n_parallel tasks in
  53. mini-batches to improve the run-time of metric estimation.
  54. Args:
  55. dataset: The ray dataset object to do offline evaluation on.
  56. n_parallelism: The number of parallelism to use for the computation.
  57. Returns:
  58. Dict[str, Any]: A dictionary of the estimated values.
  59. """