pyagent.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. """Agent - Agent object.
  2. Manage wandb agent.
  3. """
  4. import ctypes
  5. import logging
  6. import os
  7. import queue
  8. import socket
  9. import sys
  10. import threading
  11. import time
  12. import traceback
  13. import wandb
  14. from wandb.apis import InternalApi
  15. from wandb.sdk.launch.sweeps import SweepNotFoundError
  16. from wandb.sdk.launch.sweeps import utils as sweep_utils
  17. from wandb.sdk.lib import config_util
  18. logger = logging.getLogger(__name__)
  19. def _terminate_thread(thread):
  20. if not thread.is_alive():
  21. return
  22. if hasattr(thread, "_terminated"):
  23. return
  24. thread._terminated = True
  25. tid = getattr(thread, "_thread_id", None)
  26. if tid is None:
  27. for k, v in threading._active.items():
  28. if v is thread:
  29. tid = k
  30. if tid is None:
  31. # This should never happen
  32. return
  33. logger.debug(f"Terminating thread: {tid}")
  34. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
  35. ctypes.c_long(tid), ctypes.py_object(Exception)
  36. )
  37. if res == 0:
  38. # This should never happen
  39. return
  40. elif res != 1:
  41. # Revert
  42. logger.debug(f"Termination failed for thread {tid}")
  43. ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None)
  44. class Job:
  45. def __init__(self, command):
  46. self.command = command
  47. job_type = command.get("type")
  48. self.type = job_type
  49. self.run_id = command.get("run_id")
  50. self.config = command.get("args")
  51. def __repr__(self):
  52. if self.type == "run":
  53. return f"Job({self.run_id},{self.config})"
  54. elif self.type == "stop":
  55. return f"stop({self.run_id})"
  56. else:
  57. return "exit"
  58. class RunStatus:
  59. QUEUED = "QUEUED"
  60. RUNNING = "RUNNING"
  61. STOPPED = "STOPPED"
  62. ERRORED = "ERRORED"
  63. DONE = "DONE"
  64. class Agent:
  65. FLAPPING_MAX_SECONDS = 60
  66. FLAPPING_MAX_FAILURES = 3
  67. MAX_INITIAL_FAILURES = 5
  68. def __init__(
  69. self, sweep_id=None, project=None, entity=None, function=None, count=None
  70. ):
  71. self._sweep_path = sweep_id
  72. self._sweep_id = None
  73. self._project = project
  74. self._entity = entity
  75. self._function = function
  76. self._count = count
  77. # glob_config = os.path.expanduser('~/.config/wandb/settings')
  78. # loc_config = 'wandb/settings'
  79. # files = (glob_config, loc_config)
  80. self._api = InternalApi()
  81. self._agent_id = None
  82. self._max_initial_failures = wandb.env.get_agent_max_initial_failures(
  83. self.MAX_INITIAL_FAILURES
  84. )
  85. # if the directory to log to is not set, set it
  86. if os.environ.get(wandb.env.DIR) is None:
  87. os.environ[wandb.env.DIR] = os.path.abspath(os.getcwd())
  88. def _init(self):
  89. # These are not in constructor so that Agent instance can be rerun
  90. self._run_threads = {}
  91. self._run_status = {}
  92. self._queue = queue.Queue()
  93. self._exit_flag = False
  94. self._exceptions = {}
  95. self._start_time = time.time()
  96. def _register(self):
  97. logger.debug("Agent._register()")
  98. agent = self._api.register_agent(socket.gethostname(), sweep_id=self._sweep_id)
  99. self._agent_id = agent["id"]
  100. logger.debug(f"agent_id = {self._agent_id}")
  101. def _setup(self):
  102. logger.debug("Agent._setup()")
  103. self._init()
  104. parts = dict(entity=self._entity, project=self._project, name=self._sweep_path)
  105. err = sweep_utils.parse_sweep_id(parts)
  106. if err:
  107. wandb.termerror(err)
  108. return
  109. entity = parts.get("entity") or self._entity
  110. project = parts.get("project") or self._project
  111. sweep_id = parts.get("name") or self._sweep_id
  112. if sweep_id:
  113. os.environ[wandb.env.SWEEP_ID] = sweep_id
  114. if entity:
  115. wandb.env.set_entity(entity)
  116. if project:
  117. wandb.env.set_project(project)
  118. if sweep_id:
  119. self._sweep_id = sweep_id
  120. self._register()
  121. def _stop_run(self, run_id):
  122. logger.debug(f"Stopping run {run_id}.")
  123. self._run_status[run_id] = RunStatus.STOPPED
  124. thread = self._run_threads.get(run_id)
  125. if thread:
  126. _terminate_thread(thread)
  127. def _stop_all_runs(self):
  128. logger.debug("Stopping all runs.")
  129. for run in list(self._run_threads.keys()):
  130. self._stop_run(run)
  131. def _exit(self):
  132. self._stop_all_runs()
  133. self._exit_flag = True
  134. # _terminate_thread(self._main_thread)
  135. def _heartbeat(self):
  136. while True:
  137. if self._exit_flag:
  138. return
  139. # if not self._main_thread.is_alive():
  140. # return
  141. run_status = {
  142. run: True
  143. for run, status in self._run_status.items()
  144. if status in (RunStatus.QUEUED, RunStatus.RUNNING)
  145. }
  146. try:
  147. commands = self._api.agent_heartbeat(self._agent_id, {}, run_status)
  148. except SweepNotFoundError:
  149. wandb.termerror(
  150. "Sweep was deleted or agent was not found. Stopping sweep."
  151. )
  152. self._exit()
  153. return
  154. if commands:
  155. job = Job(commands[0])
  156. logger.debug(f"Job received: {job}")
  157. if job.type in ["run", "resume"]:
  158. self._queue.put(job)
  159. self._run_status[job.run_id] = RunStatus.QUEUED
  160. elif job.type == "stop":
  161. self._stop_run(job.run_id)
  162. elif job.type == "exit":
  163. self._exit()
  164. return
  165. time.sleep(5)
  166. def _run_jobs_from_queue(self):
  167. global _INSTANCES
  168. _INSTANCES += 1
  169. try:
  170. waiting = False
  171. count = 0
  172. while True:
  173. if self._exit_flag:
  174. return
  175. try:
  176. try:
  177. job = self._queue.get(timeout=5)
  178. if self._exit_flag:
  179. logger.debug("Exiting main loop due to exit flag.")
  180. wandb.termlog("Sweep Agent: Exiting.")
  181. return
  182. except queue.Empty:
  183. if not waiting:
  184. logger.debug("Paused.")
  185. wandb.termlog("Sweep Agent: Waiting for job.")
  186. waiting = True
  187. time.sleep(5)
  188. if self._exit_flag:
  189. logger.debug("Exiting main loop due to exit flag.")
  190. wandb.termlog("Sweep Agent: Exiting.")
  191. return
  192. continue
  193. if waiting:
  194. logger.debug("Resumed.")
  195. wandb.termlog("Job received.")
  196. waiting = False
  197. count += 1
  198. run_id = job.run_id
  199. if self._run_status[run_id] == RunStatus.STOPPED:
  200. continue
  201. logger.debug(f"Spawning new thread for run {run_id}.")
  202. thread = threading.Thread(target=self._run_job, args=(job,))
  203. self._run_threads[run_id] = thread
  204. thread.start()
  205. self._run_status[run_id] = RunStatus.RUNNING
  206. thread.join()
  207. logger.debug(f"Thread joined for run {run_id}.")
  208. if self._run_status[run_id] == RunStatus.RUNNING:
  209. self._run_status[run_id] = RunStatus.DONE
  210. elif self._run_status[run_id] == RunStatus.ERRORED:
  211. exc = self._exceptions[run_id]
  212. # Extract to reduce a decision point to avoid ruff c901
  213. log_str, term_str = _get_exception_logger_and_term_strs(exc)
  214. logger.error(f"Run {run_id} errored:\n{log_str}")
  215. wandb.termerror(f"Run {run_id} errored:{term_str}")
  216. if os.getenv(wandb.env.AGENT_DISABLE_FLAPPING) == "true":
  217. self._exit_flag = True
  218. return
  219. elif (
  220. time.time() - self._start_time < self.FLAPPING_MAX_SECONDS
  221. ) and (len(self._exceptions) >= self.FLAPPING_MAX_FAILURES):
  222. msg = f"Detected {self.FLAPPING_MAX_FAILURES} failed runs in the first {self.FLAPPING_MAX_SECONDS} seconds, killing sweep."
  223. logger.error(msg)
  224. wandb.termerror(msg)
  225. wandb.termlog(
  226. "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
  227. )
  228. self._exit_flag = True
  229. return
  230. if (
  231. self._max_initial_failures < len(self._exceptions)
  232. and len(self._exceptions) >= count
  233. ):
  234. msg = f"Detected {self._max_initial_failures} failed runs in a row at start, killing sweep."
  235. logger.error(msg)
  236. wandb.termerror(msg)
  237. wandb.termlog(
  238. "To change this value set WANDB_AGENT_MAX_INITIAL_FAILURES=val"
  239. )
  240. self._exit_flag = True
  241. return
  242. if self._count and self._count == count:
  243. logger.debug("Exiting main loop because max count reached.")
  244. self._exit_flag = True
  245. return
  246. except KeyboardInterrupt:
  247. logger.debug("Ctrl + C detected. Stopping sweep.")
  248. wandb.termlog("Ctrl + C detected. Stopping sweep.")
  249. self._exit()
  250. return
  251. except Exception:
  252. if self._exit_flag:
  253. logger.debug("Exiting main loop due to exit flag.")
  254. wandb.termlog("Sweep Agent: Killed.")
  255. return
  256. else:
  257. raise
  258. finally:
  259. _INSTANCES -= 1
  260. def _run_job(self, job):
  261. try:
  262. run_id = job.run_id
  263. config_file = os.path.join(
  264. "wandb", "sweep-" + self._sweep_id, "config-" + run_id + ".yaml"
  265. )
  266. os.environ[wandb.env.RUN_ID] = run_id
  267. base_dir = os.environ.get(wandb.env.DIR, "")
  268. sweep_param_path = os.path.join(base_dir, config_file)
  269. os.environ[wandb.env.SWEEP_PARAM_PATH] = sweep_param_path
  270. config_util.save_config_file_from_dict(sweep_param_path, job.config)
  271. os.environ[wandb.env.SWEEP_ID] = self._sweep_id
  272. wandb.teardown()
  273. wandb.termlog(f"Agent Starting Run: {run_id} with config:")
  274. for k, v in job.config.items():
  275. wandb.termlog("\t{}: {}".format(k, v["value"]))
  276. try:
  277. self._function()
  278. except KeyboardInterrupt:
  279. raise
  280. except Exception as e:
  281. # Log the run's exceptions directly to stderr to match CLI case, and wrap so we
  282. # can identify it as coming from the job later later. This will get automatically
  283. # logged by console_capture.py. Exception handler below will also handle exceptions
  284. # in setup code.
  285. exc_repr = _format_exception_traceback(e)
  286. print(exc_repr, file=sys.stderr) # noqa: T201
  287. raise _JobError(f"Run threw exception: {str(e)}") from e
  288. wandb.finish()
  289. except KeyboardInterrupt:
  290. raise
  291. except Exception as e:
  292. wandb.finish(exit_code=1)
  293. if self._run_status[run_id] == RunStatus.RUNNING:
  294. self._run_status[run_id] = RunStatus.ERRORED
  295. self._exceptions[run_id] = e
  296. finally:
  297. # clean up the environment changes made
  298. os.environ.pop(wandb.env.RUN_ID, None)
  299. os.environ.pop(wandb.env.SWEEP_ID, None)
  300. os.environ.pop(wandb.env.SWEEP_PARAM_PATH, None)
  301. def run(self):
  302. logger.info(
  303. f"Starting sweep agent: entity={self._entity}, project={self._project}, count={self._count}"
  304. )
  305. self._setup()
  306. # self._main_thread = threading.Thread(target=self._run_jobs_from_queue)
  307. self._heartbeat_thread = threading.Thread(target=self._heartbeat)
  308. self._heartbeat_thread.daemon = True
  309. # self._main_thread.start()
  310. self._heartbeat_thread.start()
  311. # self._main_thread.join()
  312. self._run_jobs_from_queue()
  313. def pyagent(sweep_id, function, entity=None, project=None, count=None):
  314. """Generic agent entrypoint, used for CLI or jupyter.
  315. Args:
  316. sweep_id (dict): Sweep ID generated by CLI or sweep API
  317. function (func, optional): A function to call instead of the "program"
  318. entity (str, optional): W&B Entity
  319. project (str, optional): W&B Project
  320. count (int, optional): the number of trials to run.
  321. """
  322. if not callable(function):
  323. raise TypeError("function parameter must be callable!")
  324. agent = Agent(
  325. sweep_id,
  326. function=function,
  327. entity=entity,
  328. project=project,
  329. count=count,
  330. )
  331. agent.run()
  332. def _format_exception_traceback(exc):
  333. return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
  334. class _JobError(Exception):
  335. """Exception raised when a job fails during execution."""
  336. pass
  337. def _get_exception_logger_and_term_strs(exc):
  338. if isinstance(exc, _JobError) and exc.__cause__:
  339. # If it's a JobException, get the original exception for display
  340. job_exc = exc.__cause__
  341. log_str = _format_exception_traceback(job_exc)
  342. # Don't long full stacktrace to terminal again because we already
  343. # printed it to stderr.
  344. term_str = " " + str(job_exc)
  345. else:
  346. log_str = _format_exception_traceback(exc)
  347. term_str = "\n" + log_str
  348. return log_str, term_str
  349. _INSTANCES = 0
  350. def is_running():
  351. return bool(_INSTANCES)