import os import pickle import time import numpy as np from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig from ray.rllib.utils.annotations import override from ray.tune import result as tune_result class _MockTrainer(Algorithm): """Mock Algorithm for use in tests.""" @classmethod @override(Algorithm) def get_default_config(cls) -> AlgorithmConfig: return ( AlgorithmConfig() .framework("tf") .update_from_dict( { "mock_error": False, "persistent_error": False, "test_variable": 1, "user_checkpoint_freq": 0, "sleep": 0, } ) ) @classmethod def default_resource_request(cls, config: AlgorithmConfig): return None @override(Algorithm) def setup(self, config): self.callbacks = self.config.callbacks_class() # Add needed properties. self.info = None self.restored = False @override(Algorithm) def step(self): if ( self.config.mock_error and self.iteration == 1 and (self.config.persistent_error or not self.restored) ): raise Exception("mock error") if self.config.sleep: time.sleep(self.config.sleep) result = dict( episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={} ) if self.config.user_checkpoint_freq > 0 and self.iteration > 0: if self.iteration % self.config.user_checkpoint_freq == 0: result.update({tune_result.SHOULD_CHECKPOINT: True}) return result @override(Algorithm) def save_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "mock_agent.pkl") with open(path, "wb") as f: pickle.dump(self.info, f) @override(Algorithm) def load_checkpoint(self, checkpoint_dir): path = os.path.join(checkpoint_dir, "mock_agent.pkl") with open(path, "rb") as f: info = pickle.load(f) self.info = info self.restored = True @staticmethod @override(Algorithm) def _get_env_id_and_creator(env_specifier, config): # No env to register. return None, None def set_info(self, info): self.info = info return info def get_info(self, sess=None): return self.info class _SigmoidFakeData(_MockTrainer): """Algorithm that returns sigmoid learning curves. This can be helpful for evaluating early stopping algorithms.""" @classmethod @override(Algorithm) def get_default_config(cls) -> AlgorithmConfig: return AlgorithmConfig().update_from_dict( { "width": 100, "height": 100, "offset": 0, "iter_time": 10, "iter_timesteps": 1, } ) def step(self): i = max(0, self.iteration - self.config.offset) v = np.tanh(float(i) / self.config.width) v *= self.config.height return dict( episode_reward_mean=v, episode_len_mean=v, timesteps_this_iter=self.config.iter_timesteps, time_this_iter_s=self.config.iter_time, info={}, ) class _ParameterTuningTrainer(_MockTrainer): @classmethod @override(Algorithm) def get_default_config(cls) -> AlgorithmConfig: return AlgorithmConfig().update_from_dict( { "reward_amt": 10, "dummy_param": 10, "dummy_param2": 15, "iter_time": 10, "iter_timesteps": 1, } ) def step(self): return dict( episode_reward_mean=self.config.reward_amt * self.iteration, episode_len_mean=self.config.reward_amt, timesteps_this_iter=self.config.iter_timesteps, time_this_iter_s=self.config.iter_time, info={}, ) def _algorithm_import_failed(trace): """Returns dummy Algorithm class for if PyTorch etc. is not installed.""" class _AlgorithmImportFailed(Algorithm): _name = "AlgorithmImportFailed" def setup(self, config): raise ImportError(trace) return _AlgorithmImportFailed