| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- import abc
- import logging
- import os
- from typing import Any, Dict
- from ray.data import Dataset
- from ray.rllib.policy import Policy
- from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI
- from ray.rllib.utils.typing import SampleBatchType
- logger = logging.getLogger(__name__)
- @DeveloperAPI
- class OfflineEvaluator(abc.ABC):
- """Interface for an offline evaluator of a policy"""
- @DeveloperAPI
- def __init__(self, policy: Policy, **kwargs):
- """Initializes an OffPolicyEstimator instance.
- Args:
- policy: Policy to evaluate.
- kwargs: forward compatibility placeholder.
- """
- self.policy = policy
- @abc.abstractmethod
- @DeveloperAPI
- def estimate(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]:
- """Returns the evaluation results for the given batch of episodes.
- Args:
- batch: The batch to evaluate.
- kwargs: forward compatibility placeholder.
- Returns:
- The evaluation done on the given batch. The returned
- dict can be any arbitrary mapping of strings to metrics.
- """
- raise NotImplementedError
- @DeveloperAPI
- def train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]:
- """Sometimes you need to train a model inside an evaluator. This method
- abstracts the training process.
- Args:
- batch: SampleBatch to train on
- kwargs: forward compatibility placeholder.
- Returns:
- Any optional metrics to return from the evaluator
- """
- return {}
- @ExperimentalAPI
- def estimate_on_dataset(
- self,
- dataset: Dataset,
- *,
- n_parallelism: int = os.cpu_count(),
- ) -> Dict[str, Any]:
- """Calculates the estimate of the metrics based on the given offline dataset.
- Typically, the dataset is passed through only once via n_parallel tasks in
- mini-batches to improve the run-time of metric estimation.
- Args:
- dataset: The ray dataset object to do offline evaluation on.
- n_parallelism: The number of parallelism to use for the computation.
- Returns:
- Dict[str, Any]: A dictionary of the estimated values.
- """
|