| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- from typing import Any, Callable, Dict, List, Optional
- from ray.rllib.callbacks.callbacks import RLlibCallback
- from ray.rllib.utils import force_list
- from ray.rllib.utils.annotations import OldAPIStack
- def make_callback(
- callback_name: str,
- callbacks_objects: Optional[List[RLlibCallback]] = None,
- callbacks_functions: Optional[List[Callable]] = None,
- *,
- args: List[Any] = None,
- kwargs: Dict[str, Any] = None,
- ) -> None:
- """Calls an RLlibCallback method or a registered callback callable.
- Args:
- callback_name: The name of the callback method or key, for example:
- "on_episode_start" or "on_train_result".
- callbacks_objects: The RLlibCallback object or list of RLlibCallback objects
- to call the `callback_name` method on (in the order they appear in the
- list).
- callbacks_functions: The callable or list of callables to call
- (in the order they appear in the list).
- args: Call args to pass to the method/callable calls.
- kwargs: Call kwargs to pass to the method/callable calls.
- """
- # Loop through all available RLlibCallback objects.
- callbacks_objects = force_list(callbacks_objects)
- for callback_obj in callbacks_objects:
- getattr(callback_obj, callback_name)(*(args or ()), **(kwargs or {}))
- # Loop through all available RLlibCallback objects.
- callbacks_functions = force_list(callbacks_functions)
- for callback_fn in callbacks_functions:
- callback_fn(*(args or ()), **(kwargs or {}))
- @OldAPIStack
- def _make_multi_callbacks(callback_class_list):
- class _MultiCallbacks(RLlibCallback):
- IS_CALLBACK_CONTAINER = True
- def __init__(self):
- super().__init__()
- self._callback_list = [
- callback_class() for callback_class in callback_class_list
- ]
- def on_algorithm_init(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_algorithm_init(**kwargs)
- def on_workers_recreated(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_workers_recreated(**kwargs)
- # Only on new API stack.
- def on_env_runners_recreated(self, **kwargs) -> None:
- pass
- def on_offline_eval_runners_recreated(self, **kwargs) -> None:
- pass
- def on_checkpoint_loaded(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_checkpoint_loaded(**kwargs)
- def on_create_policy(self, *, policy_id, policy) -> None:
- for callback in self._callback_list:
- callback.on_create_policy(policy_id=policy_id, policy=policy)
- def on_environment_created(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_environment_created(**kwargs)
- def on_sub_environment_created(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_sub_environment_created(**kwargs)
- def on_episode_created(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_episode_created(**kwargs)
- def on_episode_start(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_episode_start(**kwargs)
- def on_episode_step(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_episode_step(**kwargs)
- def on_episode_end(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_episode_end(**kwargs)
- def on_evaluate_start(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_evaluate_start(**kwargs)
- def on_evaluate_end(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_evaluate_end(**kwargs)
- # TODO (simon, sven): Fix the test such that we can simply remove
- # these.
- def on_evaluate_offline_start(self, **kwargs):
- for callback in self._callback_list:
- callback.on_evaluate_offline_start(**kwargs)
- def on_evaluate_offline_end(self, **kwargs):
- for callback in self._callback_list:
- callback.on_evaluate_offline_end(**kwargs)
- def on_postprocess_trajectory(
- self,
- *,
- worker,
- episode,
- agent_id,
- policy_id,
- policies,
- postprocessed_batch,
- original_batches,
- **kwargs,
- ) -> None:
- for callback in self._callback_list:
- callback.on_postprocess_trajectory(
- worker=worker,
- episode=episode,
- agent_id=agent_id,
- policy_id=policy_id,
- policies=policies,
- postprocessed_batch=postprocessed_batch,
- original_batches=original_batches,
- **kwargs,
- )
- def on_sample_end(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_sample_end(**kwargs)
- def on_learn_on_batch(
- self, *, policy, train_batch, result: dict, **kwargs
- ) -> None:
- for callback in self._callback_list:
- callback.on_learn_on_batch(
- policy=policy, train_batch=train_batch, result=result, **kwargs
- )
- def on_train_result(self, **kwargs) -> None:
- for callback in self._callback_list:
- callback.on_train_result(**kwargs)
- return _MultiCallbacks
|