mock.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import os
  2. import pickle
  3. import time
  4. import numpy as np
  5. from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig
  6. from ray.rllib.utils.annotations import override
  7. from ray.tune import result as tune_result
  8. class _MockTrainer(Algorithm):
  9. """Mock Algorithm for use in tests."""
  10. @classmethod
  11. @override(Algorithm)
  12. def get_default_config(cls) -> AlgorithmConfig:
  13. return (
  14. AlgorithmConfig()
  15. .framework("tf")
  16. .update_from_dict(
  17. {
  18. "mock_error": False,
  19. "persistent_error": False,
  20. "test_variable": 1,
  21. "user_checkpoint_freq": 0,
  22. "sleep": 0,
  23. }
  24. )
  25. )
  26. @classmethod
  27. def default_resource_request(cls, config: AlgorithmConfig):
  28. return None
  29. @override(Algorithm)
  30. def setup(self, config):
  31. self.callbacks = self.config.callbacks_class()
  32. # Add needed properties.
  33. self.info = None
  34. self.restored = False
  35. @override(Algorithm)
  36. def step(self):
  37. if (
  38. self.config.mock_error
  39. and self.iteration == 1
  40. and (self.config.persistent_error or not self.restored)
  41. ):
  42. raise Exception("mock error")
  43. if self.config.sleep:
  44. time.sleep(self.config.sleep)
  45. result = dict(
  46. episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}
  47. )
  48. if self.config.user_checkpoint_freq > 0 and self.iteration > 0:
  49. if self.iteration % self.config.user_checkpoint_freq == 0:
  50. result.update({tune_result.SHOULD_CHECKPOINT: True})
  51. return result
  52. @override(Algorithm)
  53. def save_checkpoint(self, checkpoint_dir):
  54. path = os.path.join(checkpoint_dir, "mock_agent.pkl")
  55. with open(path, "wb") as f:
  56. pickle.dump(self.info, f)
  57. @override(Algorithm)
  58. def load_checkpoint(self, checkpoint_dir):
  59. path = os.path.join(checkpoint_dir, "mock_agent.pkl")
  60. with open(path, "rb") as f:
  61. info = pickle.load(f)
  62. self.info = info
  63. self.restored = True
  64. @staticmethod
  65. @override(Algorithm)
  66. def _get_env_id_and_creator(env_specifier, config):
  67. # No env to register.
  68. return None, None
  69. def set_info(self, info):
  70. self.info = info
  71. return info
  72. def get_info(self, sess=None):
  73. return self.info
  74. class _SigmoidFakeData(_MockTrainer):
  75. """Algorithm that returns sigmoid learning curves.
  76. This can be helpful for evaluating early stopping algorithms."""
  77. @classmethod
  78. @override(Algorithm)
  79. def get_default_config(cls) -> AlgorithmConfig:
  80. return AlgorithmConfig().update_from_dict(
  81. {
  82. "width": 100,
  83. "height": 100,
  84. "offset": 0,
  85. "iter_time": 10,
  86. "iter_timesteps": 1,
  87. }
  88. )
  89. def step(self):
  90. i = max(0, self.iteration - self.config.offset)
  91. v = np.tanh(float(i) / self.config.width)
  92. v *= self.config.height
  93. return dict(
  94. episode_reward_mean=v,
  95. episode_len_mean=v,
  96. timesteps_this_iter=self.config.iter_timesteps,
  97. time_this_iter_s=self.config.iter_time,
  98. info={},
  99. )
  100. class _ParameterTuningTrainer(_MockTrainer):
  101. @classmethod
  102. @override(Algorithm)
  103. def get_default_config(cls) -> AlgorithmConfig:
  104. return AlgorithmConfig().update_from_dict(
  105. {
  106. "reward_amt": 10,
  107. "dummy_param": 10,
  108. "dummy_param2": 15,
  109. "iter_time": 10,
  110. "iter_timesteps": 1,
  111. }
  112. )
  113. def step(self):
  114. return dict(
  115. episode_reward_mean=self.config.reward_amt * self.iteration,
  116. episode_len_mean=self.config.reward_amt,
  117. timesteps_this_iter=self.config.iter_timesteps,
  118. time_this_iter_s=self.config.iter_time,
  119. info={},
  120. )
  121. def _algorithm_import_failed(trace):
  122. """Returns dummy Algorithm class for if PyTorch etc. is not installed."""
  123. class _AlgorithmImportFailed(Algorithm):
  124. _name = "AlgorithmImportFailed"
  125. def setup(self, config):
  126. raise ImportError(trace)
  127. return _AlgorithmImportFailed