| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686 |
- from __future__ import annotations
- import contextlib
- import logging
- import multiprocessing
- import os
- import platform
- import queue
- import re
- import signal
- import socket
- import subprocess
- import sys
- import time
- import traceback
- from typing import Any, Callable
- import wandb
- from wandb import util
- from wandb.sdk import wandb_login
- from wandb.sdk.lib import config_util, ipython
- logger = logging.getLogger(__name__)
- class AgentError(Exception):
- pass
- class AgentProcess:
- """Launch and manage a process."""
- def __init__(
- self,
- env=None,
- command=None,
- function=None,
- run_id=None,
- in_jupyter=None,
- forward_signals=False,
- ):
- self._popen = None
- self._proc = None
- self._finished_q = multiprocessing.Queue()
- self._proc_killed = False
- # Store original handlers
- self._original_handlers = {}
- # Set up handlers for all possible signals
- if forward_signals:
- skip_signals = {
- getattr(signal, "SIGKILL", None),
- getattr(signal, "SIGSTOP", None),
- }
- skip_signals.discard(None)
- for signum in signal.valid_signals():
- # Skip signals that can't be caught
- if signum in skip_signals:
- continue
- with contextlib.suppress(OSError, ValueError):
- # Some signals might not be supported on all platforms
- self._original_handlers[signum] = signal.getsignal(signum)
- signal.signal(signum, self._forward_signal)
- if command:
- if platform.system() == "Windows":
- kwargs = dict(creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
- env.pop(wandb.env.SERVICE, None)
- # TODO: Determine if we need the same stdin workaround as POSIX case below.
- self._popen = subprocess.Popen(command, env=env, **kwargs)
- else:
- if sys.version_info >= (3, 11):
- # preexec_fn=os.setpgrp is not thread-safe; process_group was introduced in
- # python 3.11 to replace it, so use that when possible
- kwargs = dict(process_group=0)
- else:
- kwargs = dict(preexec_fn=os.setpgrp)
- env.pop(wandb.env.SERVICE, None)
- # Upon spawning the subprocess in a new process group, the child's process group is
- # not connected to the controlling terminal's stdin. If it tries to access stdin,
- # it gets a SIGTTIN and blocks until we give it the terminal, which we don't want
- # to do.
- #
- # By using subprocess.PIPE, we give it an independent stdin. However, it will still
- # block if it tries to read from stdin, because we're not writing anything to it.
- # We immediately close the subprocess's stdin here so it can fail fast and get an
- # EOF.
- #
- # (One situation that makes this relevant is that importing `readline` even
- # indirectly can cause the child to attempt to access stdin, which can trigger the
- # deadlock. In Python 3.13, `import torch` indirectly imports `readline` via `pdb`,
- # meaning `import torch` in a run script can deadlock unless we override stdin.
- # See https://github.com/wandb/wandb/pull/10489 description for more details.)
- #
- # Also, we avoid spawning a new session because that breaks preempted child process
- # handling.
- self._popen = subprocess.Popen(
- command,
- env=env,
- stdin=subprocess.PIPE,
- **kwargs,
- )
- self._popen.stdin.close()
- elif function:
- self._proc = multiprocessing.Process(
- target=self._start,
- args=(self._finished_q, env, function, run_id, in_jupyter),
- )
- self._proc.start()
- else:
- raise AgentError("Agent Process requires command or function")
- def _forward_signal(self, signum, frame):
- """Forward a received signal to any child process, mirroring the agent's behavior."""
- if self._popen:
- if platform.system() == "Windows" and signum in (
- signal.SIGINT,
- signal.SIGTERM,
- ):
- # On Windows, we can only send CTRL_BREAK_EVENT or CTRL_C_EVENT
- self._popen.send_signal(signal.CTRL_BREAK_EVENT)
- else:
- self._popen.send_signal(signum)
- if self._proc:
- if hasattr(signal, "SIGKILL") and signum == signal.SIGKILL:
- self._proc.kill()
- else:
- self._proc.send_signal(signum)
- # Call original handler to ensure parent process handles signal
- original_handler = self._original_handlers.get(signum)
- if original_handler and callable(original_handler):
- original_handler(signum, frame)
- def _start(self, finished_q, env, function, run_id, in_jupyter):
- if env:
- for k, v in env.items():
- os.environ[k] = v
- # call user function
- wandb.termlog(f"Agent Started Run: {run_id}")
- if function:
- function()
- wandb.termlog(f"Agent Finished Run: {run_id}\n")
- # complete the run
- run = wandb.run
- if run:
- wandb.join()
- # signal that the process is finished
- finished_q.put(True)
- def poll(self):
- if self._popen:
- return self._popen.poll()
- if self._proc_killed:
- # we need to join process to prevent zombies
- self._proc.join()
- return True
- try:
- finished = self._finished_q.get(False, 0)
- if finished:
- return True
- except queue.Empty:
- pass
- return
- def wait(self):
- if self._popen:
- # if on windows, wait() will block and we won't be able to interrupt
- if platform.system() == "Windows":
- while True:
- p = self._popen.poll()
- if p is not None:
- return p
- time.sleep(1)
- return self._popen.wait()
- return self._proc.join()
- def kill(self):
- if self._popen:
- return self._popen.kill()
- pid = self._proc.pid
- if pid:
- ret = os.kill(pid, signal.SIGKILL)
- self._proc_killed = True
- return ret
- return
- def terminate(self):
- if self._popen:
- # windows terminate is too strong, send Ctrl-C instead
- if platform.system() == "Windows":
- return self._popen.send_signal(signal.CTRL_C_EVENT)
- return self._popen.terminate()
- return self._proc.terminate()
- class Agent:
- POLL_INTERVAL = 5
- REPORT_INTERVAL = 0
- KILL_DELAY = 30
- FLAPPING_MAX_SECONDS = 60
- FLAPPING_MAX_FAILURES = 3
- MAX_INITIAL_FAILURES = 5
- DEFAULT_SWEEP_COMMAND: list[str] = [
- "${env}",
- "${interpreter}",
- "${program}",
- "${args}",
- ]
- SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}")
- def __init__(
- self,
- api,
- queue,
- sweep_id=None,
- function=None,
- in_jupyter=None,
- count=None,
- forward_signals=False,
- ):
- self._api = api
- self._queue = queue
- self._run_processes = {} # keyed by run.id (GQL run name)
- self._server_responses = []
- self._sweep_id = sweep_id
- self._in_jupyter = in_jupyter
- self._log = []
- self._running = True
- self._last_report_time = None
- self._function = function
- self._report_interval = wandb.env.get_agent_report_interval(
- self.REPORT_INTERVAL
- )
- self._kill_delay = wandb.env.get_agent_kill_delay(self.KILL_DELAY)
- self._finished = 0
- self._failed = 0
- self._count = count
- self._sweep_command = []
- self._max_initial_failures = wandb.env.get_agent_max_initial_failures(
- self.MAX_INITIAL_FAILURES
- )
- self._forward_signals = forward_signals
- if self._report_interval is None:
- raise AgentError("Invalid agent report interval")
- if self._kill_delay is None:
- raise AgentError("Invalid agent kill delay")
- # if the directory to log to is not set, set it
- if os.environ.get("WANDB_DIR") is None:
- os.environ["WANDB_DIR"] = os.path.abspath(os.getcwd())
- def is_flapping(self):
- """Determine if the process is flapping.
- Flapping occurs if the agents receives FLAPPING_MAX_FAILURES non-0 exit codes in
- the first FLAPPING_MAX_SECONDS.
- """
- if os.getenv(wandb.env.AGENT_DISABLE_FLAPPING) == "true":
- return False
- if time.time() < wandb.START_TIME + self.FLAPPING_MAX_SECONDS:
- return self._failed >= self.FLAPPING_MAX_FAILURES
- def is_failing(self):
- return (
- self._failed >= self._finished
- and self._max_initial_failures <= self._failed
- )
- def run(self): # noqa: C901
- # TODO: catch exceptions, handle errors, show validation warnings, and make more generic
- import yaml
- sweep_obj = self._api.sweep(self._sweep_id, "{}")
- if sweep_obj:
- sweep_yaml = sweep_obj.get("config")
- if sweep_yaml:
- sweep_config = yaml.safe_load(sweep_yaml)
- if sweep_config:
- sweep_command = sweep_config.get("command")
- if sweep_command and isinstance(sweep_command, list):
- self._sweep_command = sweep_command
- # TODO: include sweep ID
- agent = self._api.register_agent(socket.gethostname(), sweep_id=self._sweep_id)
- agent_id = agent["id"]
- try:
- while self._running:
- commands = util.read_many_from_queue(
- self._queue, 100, self.POLL_INTERVAL
- )
- for command in commands:
- command["resp_queue"].put(self._process_command(command))
- now = util.stopwatch_now()
- if self._last_report_time is None or (
- self._report_interval != 0
- and now > self._last_report_time + self._report_interval
- ):
- logger.info("Running runs: %s", list(self._run_processes.keys()))
- self._last_report_time = now
- run_status = {}
- for run_id, run_process in list(self._run_processes.items()):
- poll_result = run_process.poll()
- if poll_result is None:
- run_status[run_id] = True
- continue
- elif (
- not isinstance(poll_result, bool)
- and isinstance(poll_result, int)
- and poll_result > 0
- ):
- self._failed += 1
- # TODO: raise an exception
- if self.is_flapping():
- logger.error(
- "Detected %i failed runs in the first %i seconds, shutting down.",
- self.FLAPPING_MAX_FAILURES,
- self.FLAPPING_MAX_SECONDS,
- )
- logger.info(
- "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
- )
- self._running = False
- break
- # TODO: raise an exception
- if self.is_failing():
- logger.error(
- "Detected %i failed runs in a row, shutting down.",
- self._max_initial_failures,
- )
- logger.info(
- "To change this value set WANDB_AGENT_MAX_INITIAL_FAILURES=val"
- )
- self._running = False
- break
- logger.info("Cleaning up finished run: %s", run_id)
- # wandb.teardown() was added with wandb service and is a hammer to make
- # sure that active runs are finished before moving on to another agent run
- #
- # In the future, a lighter weight way to implement this could be to keep a
- # service process open for all the agent instances and inform_finish when
- # the run should be marked complete. This however could require
- # inform_finish on every run created by this process.
- if hasattr(wandb, "teardown"):
- exit_code = 0
- if isinstance(poll_result, int):
- exit_code = poll_result
- elif isinstance(poll_result, bool):
- exit_code = -1
- wandb.teardown(exit_code)
- del self._run_processes[run_id]
- self._last_report_time = None
- self._finished += 1
- if self._count and self._finished >= self._count or not self._running:
- self._running = False
- continue
- commands = self._api.agent_heartbeat(agent_id, {}, run_status)
- # TODO: send _server_responses
- self._server_responses = []
- for command in commands:
- self._server_responses.append(self._process_command(command))
- except KeyboardInterrupt:
- try:
- wandb.termlog(
- "Ctrl-c pressed. Waiting for runs to end. Press ctrl-c again to terminate them."
- )
- for _, run_process in self._run_processes.items():
- run_process.wait()
- except KeyboardInterrupt:
- pass
- finally:
- try:
- if not self._in_jupyter:
- wandb.termlog("Terminating and syncing runs. Press ctrl-c to kill.")
- for _, run_process in self._run_processes.items():
- try:
- run_process.terminate()
- except OSError:
- pass # if process is already dead
- for _, run_process in self._run_processes.items():
- run_process.wait()
- except KeyboardInterrupt:
- wandb.termlog("Killing runs and quitting.")
- for _, run_process in self._run_processes.items():
- try:
- run_process.kill()
- except OSError:
- pass # if process is already dead
- def _process_command(self, command):
- logger.info("Agent received command: {}".format(command.get("type", "Unknown")))
- response = {
- "id": command.get("id"),
- "result": None,
- }
- try:
- command_type = command["type"]
- if command_type == "run":
- result = self._command_run(command)
- elif command_type == "stop":
- result = self._command_stop(command)
- elif command_type == "exit":
- result = self._command_exit(command)
- elif command_type == "resume":
- result = self._command_run(command)
- else:
- raise AgentError(f"No such command: {command_type}") # noqa: TRY301
- response["result"] = result
- except Exception:
- logger.exception("Exception while processing command: %s", command)
- ex_type, ex, tb = sys.exc_info()
- response["exception"] = f"{ex_type.__name__}: {str(ex)}"
- response["traceback"] = traceback.format_tb(tb)
- del tb
- self._log.append((command, response))
- return response
- def _command_run(self, command):
- from wandb.sdk.launch.sweeps import utils as sweep_utils
- logger.info(
- "Agent starting run with config:\n"
- + "\n".join(
- ["\t{}: {}".format(k, v["value"]) for k, v in command["args"].items()]
- )
- )
- if self._in_jupyter:
- wandb.termlog(
- f"Agent Starting Run: {command.get('run_id')} with config:\n"
- + "\n".join(
- [f"\t{k}: {v['value']}" for k, v in command["args"].items()]
- )
- )
- # Setup sweep command
- sweep_command: list[str] = sweep_utils.create_sweep_command(self._sweep_command)
- run_id = command.get("run_id")
- sweep_id = os.environ.get(wandb.env.SWEEP_ID)
- # TODO(jhr): move into settings
- config_file = os.path.join(
- "wandb", "sweep-" + sweep_id, "config-" + run_id + ".yaml"
- )
- json_file = os.path.join(
- "wandb", "sweep-" + sweep_id, "config-" + run_id + ".json"
- )
- os.environ[wandb.env.RUN_ID] = run_id
- base_dir = os.environ.get(wandb.env.DIR, "")
- sweep_param_path = os.path.join(base_dir, config_file)
- os.environ[wandb.env.SWEEP_PARAM_PATH] = sweep_param_path
- config_util.save_config_file_from_dict(sweep_param_path, command["args"])
- env = dict(os.environ)
- sweep_vars: dict[str, Any] = sweep_utils.create_sweep_command_args(command)
- if "${args_json_file}" in sweep_command:
- with open(json_file, "w") as fp:
- fp.write(sweep_vars["args_json"][0])
- if self._function:
- # make sure that each run regenerates setup singleton
- wandb.teardown()
- proc = AgentProcess(
- function=self._function,
- env=env,
- run_id=run_id,
- in_jupyter=self._in_jupyter,
- forward_signals=self._forward_signals,
- )
- else:
- sweep_vars["interpreter"] = ["python"]
- sweep_vars["program"] = [command["program"]]
- sweep_vars["args_json_file"] = [json_file]
- if platform.system() != "Windows":
- sweep_vars["env"] = ["/usr/bin/env"]
- command_list = []
- for c in sweep_command:
- c = str(c)
- if c.startswith("${") and c.endswith("}"):
- replace_list = sweep_vars.get(c[2:-1])
- command_list += replace_list or []
- else:
- command_list += [c]
- logger.info(
- "About to run command: {}".format(
- " ".join(f'"{c}"' if " " in c else c for c in command_list)
- )
- )
- proc = AgentProcess(
- command=command_list, env=env, forward_signals=self._forward_signals
- )
- self._run_processes[run_id] = proc
- # we keep track of when we sent the sigterm to give processes a chance
- # to handle the signal before sending sigkill every heartbeat
- self._run_processes[run_id].last_sigterm_time = None
- self._last_report_time = None
- def _command_stop(self, command):
- run_id = command["run_id"]
- if run_id in self._run_processes:
- proc = self._run_processes[run_id]
- now = util.stopwatch_now()
- if proc.last_sigterm_time is None:
- proc.last_sigterm_time = now
- logger.info("Stop: %s", run_id)
- try:
- proc.terminate()
- except OSError: # if process is already dead
- pass
- elif now > proc.last_sigterm_time + self._kill_delay:
- logger.info("Kill: %s", run_id)
- try:
- proc.kill()
- except OSError: # if process is already dead
- pass
- else:
- logger.error("Run %s not running", run_id)
- def _command_exit(self, command):
- logger.info("Received exit command. Killing runs and quitting.")
- for _, proc in self._run_processes.items():
- try:
- proc.kill()
- except OSError:
- # process is already dead
- pass
- self._running = False
- class AgentApi:
- def __init__(self, queue):
- self._queue = queue
- self._command_id = 0
- self._multiproc_manager = multiprocessing.Manager()
- def command(self, command):
- command["origin"] = "local"
- command["id"] = f"local-{self._command_id}"
- self._command_id += 1
- resp_queue = self._multiproc_manager.Queue()
- command["resp_queue"] = resp_queue
- self._queue.put(command)
- result = resp_queue.get()
- print("result:", result) # noqa: T201
- if "exception" in result:
- print("Exception occurred while running command") # noqa: T201
- for line in result["traceback"]:
- print(line.strip()) # noqa: T201
- print(result["exception"]) # noqa: T201
- return result
- def run_agent(
- sweep_id,
- function=None,
- in_jupyter=None,
- entity=None,
- project=None,
- count=None,
- forward_signals=False,
- ):
- from wandb.apis import InternalApi
- from wandb.sdk.launch.sweeps import utils as sweep_utils
- parts = dict(entity=entity, project=project, name=sweep_id)
- err = sweep_utils.parse_sweep_id(parts)
- if err:
- wandb.termerror(err)
- return
- entity = parts.get("entity") or entity
- project = parts.get("project") or project
- sweep_id = parts.get("name") or sweep_id
- if entity:
- wandb.env.set_entity(entity)
- if project:
- wandb.env.set_project(project)
- if sweep_id:
- # TODO(jhr): remove when jobspec is merged
- os.environ[wandb.env.SWEEP_ID] = sweep_id
- logger.setLevel(logging.DEBUG)
- ch = logging.StreamHandler()
- log_level = logging.DEBUG
- if in_jupyter:
- log_level = logging.ERROR
- ch.setLevel(log_level)
- formatter = logging.Formatter(
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
- )
- ch.setFormatter(formatter)
- try:
- logger.addHandler(ch)
- api = InternalApi()
- queue = multiprocessing.Queue()
- agent = Agent(
- api,
- queue,
- sweep_id=sweep_id,
- function=function,
- in_jupyter=in_jupyter,
- count=count,
- forward_signals=forward_signals,
- )
- agent.run()
- finally:
- # make sure we remove the logging handler (important for jupyter notebooks)
- logger.removeHandler(ch)
- def agent(
- sweep_id: str,
- function: Callable | None = None,
- entity: str | None = None,
- project: str | None = None,
- count: int | None = None,
- forward_signals: bool = False,
- ) -> None:
- """Start one or more sweep agents.
- The sweep agent uses the `sweep_id` to know which sweep it
- is a part of, what function to execute, and (optionally) how
- many agents to run.
- Args:
- sweep_id: The unique identifier for a sweep. A sweep ID
- is generated by W&B CLI or Python SDK.
- function: A function to call instead of the "program"
- specified in the sweep config.
- entity: The username or team name where you want to send W&B
- runs created by the sweep to. Ensure that the entity you
- specify already exists. If you don't specify an entity,
- the run will be sent to your default entity,
- which is usually your username.
- project: The name of the project where W&B runs created from
- the sweep are sent to. If the project is not specified, the
- run is sent to a project labeled "Uncategorized".
- count: The number of sweep config trials to try.
- forward_signals: Whether to forward signals the agent receives
- to the child processes. Only supported by CLI agent.
- """
- from wandb.agents.pyagent import pyagent
- global _INSTANCES
- _INSTANCES += 1
- try:
- # make sure we are logged in
- wandb_login._login(_silent=True)
- if function:
- return pyagent(sweep_id, function, entity, project, count)
- return run_agent(
- sweep_id,
- function=function,
- in_jupyter=ipython.in_jupyter(),
- entity=entity,
- project=project,
- count=count,
- forward_signals=forward_signals,
- )
- finally:
- _INSTANCES -= 1
- _INSTANCES = 0
- def _is_running():
- return bool(_INSTANCES)
|