| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- """Scheduler for classic wandb Sweeps."""
- from __future__ import annotations
- import logging
- from pprint import pformat as pf
- from typing import Any
- import wandb
- from wandb.sdk.launch.sweeps import SweepNotFoundError
- from wandb.sdk.launch.sweeps.scheduler import LOG_PREFIX, RunState, Scheduler, SweepRun
- _logger = logging.getLogger(__name__)
- class SweepScheduler(Scheduler):
- """A controller/agent that populates a Launch RunQueue from a sweeps RunQueue."""
- def __init__(
- self,
- *args: Any,
- **kwargs: Any,
- ):
- super().__init__(*args, **kwargs)
- def _get_next_sweep_run(self, worker_id: int) -> SweepRun | None:
- """Called by the main scheduler execution loop.
- Expected to return a properly formatted SweepRun if the scheduler
- is alive, or None and set the appropriate scheduler state:
- FAILED: self.fail_sweep()
- STOPPED: self.stop_sweep()
- """
- commands: list[dict[str, Any]] = self._get_sweep_commands(worker_id)
- for command in commands:
- # The command "type" can be one of "run", "resume", "stop", "exit"
- _type = command.get("type")
- if _type in ["exit", "stop"]:
- self.stop_sweep()
- return None
- if _type not in ["run", "resume"]:
- self.fail_sweep(f"AgentHeartbeat unknown command: {_type}")
- _run_id: str | None = command.get("run_id")
- if not _run_id:
- self.fail_sweep(f"No run id in agent heartbeat: {command}")
- return None
- if _run_id in self._runs:
- wandb.termlog(f"{LOG_PREFIX}Skipping duplicate run: {_run_id}")
- continue
- return SweepRun(
- id=_run_id,
- state=RunState.PENDING,
- args=command.get("args", {}),
- logs=command.get("logs", []),
- worker_id=worker_id,
- )
- return None
- def _get_sweep_commands(self, worker_id: int) -> list[dict[str, Any]]:
- """Helper to receive sweep command from backend."""
- # AgentHeartbeat wants a Dict of runs which are running or queued
- _run_states: dict[str, bool] = {}
- for run_id, run in self._yield_runs():
- # Filter out runs that are from a different worker thread
- if run.worker_id == worker_id and run.state.is_alive:
- _run_states[run_id] = True
- _logger.debug(f"Sending states: \n{pf(_run_states)}\n")
- try:
- commands: list[dict[str, Any]] = self._api.agent_heartbeat(
- agent_id=self._workers[worker_id].agent_id,
- metrics={},
- run_states=_run_states,
- )
- except SweepNotFoundError:
- wandb.termerror(
- f"{LOG_PREFIX}Sweep was deleted or agent was not found. Stopping sweep."
- )
- self.stop_sweep()
- return []
- _logger.debug(f"AgentHeartbeat commands: \n{pf(commands)}\n")
- return commands
- def _exit(self) -> None:
- pass
- def _poll(self) -> None:
- _logger.debug(f"_poll. _runs: {self._runs}")
- def _load_state(self) -> None:
- pass
- def _save_state(self) -> None:
- pass
|