mock_trainable.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import json
  2. import os
  3. import time
  4. import numpy as np
  5. from ray.tune import Trainable
  6. MOCK_TRAINABLE_NAME = "mock_trainable"
  7. MOCK_ERROR_KEY = "mock_error"
  8. class MyTrainableClass(Trainable):
  9. """Example agent whose learning curve is a random sigmoid.
  10. The dummy hyperparameters "width" and "height" determine the slope and
  11. maximum reward value reached.
  12. """
  13. def setup(self, config):
  14. self._sleep_time = config.get("sleep", 0)
  15. self._mock_error = config.get(MOCK_ERROR_KEY, False)
  16. self._persistent_error = config.get("persistent_error", False)
  17. self.timestep = 0
  18. self.restored = False
  19. def step(self):
  20. if (
  21. self._mock_error
  22. and self.timestep > 0 # allow at least 1 successful checkpoint.
  23. and (self._persistent_error or not self.restored)
  24. ):
  25. raise RuntimeError(f"Failing on purpose! {self.timestep=}")
  26. if self._sleep_time > 0:
  27. time.sleep(self._sleep_time)
  28. self.timestep += 1
  29. v = np.tanh(float(self.timestep) / self.config.get("width", 1))
  30. v *= self.config.get("height", 1)
  31. # Here we use `episode_reward_mean`, but you can also report other
  32. # objectives such as loss or accuracy.
  33. return {"episode_reward_mean": v}
  34. def save_checkpoint(self, checkpoint_dir):
  35. path = os.path.join(checkpoint_dir, "checkpoint")
  36. with open(path, "w") as f:
  37. f.write(json.dumps({"timestep": self.timestep}))
  38. def load_checkpoint(self, checkpoint_dir):
  39. path = os.path.join(checkpoint_dir, "checkpoint")
  40. with open(path, "r") as f:
  41. self.timestep = json.loads(f.read())["timestep"]
  42. self.restored = True
  43. def register_mock_trainable():
  44. from ray.tune import register_trainable
  45. register_trainable(MOCK_TRAINABLE_NAME, MyTrainableClass)