| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- """Agent - Agent object.
- Manage wandb agent.
- """
- import ctypes
- import logging
- import os
- import queue
- import socket
- import sys
- import threading
- import time
- import traceback
- import wandb
- from wandb.apis import InternalApi
- from wandb.sdk.launch.sweeps import SweepNotFoundError
- from wandb.sdk.launch.sweeps import utils as sweep_utils
- from wandb.sdk.lib import config_util
- logger = logging.getLogger(__name__)
- def _terminate_thread(thread):
- if not thread.is_alive():
- return
- if hasattr(thread, "_terminated"):
- return
- thread._terminated = True
- tid = getattr(thread, "_thread_id", None)
- if tid is None:
- for k, v in threading._active.items():
- if v is thread:
- tid = k
- if tid is None:
- # This should never happen
- return
- logger.debug(f"Terminating thread: {tid}")
- res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
- ctypes.c_long(tid), ctypes.py_object(Exception)
- )
- if res == 0:
- # This should never happen
- return
- elif res != 1:
- # Revert
- logger.debug(f"Termination failed for thread {tid}")
- ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None)
- class Job:
- def __init__(self, command):
- self.command = command
- job_type = command.get("type")
- self.type = job_type
- self.run_id = command.get("run_id")
- self.config = command.get("args")
- def __repr__(self):
- if self.type == "run":
- return f"Job({self.run_id},{self.config})"
- elif self.type == "stop":
- return f"stop({self.run_id})"
- else:
- return "exit"
- class RunStatus:
- QUEUED = "QUEUED"
- RUNNING = "RUNNING"
- STOPPED = "STOPPED"
- ERRORED = "ERRORED"
- DONE = "DONE"
- class Agent:
- FLAPPING_MAX_SECONDS = 60
- FLAPPING_MAX_FAILURES = 3
- MAX_INITIAL_FAILURES = 5
- def __init__(
- self, sweep_id=None, project=None, entity=None, function=None, count=None
- ):
- self._sweep_path = sweep_id
- self._sweep_id = None
- self._project = project
- self._entity = entity
- self._function = function
- self._count = count
- # glob_config = os.path.expanduser('~/.config/wandb/settings')
- # loc_config = 'wandb/settings'
- # files = (glob_config, loc_config)
- self._api = InternalApi()
- self._agent_id = None
- self._max_initial_failures = wandb.env.get_agent_max_initial_failures(
- self.MAX_INITIAL_FAILURES
- )
- # if the directory to log to is not set, set it
- if os.environ.get(wandb.env.DIR) is None:
- os.environ[wandb.env.DIR] = os.path.abspath(os.getcwd())
- def _init(self):
- # These are not in constructor so that Agent instance can be rerun
- self._run_threads = {}
- self._run_status = {}
- self._queue = queue.Queue()
- self._exit_flag = False
- self._exceptions = {}
- self._start_time = time.time()
- def _register(self):
- logger.debug("Agent._register()")
- agent = self._api.register_agent(socket.gethostname(), sweep_id=self._sweep_id)
- self._agent_id = agent["id"]
- logger.debug(f"agent_id = {self._agent_id}")
- def _setup(self):
- logger.debug("Agent._setup()")
- self._init()
- parts = dict(entity=self._entity, project=self._project, name=self._sweep_path)
- err = sweep_utils.parse_sweep_id(parts)
- if err:
- wandb.termerror(err)
- return
- entity = parts.get("entity") or self._entity
- project = parts.get("project") or self._project
- sweep_id = parts.get("name") or self._sweep_id
- if sweep_id:
- os.environ[wandb.env.SWEEP_ID] = sweep_id
- if entity:
- wandb.env.set_entity(entity)
- if project:
- wandb.env.set_project(project)
- if sweep_id:
- self._sweep_id = sweep_id
- self._register()
- def _stop_run(self, run_id):
- logger.debug(f"Stopping run {run_id}.")
- self._run_status[run_id] = RunStatus.STOPPED
- thread = self._run_threads.get(run_id)
- if thread:
- _terminate_thread(thread)
- def _stop_all_runs(self):
- logger.debug("Stopping all runs.")
- for run in list(self._run_threads.keys()):
- self._stop_run(run)
- def _exit(self):
- self._stop_all_runs()
- self._exit_flag = True
- # _terminate_thread(self._main_thread)
- def _heartbeat(self):
- while True:
- if self._exit_flag:
- return
- # if not self._main_thread.is_alive():
- # return
- run_status = {
- run: True
- for run, status in self._run_status.items()
- if status in (RunStatus.QUEUED, RunStatus.RUNNING)
- }
- try:
- commands = self._api.agent_heartbeat(self._agent_id, {}, run_status)
- except SweepNotFoundError:
- wandb.termerror(
- "Sweep was deleted or agent was not found. Stopping sweep."
- )
- self._exit()
- return
- if commands:
- job = Job(commands[0])
- logger.debug(f"Job received: {job}")
- if job.type in ["run", "resume"]:
- self._queue.put(job)
- self._run_status[job.run_id] = RunStatus.QUEUED
- elif job.type == "stop":
- self._stop_run(job.run_id)
- elif job.type == "exit":
- self._exit()
- return
- time.sleep(5)
- def _run_jobs_from_queue(self):
- global _INSTANCES
- _INSTANCES += 1
- try:
- waiting = False
- count = 0
- while True:
- if self._exit_flag:
- return
- try:
- try:
- job = self._queue.get(timeout=5)
- if self._exit_flag:
- logger.debug("Exiting main loop due to exit flag.")
- wandb.termlog("Sweep Agent: Exiting.")
- return
- except queue.Empty:
- if not waiting:
- logger.debug("Paused.")
- wandb.termlog("Sweep Agent: Waiting for job.")
- waiting = True
- time.sleep(5)
- if self._exit_flag:
- logger.debug("Exiting main loop due to exit flag.")
- wandb.termlog("Sweep Agent: Exiting.")
- return
- continue
- if waiting:
- logger.debug("Resumed.")
- wandb.termlog("Job received.")
- waiting = False
- count += 1
- run_id = job.run_id
- if self._run_status[run_id] == RunStatus.STOPPED:
- continue
- logger.debug(f"Spawning new thread for run {run_id}.")
- thread = threading.Thread(target=self._run_job, args=(job,))
- self._run_threads[run_id] = thread
- thread.start()
- self._run_status[run_id] = RunStatus.RUNNING
- thread.join()
- logger.debug(f"Thread joined for run {run_id}.")
- if self._run_status[run_id] == RunStatus.RUNNING:
- self._run_status[run_id] = RunStatus.DONE
- elif self._run_status[run_id] == RunStatus.ERRORED:
- exc = self._exceptions[run_id]
- # Extract to reduce a decision point to avoid ruff c901
- log_str, term_str = _get_exception_logger_and_term_strs(exc)
- logger.error(f"Run {run_id} errored:\n{log_str}")
- wandb.termerror(f"Run {run_id} errored:{term_str}")
- if os.getenv(wandb.env.AGENT_DISABLE_FLAPPING) == "true":
- self._exit_flag = True
- return
- elif (
- time.time() - self._start_time < self.FLAPPING_MAX_SECONDS
- ) and (len(self._exceptions) >= self.FLAPPING_MAX_FAILURES):
- msg = f"Detected {self.FLAPPING_MAX_FAILURES} failed runs in the first {self.FLAPPING_MAX_SECONDS} seconds, killing sweep."
- logger.error(msg)
- wandb.termerror(msg)
- wandb.termlog(
- "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
- )
- self._exit_flag = True
- return
- if (
- self._max_initial_failures < len(self._exceptions)
- and len(self._exceptions) >= count
- ):
- msg = f"Detected {self._max_initial_failures} failed runs in a row at start, killing sweep."
- logger.error(msg)
- wandb.termerror(msg)
- wandb.termlog(
- "To change this value set WANDB_AGENT_MAX_INITIAL_FAILURES=val"
- )
- self._exit_flag = True
- return
- if self._count and self._count == count:
- logger.debug("Exiting main loop because max count reached.")
- self._exit_flag = True
- return
- except KeyboardInterrupt:
- logger.debug("Ctrl + C detected. Stopping sweep.")
- wandb.termlog("Ctrl + C detected. Stopping sweep.")
- self._exit()
- return
- except Exception:
- if self._exit_flag:
- logger.debug("Exiting main loop due to exit flag.")
- wandb.termlog("Sweep Agent: Killed.")
- return
- else:
- raise
- finally:
- _INSTANCES -= 1
- def _run_job(self, job):
- try:
- run_id = job.run_id
- config_file = os.path.join(
- "wandb", "sweep-" + self._sweep_id, "config-" + run_id + ".yaml"
- )
- os.environ[wandb.env.RUN_ID] = run_id
- base_dir = os.environ.get(wandb.env.DIR, "")
- sweep_param_path = os.path.join(base_dir, config_file)
- os.environ[wandb.env.SWEEP_PARAM_PATH] = sweep_param_path
- config_util.save_config_file_from_dict(sweep_param_path, job.config)
- os.environ[wandb.env.SWEEP_ID] = self._sweep_id
- wandb.teardown()
- wandb.termlog(f"Agent Starting Run: {run_id} with config:")
- for k, v in job.config.items():
- wandb.termlog("\t{}: {}".format(k, v["value"]))
- try:
- self._function()
- except KeyboardInterrupt:
- raise
- except Exception as e:
- # Log the run's exceptions directly to stderr to match CLI case, and wrap so we
- # can identify it as coming from the job later later. This will get automatically
- # logged by console_capture.py. Exception handler below will also handle exceptions
- # in setup code.
- exc_repr = _format_exception_traceback(e)
- print(exc_repr, file=sys.stderr) # noqa: T201
- raise _JobError(f"Run threw exception: {str(e)}") from e
- wandb.finish()
- except KeyboardInterrupt:
- raise
- except Exception as e:
- wandb.finish(exit_code=1)
- if self._run_status[run_id] == RunStatus.RUNNING:
- self._run_status[run_id] = RunStatus.ERRORED
- self._exceptions[run_id] = e
- finally:
- # clean up the environment changes made
- os.environ.pop(wandb.env.RUN_ID, None)
- os.environ.pop(wandb.env.SWEEP_ID, None)
- os.environ.pop(wandb.env.SWEEP_PARAM_PATH, None)
- def run(self):
- logger.info(
- f"Starting sweep agent: entity={self._entity}, project={self._project}, count={self._count}"
- )
- self._setup()
- # self._main_thread = threading.Thread(target=self._run_jobs_from_queue)
- self._heartbeat_thread = threading.Thread(target=self._heartbeat)
- self._heartbeat_thread.daemon = True
- # self._main_thread.start()
- self._heartbeat_thread.start()
- # self._main_thread.join()
- self._run_jobs_from_queue()
- def pyagent(sweep_id, function, entity=None, project=None, count=None):
- """Generic agent entrypoint, used for CLI or jupyter.
- Args:
- sweep_id (dict): Sweep ID generated by CLI or sweep API
- function (func, optional): A function to call instead of the "program"
- entity (str, optional): W&B Entity
- project (str, optional): W&B Project
- count (int, optional): the number of trials to run.
- """
- if not callable(function):
- raise TypeError("function parameter must be callable!")
- agent = Agent(
- sweep_id,
- function=function,
- entity=entity,
- project=project,
- count=count,
- )
- agent.run()
- def _format_exception_traceback(exc):
- return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
- class _JobError(Exception):
- """Exception raised when a job fails during execution."""
- pass
- def _get_exception_logger_and_term_strs(exc):
- if isinstance(exc, _JobError) and exc.__cause__:
- # If it's a JobException, get the original exception for display
- job_exc = exc.__cause__
- log_str = _format_exception_traceback(job_exc)
- # Don't long full stacktrace to terminal again because we already
- # printed it to stderr.
- term_str = " " + str(job_exc)
- else:
- log_str = _format_exception_traceback(exc)
- term_str = "\n" + log_str
- return log_str, term_str
- _INSTANCES = 0
- def is_running():
- return bool(_INSTANCES)
|