wandb_agent.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. from __future__ import annotations
  2. import contextlib
  3. import logging
  4. import multiprocessing
  5. import os
  6. import platform
  7. import queue
  8. import re
  9. import signal
  10. import socket
  11. import subprocess
  12. import sys
  13. import time
  14. import traceback
  15. from typing import Any, Callable
  16. import wandb
  17. from wandb import util
  18. from wandb.sdk import wandb_login
  19. from wandb.sdk.lib import config_util, ipython
  20. logger = logging.getLogger(__name__)
  21. class AgentError(Exception):
  22. pass
  23. class AgentProcess:
  24. """Launch and manage a process."""
  25. def __init__(
  26. self,
  27. env=None,
  28. command=None,
  29. function=None,
  30. run_id=None,
  31. in_jupyter=None,
  32. forward_signals=False,
  33. ):
  34. self._popen = None
  35. self._proc = None
  36. self._finished_q = multiprocessing.Queue()
  37. self._proc_killed = False
  38. # Store original handlers
  39. self._original_handlers = {}
  40. # Set up handlers for all possible signals
  41. if forward_signals:
  42. skip_signals = {
  43. getattr(signal, "SIGKILL", None),
  44. getattr(signal, "SIGSTOP", None),
  45. }
  46. skip_signals.discard(None)
  47. for signum in signal.valid_signals():
  48. # Skip signals that can't be caught
  49. if signum in skip_signals:
  50. continue
  51. with contextlib.suppress(OSError, ValueError):
  52. # Some signals might not be supported on all platforms
  53. self._original_handlers[signum] = signal.getsignal(signum)
  54. signal.signal(signum, self._forward_signal)
  55. if command:
  56. if platform.system() == "Windows":
  57. kwargs = dict(creationflags=subprocess.CREATE_NEW_PROCESS_GROUP)
  58. env.pop(wandb.env.SERVICE, None)
  59. # TODO: Determine if we need the same stdin workaround as POSIX case below.
  60. self._popen = subprocess.Popen(command, env=env, **kwargs)
  61. else:
  62. if sys.version_info >= (3, 11):
  63. # preexec_fn=os.setpgrp is not thread-safe; process_group was introduced in
  64. # python 3.11 to replace it, so use that when possible
  65. kwargs = dict(process_group=0)
  66. else:
  67. kwargs = dict(preexec_fn=os.setpgrp)
  68. env.pop(wandb.env.SERVICE, None)
  69. # Upon spawning the subprocess in a new process group, the child's process group is
  70. # not connected to the controlling terminal's stdin. If it tries to access stdin,
  71. # it gets a SIGTTIN and blocks until we give it the terminal, which we don't want
  72. # to do.
  73. #
  74. # By using subprocess.PIPE, we give it an independent stdin. However, it will still
  75. # block if it tries to read from stdin, because we're not writing anything to it.
  76. # We immediately close the subprocess's stdin here so it can fail fast and get an
  77. # EOF.
  78. #
  79. # (One situation that makes this relevant is that importing `readline` even
  80. # indirectly can cause the child to attempt to access stdin, which can trigger the
  81. # deadlock. In Python 3.13, `import torch` indirectly imports `readline` via `pdb`,
  82. # meaning `import torch` in a run script can deadlock unless we override stdin.
  83. # See https://github.com/wandb/wandb/pull/10489 description for more details.)
  84. #
  85. # Also, we avoid spawning a new session because that breaks preempted child process
  86. # handling.
  87. self._popen = subprocess.Popen(
  88. command,
  89. env=env,
  90. stdin=subprocess.PIPE,
  91. **kwargs,
  92. )
  93. self._popen.stdin.close()
  94. elif function:
  95. self._proc = multiprocessing.Process(
  96. target=self._start,
  97. args=(self._finished_q, env, function, run_id, in_jupyter),
  98. )
  99. self._proc.start()
  100. else:
  101. raise AgentError("Agent Process requires command or function")
  102. def _forward_signal(self, signum, frame):
  103. """Forward a received signal to any child process, mirroring the agent's behavior."""
  104. if self._popen:
  105. if platform.system() == "Windows" and signum in (
  106. signal.SIGINT,
  107. signal.SIGTERM,
  108. ):
  109. # On Windows, we can only send CTRL_BREAK_EVENT or CTRL_C_EVENT
  110. self._popen.send_signal(signal.CTRL_BREAK_EVENT)
  111. else:
  112. self._popen.send_signal(signum)
  113. if self._proc:
  114. if hasattr(signal, "SIGKILL") and signum == signal.SIGKILL:
  115. self._proc.kill()
  116. else:
  117. self._proc.send_signal(signum)
  118. # Call original handler to ensure parent process handles signal
  119. original_handler = self._original_handlers.get(signum)
  120. if original_handler and callable(original_handler):
  121. original_handler(signum, frame)
  122. def _start(self, finished_q, env, function, run_id, in_jupyter):
  123. if env:
  124. for k, v in env.items():
  125. os.environ[k] = v
  126. # call user function
  127. wandb.termlog(f"Agent Started Run: {run_id}")
  128. if function:
  129. function()
  130. wandb.termlog(f"Agent Finished Run: {run_id}\n")
  131. # complete the run
  132. run = wandb.run
  133. if run:
  134. wandb.join()
  135. # signal that the process is finished
  136. finished_q.put(True)
  137. def poll(self):
  138. if self._popen:
  139. return self._popen.poll()
  140. if self._proc_killed:
  141. # we need to join process to prevent zombies
  142. self._proc.join()
  143. return True
  144. try:
  145. finished = self._finished_q.get(False, 0)
  146. if finished:
  147. return True
  148. except queue.Empty:
  149. pass
  150. return
  151. def wait(self):
  152. if self._popen:
  153. # if on windows, wait() will block and we won't be able to interrupt
  154. if platform.system() == "Windows":
  155. while True:
  156. p = self._popen.poll()
  157. if p is not None:
  158. return p
  159. time.sleep(1)
  160. return self._popen.wait()
  161. return self._proc.join()
  162. def kill(self):
  163. if self._popen:
  164. return self._popen.kill()
  165. pid = self._proc.pid
  166. if pid:
  167. ret = os.kill(pid, signal.SIGKILL)
  168. self._proc_killed = True
  169. return ret
  170. return
  171. def terminate(self):
  172. if self._popen:
  173. # windows terminate is too strong, send Ctrl-C instead
  174. if platform.system() == "Windows":
  175. return self._popen.send_signal(signal.CTRL_C_EVENT)
  176. return self._popen.terminate()
  177. return self._proc.terminate()
  178. class Agent:
  179. POLL_INTERVAL = 5
  180. REPORT_INTERVAL = 0
  181. KILL_DELAY = 30
  182. FLAPPING_MAX_SECONDS = 60
  183. FLAPPING_MAX_FAILURES = 3
  184. MAX_INITIAL_FAILURES = 5
  185. DEFAULT_SWEEP_COMMAND: list[str] = [
  186. "${env}",
  187. "${interpreter}",
  188. "${program}",
  189. "${args}",
  190. ]
  191. SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}")
  192. def __init__(
  193. self,
  194. api,
  195. queue,
  196. sweep_id=None,
  197. function=None,
  198. in_jupyter=None,
  199. count=None,
  200. forward_signals=False,
  201. ):
  202. self._api = api
  203. self._queue = queue
  204. self._run_processes = {} # keyed by run.id (GQL run name)
  205. self._server_responses = []
  206. self._sweep_id = sweep_id
  207. self._in_jupyter = in_jupyter
  208. self._log = []
  209. self._running = True
  210. self._last_report_time = None
  211. self._function = function
  212. self._report_interval = wandb.env.get_agent_report_interval(
  213. self.REPORT_INTERVAL
  214. )
  215. self._kill_delay = wandb.env.get_agent_kill_delay(self.KILL_DELAY)
  216. self._finished = 0
  217. self._failed = 0
  218. self._count = count
  219. self._sweep_command = []
  220. self._max_initial_failures = wandb.env.get_agent_max_initial_failures(
  221. self.MAX_INITIAL_FAILURES
  222. )
  223. self._forward_signals = forward_signals
  224. if self._report_interval is None:
  225. raise AgentError("Invalid agent report interval")
  226. if self._kill_delay is None:
  227. raise AgentError("Invalid agent kill delay")
  228. # if the directory to log to is not set, set it
  229. if os.environ.get("WANDB_DIR") is None:
  230. os.environ["WANDB_DIR"] = os.path.abspath(os.getcwd())
  231. def is_flapping(self):
  232. """Determine if the process is flapping.
  233. Flapping occurs if the agents receives FLAPPING_MAX_FAILURES non-0 exit codes in
  234. the first FLAPPING_MAX_SECONDS.
  235. """
  236. if os.getenv(wandb.env.AGENT_DISABLE_FLAPPING) == "true":
  237. return False
  238. if time.time() < wandb.START_TIME + self.FLAPPING_MAX_SECONDS:
  239. return self._failed >= self.FLAPPING_MAX_FAILURES
  240. def is_failing(self):
  241. return (
  242. self._failed >= self._finished
  243. and self._max_initial_failures <= self._failed
  244. )
  245. def run(self): # noqa: C901
  246. # TODO: catch exceptions, handle errors, show validation warnings, and make more generic
  247. import yaml
  248. sweep_obj = self._api.sweep(self._sweep_id, "{}")
  249. if sweep_obj:
  250. sweep_yaml = sweep_obj.get("config")
  251. if sweep_yaml:
  252. sweep_config = yaml.safe_load(sweep_yaml)
  253. if sweep_config:
  254. sweep_command = sweep_config.get("command")
  255. if sweep_command and isinstance(sweep_command, list):
  256. self._sweep_command = sweep_command
  257. # TODO: include sweep ID
  258. agent = self._api.register_agent(socket.gethostname(), sweep_id=self._sweep_id)
  259. agent_id = agent["id"]
  260. try:
  261. while self._running:
  262. commands = util.read_many_from_queue(
  263. self._queue, 100, self.POLL_INTERVAL
  264. )
  265. for command in commands:
  266. command["resp_queue"].put(self._process_command(command))
  267. now = util.stopwatch_now()
  268. if self._last_report_time is None or (
  269. self._report_interval != 0
  270. and now > self._last_report_time + self._report_interval
  271. ):
  272. logger.info("Running runs: %s", list(self._run_processes.keys()))
  273. self._last_report_time = now
  274. run_status = {}
  275. for run_id, run_process in list(self._run_processes.items()):
  276. poll_result = run_process.poll()
  277. if poll_result is None:
  278. run_status[run_id] = True
  279. continue
  280. elif (
  281. not isinstance(poll_result, bool)
  282. and isinstance(poll_result, int)
  283. and poll_result > 0
  284. ):
  285. self._failed += 1
  286. # TODO: raise an exception
  287. if self.is_flapping():
  288. logger.error(
  289. "Detected %i failed runs in the first %i seconds, shutting down.",
  290. self.FLAPPING_MAX_FAILURES,
  291. self.FLAPPING_MAX_SECONDS,
  292. )
  293. logger.info(
  294. "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
  295. )
  296. self._running = False
  297. break
  298. # TODO: raise an exception
  299. if self.is_failing():
  300. logger.error(
  301. "Detected %i failed runs in a row, shutting down.",
  302. self._max_initial_failures,
  303. )
  304. logger.info(
  305. "To change this value set WANDB_AGENT_MAX_INITIAL_FAILURES=val"
  306. )
  307. self._running = False
  308. break
  309. logger.info("Cleaning up finished run: %s", run_id)
  310. # wandb.teardown() was added with wandb service and is a hammer to make
  311. # sure that active runs are finished before moving on to another agent run
  312. #
  313. # In the future, a lighter weight way to implement this could be to keep a
  314. # service process open for all the agent instances and inform_finish when
  315. # the run should be marked complete. This however could require
  316. # inform_finish on every run created by this process.
  317. if hasattr(wandb, "teardown"):
  318. exit_code = 0
  319. if isinstance(poll_result, int):
  320. exit_code = poll_result
  321. elif isinstance(poll_result, bool):
  322. exit_code = -1
  323. wandb.teardown(exit_code)
  324. del self._run_processes[run_id]
  325. self._last_report_time = None
  326. self._finished += 1
  327. if self._count and self._finished >= self._count or not self._running:
  328. self._running = False
  329. continue
  330. commands = self._api.agent_heartbeat(agent_id, {}, run_status)
  331. # TODO: send _server_responses
  332. self._server_responses = []
  333. for command in commands:
  334. self._server_responses.append(self._process_command(command))
  335. except KeyboardInterrupt:
  336. try:
  337. wandb.termlog(
  338. "Ctrl-c pressed. Waiting for runs to end. Press ctrl-c again to terminate them."
  339. )
  340. for _, run_process in self._run_processes.items():
  341. run_process.wait()
  342. except KeyboardInterrupt:
  343. pass
  344. finally:
  345. try:
  346. if not self._in_jupyter:
  347. wandb.termlog("Terminating and syncing runs. Press ctrl-c to kill.")
  348. for _, run_process in self._run_processes.items():
  349. try:
  350. run_process.terminate()
  351. except OSError:
  352. pass # if process is already dead
  353. for _, run_process in self._run_processes.items():
  354. run_process.wait()
  355. except KeyboardInterrupt:
  356. wandb.termlog("Killing runs and quitting.")
  357. for _, run_process in self._run_processes.items():
  358. try:
  359. run_process.kill()
  360. except OSError:
  361. pass # if process is already dead
  362. def _process_command(self, command):
  363. logger.info("Agent received command: {}".format(command.get("type", "Unknown")))
  364. response = {
  365. "id": command.get("id"),
  366. "result": None,
  367. }
  368. try:
  369. command_type = command["type"]
  370. if command_type == "run":
  371. result = self._command_run(command)
  372. elif command_type == "stop":
  373. result = self._command_stop(command)
  374. elif command_type == "exit":
  375. result = self._command_exit(command)
  376. elif command_type == "resume":
  377. result = self._command_run(command)
  378. else:
  379. raise AgentError(f"No such command: {command_type}") # noqa: TRY301
  380. response["result"] = result
  381. except Exception:
  382. logger.exception("Exception while processing command: %s", command)
  383. ex_type, ex, tb = sys.exc_info()
  384. response["exception"] = f"{ex_type.__name__}: {str(ex)}"
  385. response["traceback"] = traceback.format_tb(tb)
  386. del tb
  387. self._log.append((command, response))
  388. return response
  389. def _command_run(self, command):
  390. from wandb.sdk.launch.sweeps import utils as sweep_utils
  391. logger.info(
  392. "Agent starting run with config:\n"
  393. + "\n".join(
  394. ["\t{}: {}".format(k, v["value"]) for k, v in command["args"].items()]
  395. )
  396. )
  397. if self._in_jupyter:
  398. wandb.termlog(
  399. f"Agent Starting Run: {command.get('run_id')} with config:\n"
  400. + "\n".join(
  401. [f"\t{k}: {v['value']}" for k, v in command["args"].items()]
  402. )
  403. )
  404. # Setup sweep command
  405. sweep_command: list[str] = sweep_utils.create_sweep_command(self._sweep_command)
  406. run_id = command.get("run_id")
  407. sweep_id = os.environ.get(wandb.env.SWEEP_ID)
  408. # TODO(jhr): move into settings
  409. config_file = os.path.join(
  410. "wandb", "sweep-" + sweep_id, "config-" + run_id + ".yaml"
  411. )
  412. json_file = os.path.join(
  413. "wandb", "sweep-" + sweep_id, "config-" + run_id + ".json"
  414. )
  415. os.environ[wandb.env.RUN_ID] = run_id
  416. base_dir = os.environ.get(wandb.env.DIR, "")
  417. sweep_param_path = os.path.join(base_dir, config_file)
  418. os.environ[wandb.env.SWEEP_PARAM_PATH] = sweep_param_path
  419. config_util.save_config_file_from_dict(sweep_param_path, command["args"])
  420. env = dict(os.environ)
  421. sweep_vars: dict[str, Any] = sweep_utils.create_sweep_command_args(command)
  422. if "${args_json_file}" in sweep_command:
  423. with open(json_file, "w") as fp:
  424. fp.write(sweep_vars["args_json"][0])
  425. if self._function:
  426. # make sure that each run regenerates setup singleton
  427. wandb.teardown()
  428. proc = AgentProcess(
  429. function=self._function,
  430. env=env,
  431. run_id=run_id,
  432. in_jupyter=self._in_jupyter,
  433. forward_signals=self._forward_signals,
  434. )
  435. else:
  436. sweep_vars["interpreter"] = ["python"]
  437. sweep_vars["program"] = [command["program"]]
  438. sweep_vars["args_json_file"] = [json_file]
  439. if platform.system() != "Windows":
  440. sweep_vars["env"] = ["/usr/bin/env"]
  441. command_list = []
  442. for c in sweep_command:
  443. c = str(c)
  444. if c.startswith("${") and c.endswith("}"):
  445. replace_list = sweep_vars.get(c[2:-1])
  446. command_list += replace_list or []
  447. else:
  448. command_list += [c]
  449. logger.info(
  450. "About to run command: {}".format(
  451. " ".join(f'"{c}"' if " " in c else c for c in command_list)
  452. )
  453. )
  454. proc = AgentProcess(
  455. command=command_list, env=env, forward_signals=self._forward_signals
  456. )
  457. self._run_processes[run_id] = proc
  458. # we keep track of when we sent the sigterm to give processes a chance
  459. # to handle the signal before sending sigkill every heartbeat
  460. self._run_processes[run_id].last_sigterm_time = None
  461. self._last_report_time = None
  462. def _command_stop(self, command):
  463. run_id = command["run_id"]
  464. if run_id in self._run_processes:
  465. proc = self._run_processes[run_id]
  466. now = util.stopwatch_now()
  467. if proc.last_sigterm_time is None:
  468. proc.last_sigterm_time = now
  469. logger.info("Stop: %s", run_id)
  470. try:
  471. proc.terminate()
  472. except OSError: # if process is already dead
  473. pass
  474. elif now > proc.last_sigterm_time + self._kill_delay:
  475. logger.info("Kill: %s", run_id)
  476. try:
  477. proc.kill()
  478. except OSError: # if process is already dead
  479. pass
  480. else:
  481. logger.error("Run %s not running", run_id)
  482. def _command_exit(self, command):
  483. logger.info("Received exit command. Killing runs and quitting.")
  484. for _, proc in self._run_processes.items():
  485. try:
  486. proc.kill()
  487. except OSError:
  488. # process is already dead
  489. pass
  490. self._running = False
  491. class AgentApi:
  492. def __init__(self, queue):
  493. self._queue = queue
  494. self._command_id = 0
  495. self._multiproc_manager = multiprocessing.Manager()
  496. def command(self, command):
  497. command["origin"] = "local"
  498. command["id"] = f"local-{self._command_id}"
  499. self._command_id += 1
  500. resp_queue = self._multiproc_manager.Queue()
  501. command["resp_queue"] = resp_queue
  502. self._queue.put(command)
  503. result = resp_queue.get()
  504. print("result:", result) # noqa: T201
  505. if "exception" in result:
  506. print("Exception occurred while running command") # noqa: T201
  507. for line in result["traceback"]:
  508. print(line.strip()) # noqa: T201
  509. print(result["exception"]) # noqa: T201
  510. return result
  511. def run_agent(
  512. sweep_id,
  513. function=None,
  514. in_jupyter=None,
  515. entity=None,
  516. project=None,
  517. count=None,
  518. forward_signals=False,
  519. ):
  520. from wandb.apis import InternalApi
  521. from wandb.sdk.launch.sweeps import utils as sweep_utils
  522. parts = dict(entity=entity, project=project, name=sweep_id)
  523. err = sweep_utils.parse_sweep_id(parts)
  524. if err:
  525. wandb.termerror(err)
  526. return
  527. entity = parts.get("entity") or entity
  528. project = parts.get("project") or project
  529. sweep_id = parts.get("name") or sweep_id
  530. if entity:
  531. wandb.env.set_entity(entity)
  532. if project:
  533. wandb.env.set_project(project)
  534. if sweep_id:
  535. # TODO(jhr): remove when jobspec is merged
  536. os.environ[wandb.env.SWEEP_ID] = sweep_id
  537. logger.setLevel(logging.DEBUG)
  538. ch = logging.StreamHandler()
  539. log_level = logging.DEBUG
  540. if in_jupyter:
  541. log_level = logging.ERROR
  542. ch.setLevel(log_level)
  543. formatter = logging.Formatter(
  544. "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  545. )
  546. ch.setFormatter(formatter)
  547. try:
  548. logger.addHandler(ch)
  549. api = InternalApi()
  550. queue = multiprocessing.Queue()
  551. agent = Agent(
  552. api,
  553. queue,
  554. sweep_id=sweep_id,
  555. function=function,
  556. in_jupyter=in_jupyter,
  557. count=count,
  558. forward_signals=forward_signals,
  559. )
  560. agent.run()
  561. finally:
  562. # make sure we remove the logging handler (important for jupyter notebooks)
  563. logger.removeHandler(ch)
  564. def agent(
  565. sweep_id: str,
  566. function: Callable | None = None,
  567. entity: str | None = None,
  568. project: str | None = None,
  569. count: int | None = None,
  570. forward_signals: bool = False,
  571. ) -> None:
  572. """Start one or more sweep agents.
  573. The sweep agent uses the `sweep_id` to know which sweep it
  574. is a part of, what function to execute, and (optionally) how
  575. many agents to run.
  576. Args:
  577. sweep_id: The unique identifier for a sweep. A sweep ID
  578. is generated by W&B CLI or Python SDK.
  579. function: A function to call instead of the "program"
  580. specified in the sweep config.
  581. entity: The username or team name where you want to send W&B
  582. runs created by the sweep to. Ensure that the entity you
  583. specify already exists. If you don't specify an entity,
  584. the run will be sent to your default entity,
  585. which is usually your username.
  586. project: The name of the project where W&B runs created from
  587. the sweep are sent to. If the project is not specified, the
  588. run is sent to a project labeled "Uncategorized".
  589. count: The number of sweep config trials to try.
  590. forward_signals: Whether to forward signals the agent receives
  591. to the child processes. Only supported by CLI agent.
  592. """
  593. from wandb.agents.pyagent import pyagent
  594. global _INSTANCES
  595. _INSTANCES += 1
  596. try:
  597. # make sure we are logged in
  598. wandb_login._login(_silent=True)
  599. if function:
  600. return pyagent(sweep_id, function, entity, project, count)
  601. return run_agent(
  602. sweep_id,
  603. function=function,
  604. in_jupyter=ipython.in_jupyter(),
  605. entity=entity,
  606. project=project,
  607. count=count,
  608. forward_signals=forward_signals,
  609. )
  610. finally:
  611. _INSTANCES -= 1
  612. _INSTANCES = 0
  613. def _is_running():
  614. return bool(_INSTANCES)