utils.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import Any, Callable, Dict, List, Optional
  2. from ray.rllib.callbacks.callbacks import RLlibCallback
  3. from ray.rllib.utils import force_list
  4. from ray.rllib.utils.annotations import OldAPIStack
  5. def make_callback(
  6. callback_name: str,
  7. callbacks_objects: Optional[List[RLlibCallback]] = None,
  8. callbacks_functions: Optional[List[Callable]] = None,
  9. *,
  10. args: List[Any] = None,
  11. kwargs: Dict[str, Any] = None,
  12. ) -> None:
  13. """Calls an RLlibCallback method or a registered callback callable.
  14. Args:
  15. callback_name: The name of the callback method or key, for example:
  16. "on_episode_start" or "on_train_result".
  17. callbacks_objects: The RLlibCallback object or list of RLlibCallback objects
  18. to call the `callback_name` method on (in the order they appear in the
  19. list).
  20. callbacks_functions: The callable or list of callables to call
  21. (in the order they appear in the list).
  22. args: Call args to pass to the method/callable calls.
  23. kwargs: Call kwargs to pass to the method/callable calls.
  24. """
  25. # Loop through all available RLlibCallback objects.
  26. callbacks_objects = force_list(callbacks_objects)
  27. for callback_obj in callbacks_objects:
  28. getattr(callback_obj, callback_name)(*(args or ()), **(kwargs or {}))
  29. # Loop through all available RLlibCallback objects.
  30. callbacks_functions = force_list(callbacks_functions)
  31. for callback_fn in callbacks_functions:
  32. callback_fn(*(args or ()), **(kwargs or {}))
  33. @OldAPIStack
  34. def _make_multi_callbacks(callback_class_list):
  35. class _MultiCallbacks(RLlibCallback):
  36. IS_CALLBACK_CONTAINER = True
  37. def __init__(self):
  38. super().__init__()
  39. self._callback_list = [
  40. callback_class() for callback_class in callback_class_list
  41. ]
  42. def on_algorithm_init(self, **kwargs) -> None:
  43. for callback in self._callback_list:
  44. callback.on_algorithm_init(**kwargs)
  45. def on_workers_recreated(self, **kwargs) -> None:
  46. for callback in self._callback_list:
  47. callback.on_workers_recreated(**kwargs)
  48. # Only on new API stack.
  49. def on_env_runners_recreated(self, **kwargs) -> None:
  50. pass
  51. def on_offline_eval_runners_recreated(self, **kwargs) -> None:
  52. pass
  53. def on_checkpoint_loaded(self, **kwargs) -> None:
  54. for callback in self._callback_list:
  55. callback.on_checkpoint_loaded(**kwargs)
  56. def on_create_policy(self, *, policy_id, policy) -> None:
  57. for callback in self._callback_list:
  58. callback.on_create_policy(policy_id=policy_id, policy=policy)
  59. def on_environment_created(self, **kwargs) -> None:
  60. for callback in self._callback_list:
  61. callback.on_environment_created(**kwargs)
  62. def on_sub_environment_created(self, **kwargs) -> None:
  63. for callback in self._callback_list:
  64. callback.on_sub_environment_created(**kwargs)
  65. def on_episode_created(self, **kwargs) -> None:
  66. for callback in self._callback_list:
  67. callback.on_episode_created(**kwargs)
  68. def on_episode_start(self, **kwargs) -> None:
  69. for callback in self._callback_list:
  70. callback.on_episode_start(**kwargs)
  71. def on_episode_step(self, **kwargs) -> None:
  72. for callback in self._callback_list:
  73. callback.on_episode_step(**kwargs)
  74. def on_episode_end(self, **kwargs) -> None:
  75. for callback in self._callback_list:
  76. callback.on_episode_end(**kwargs)
  77. def on_evaluate_start(self, **kwargs) -> None:
  78. for callback in self._callback_list:
  79. callback.on_evaluate_start(**kwargs)
  80. def on_evaluate_end(self, **kwargs) -> None:
  81. for callback in self._callback_list:
  82. callback.on_evaluate_end(**kwargs)
  83. # TODO (simon, sven): Fix the test such that we can simply remove
  84. # these.
  85. def on_evaluate_offline_start(self, **kwargs):
  86. for callback in self._callback_list:
  87. callback.on_evaluate_offline_start(**kwargs)
  88. def on_evaluate_offline_end(self, **kwargs):
  89. for callback in self._callback_list:
  90. callback.on_evaluate_offline_end(**kwargs)
  91. def on_postprocess_trajectory(
  92. self,
  93. *,
  94. worker,
  95. episode,
  96. agent_id,
  97. policy_id,
  98. policies,
  99. postprocessed_batch,
  100. original_batches,
  101. **kwargs,
  102. ) -> None:
  103. for callback in self._callback_list:
  104. callback.on_postprocess_trajectory(
  105. worker=worker,
  106. episode=episode,
  107. agent_id=agent_id,
  108. policy_id=policy_id,
  109. policies=policies,
  110. postprocessed_batch=postprocessed_batch,
  111. original_batches=original_batches,
  112. **kwargs,
  113. )
  114. def on_sample_end(self, **kwargs) -> None:
  115. for callback in self._callback_list:
  116. callback.on_sample_end(**kwargs)
  117. def on_learn_on_batch(
  118. self, *, policy, train_batch, result: dict, **kwargs
  119. ) -> None:
  120. for callback in self._callback_list:
  121. callback.on_learn_on_batch(
  122. policy=policy, train_batch=train_batch, result=result, **kwargs
  123. )
  124. def on_train_result(self, **kwargs) -> None:
  125. for callback in self._callback_list:
  126. callback.on_train_result(**kwargs)
  127. return _MultiCallbacks