from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Union from ray.autoscaler._private.cli_logger import cli_logger class CreateClusterEvent(Enum): """Events to track in ray.autoscaler.sdk.create_or_update_cluster. Attributes: up_started : Invoked at the beginning of create_or_update_cluster. ssh_keypair_downloaded : Invoked when the ssh keypair is downloaded. cluster_booting_started : Invoked when when the cluster booting starts. acquiring_new_head_node : Invoked before the head node is acquired. head_node_acquired : Invoked after the head node is acquired. ssh_control_acquired : Invoked when the node is being updated. run_initialization_cmd : Invoked before all initialization commands are called and again before each initialization command. run_setup_cmd : Invoked before all setup commands are called and again before each setup command. start_ray_runtime : Invoked before ray start commands are run. start_ray_runtime_completed : Invoked after ray start commands are run. cluster_booting_completed : Invoked after cluster booting is completed. """ up_started = auto() ssh_keypair_downloaded = auto() cluster_booting_started = auto() acquiring_new_head_node = auto() head_node_acquired = auto() ssh_control_acquired = auto() run_initialization_cmd = auto() run_setup_cmd = auto() start_ray_runtime = auto() start_ray_runtime_completed = auto() cluster_booting_completed = auto() class _EventSystem: """Event system that handles storing and calling callbacks for events. Attributes: callback_map (Dict[str, List[Callable]]) : Stores list of callbacks for events when registered. """ def __init__(self): self.callback_map = {} def add_callback_handler( self, event: str, callback: Union[Callable[[Dict], None], List[Callable[[Dict], None]]], ): """Stores callback handler for event. Args: event: Event that callback should be called on. See CreateClusterEvent for details on the events available to be registered against. callback (Callable[[Dict], None]): Callable object that is invoked when specified event occurs. """ if event not in CreateClusterEvent.__members__.values(): cli_logger.warning( f"{event} is not currently tracked, and this" " callback will not be invoked." ) self.callback_map.setdefault(event, []).extend( [callback] if type(callback) is not list else callback ) def execute_callback( self, event: CreateClusterEvent, event_data: Optional[Dict[str, Any]] = None ): """Executes all callbacks for event. Args: event: Event that is invoked. See CreateClusterEvent for details on the available events. event_data (Dict[str, Any]): Argument that is passed to each callable object stored for this particular event. """ if event_data is None: event_data = {} event_data["event_name"] = event if event in self.callback_map: for callback in self.callback_map[event]: callback(event_data) def clear_callbacks_for_event(self, event: str): """Clears stored callable objects for event. Args: event: Event that has callable objects stored in map. See CreateClusterEvent for details on the available events. """ if event in self.callback_map: del self.callback_map[event] global_event_system = _EventSystem()