scheduler_sweep.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """Scheduler for classic wandb Sweeps."""
  2. from __future__ import annotations
  3. import logging
  4. from pprint import pformat as pf
  5. from typing import Any
  6. import wandb
  7. from wandb.sdk.launch.sweeps import SweepNotFoundError
  8. from wandb.sdk.launch.sweeps.scheduler import LOG_PREFIX, RunState, Scheduler, SweepRun
  9. _logger = logging.getLogger(__name__)
  10. class SweepScheduler(Scheduler):
  11. """A controller/agent that populates a Launch RunQueue from a sweeps RunQueue."""
  12. def __init__(
  13. self,
  14. *args: Any,
  15. **kwargs: Any,
  16. ):
  17. super().__init__(*args, **kwargs)
  18. def _get_next_sweep_run(self, worker_id: int) -> SweepRun | None:
  19. """Called by the main scheduler execution loop.
  20. Expected to return a properly formatted SweepRun if the scheduler
  21. is alive, or None and set the appropriate scheduler state:
  22. FAILED: self.fail_sweep()
  23. STOPPED: self.stop_sweep()
  24. """
  25. commands: list[dict[str, Any]] = self._get_sweep_commands(worker_id)
  26. for command in commands:
  27. # The command "type" can be one of "run", "resume", "stop", "exit"
  28. _type = command.get("type")
  29. if _type in ["exit", "stop"]:
  30. self.stop_sweep()
  31. return None
  32. if _type not in ["run", "resume"]:
  33. self.fail_sweep(f"AgentHeartbeat unknown command: {_type}")
  34. _run_id: str | None = command.get("run_id")
  35. if not _run_id:
  36. self.fail_sweep(f"No run id in agent heartbeat: {command}")
  37. return None
  38. if _run_id in self._runs:
  39. wandb.termlog(f"{LOG_PREFIX}Skipping duplicate run: {_run_id}")
  40. continue
  41. return SweepRun(
  42. id=_run_id,
  43. state=RunState.PENDING,
  44. args=command.get("args", {}),
  45. logs=command.get("logs", []),
  46. worker_id=worker_id,
  47. )
  48. return None
  49. def _get_sweep_commands(self, worker_id: int) -> list[dict[str, Any]]:
  50. """Helper to receive sweep command from backend."""
  51. # AgentHeartbeat wants a Dict of runs which are running or queued
  52. _run_states: dict[str, bool] = {}
  53. for run_id, run in self._yield_runs():
  54. # Filter out runs that are from a different worker thread
  55. if run.worker_id == worker_id and run.state.is_alive:
  56. _run_states[run_id] = True
  57. _logger.debug(f"Sending states: \n{pf(_run_states)}\n")
  58. try:
  59. commands: list[dict[str, Any]] = self._api.agent_heartbeat(
  60. agent_id=self._workers[worker_id].agent_id,
  61. metrics={},
  62. run_states=_run_states,
  63. )
  64. except SweepNotFoundError:
  65. wandb.termerror(
  66. f"{LOG_PREFIX}Sweep was deleted or agent was not found. Stopping sweep."
  67. )
  68. self.stop_sweep()
  69. return []
  70. _logger.debug(f"AgentHeartbeat commands: \n{pf(commands)}\n")
  71. return commands
  72. def _exit(self) -> None:
  73. pass
  74. def _poll(self) -> None:
  75. _logger.debug(f"_poll. _runs: {self._runs}")
  76. def _load_state(self) -> None:
  77. pass
  78. def _save_state(self) -> None:
  79. pass