sync.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. """sync."""
  2. from __future__ import annotations
  3. import atexit
  4. import datetime
  5. import fnmatch
  6. import os
  7. import queue
  8. import sys
  9. import tempfile
  10. import threading
  11. import time
  12. from urllib.parse import quote as url_quote
  13. import wandb
  14. from wandb.proto import wandb_internal_pb2 # type: ignore
  15. from wandb.sdk.interface.interface_queue import InterfaceQueue
  16. from wandb.sdk.internal import context, datastore, handler, sender, tb_watcher
  17. from wandb.sdk.internal.settings_static import SettingsStatic
  18. from wandb.sdk.lib import filesystem
  19. from wandb.util import check_and_warn_old
  20. WANDB_SUFFIX = ".wandb"
  21. SYNCED_SUFFIX = ".synced"
  22. TFEVENT_SUBSTRING = ".tfevents."
  23. class _LocalRun:
  24. def __init__(self, path, synced=None):
  25. self.path = path
  26. self.synced = synced
  27. self.offline = os.path.basename(path).startswith("offline-")
  28. self.datetime = datetime.datetime.strptime(
  29. os.path.basename(path).split("run-")[1].split("-")[0], "%Y%m%d_%H%M%S"
  30. )
  31. def __str__(self):
  32. return self.path
  33. class SyncThread(threading.Thread):
  34. def __init__(
  35. self,
  36. sync_list,
  37. project=None,
  38. entity=None,
  39. run_id=None,
  40. job_type=None,
  41. view=None,
  42. verbose=None,
  43. mark_synced=None,
  44. app_url=None,
  45. sync_tensorboard=None,
  46. log_path=None,
  47. append=None,
  48. skip_console=None,
  49. replace_tags=None,
  50. ):
  51. threading.Thread.__init__(self)
  52. self._sync_list = sync_list
  53. self._project = project
  54. self._entity = entity
  55. self._run_id = run_id
  56. self._job_type = job_type
  57. self._view = view
  58. self._verbose = verbose
  59. self._mark_synced = mark_synced
  60. self._app_url = app_url
  61. self._sync_tensorboard = sync_tensorboard
  62. self._log_path = log_path
  63. self._append = append
  64. self._skip_console = skip_console
  65. self._replace_tags = replace_tags or {}
  66. self._tmp_dir = tempfile.TemporaryDirectory()
  67. atexit.register(self._tmp_dir.cleanup)
  68. def _parse_pb(self, data, exit_pb=None):
  69. pb = wandb_internal_pb2.Record()
  70. pb.ParseFromString(data)
  71. record_type = pb.WhichOneof("record_type")
  72. if self._view:
  73. if self._verbose:
  74. print("Record:", pb) # noqa: T201
  75. else:
  76. print("Record:", record_type) # noqa: T201
  77. return pb, exit_pb, True
  78. if record_type == "run":
  79. if self._run_id:
  80. pb.run.run_id = self._run_id
  81. if self._project:
  82. pb.run.project = self._project
  83. if self._entity:
  84. pb.run.entity = self._entity
  85. if self._job_type:
  86. pb.run.job_type = self._job_type
  87. # Replace tags if specified
  88. if self._replace_tags:
  89. new_tags = [self._replace_tags.get(tag, tag) for tag in pb.run.tags]
  90. pb.run.ClearField("tags")
  91. pb.run.tags.extend(new_tags)
  92. pb.control.req_resp = True
  93. elif record_type in ("output", "output_raw") and self._skip_console:
  94. return pb, exit_pb, True
  95. elif record_type == "exit":
  96. exit_pb = pb
  97. return pb, exit_pb, True
  98. elif record_type == "final":
  99. assert exit_pb, "final seen without exit"
  100. pb = exit_pb
  101. exit_pb = None
  102. return pb, exit_pb, False
  103. def _find_tfevent_files(self, sync_item):
  104. tb_event_files = 0
  105. tb_logdirs = []
  106. tb_root = None
  107. if self._sync_tensorboard:
  108. if os.path.isdir(sync_item):
  109. files = []
  110. for dirpath, _, _files in os.walk(sync_item):
  111. for f in _files:
  112. if TFEVENT_SUBSTRING in f:
  113. files.append(os.path.join(dirpath, f))
  114. for tfevent in files:
  115. tb_event_files += 1
  116. tb_dir = os.path.dirname(os.path.abspath(tfevent))
  117. if tb_dir not in tb_logdirs:
  118. tb_logdirs.append(tb_dir)
  119. if len(tb_logdirs) > 0:
  120. tb_root = os.path.dirname(os.path.commonprefix(tb_logdirs))
  121. elif TFEVENT_SUBSTRING in sync_item:
  122. tb_root = os.path.dirname(os.path.abspath(sync_item))
  123. tb_logdirs.append(tb_root)
  124. tb_event_files = 1
  125. return tb_event_files, tb_logdirs, tb_root
  126. def _setup_tensorboard(self, tb_root, tb_logdirs, tb_event_files, sync_item):
  127. """Return true if this sync item can be synced as tensorboard."""
  128. if tb_root is not None:
  129. if tb_event_files > 0 and sync_item.endswith(WANDB_SUFFIX):
  130. wandb.termwarn("Found .wandb file, not streaming tensorboard metrics.")
  131. else:
  132. print(f"Found {tb_event_files} tfevent files in {tb_root}") # noqa: T201
  133. if len(tb_logdirs) > 3:
  134. wandb.termwarn(
  135. f"Found {len(tb_logdirs)} directories containing tfevent files. "
  136. "If these represent multiple experiments, sync them "
  137. "individually or pass a list of paths."
  138. )
  139. return True
  140. return False
  141. def _send_tensorboard(self, tb_root, tb_logdirs, send_manager):
  142. if self._entity is None:
  143. viewer, _ = send_manager._api.viewer_server_info()
  144. self._entity = viewer.get("entity")
  145. proto_run = wandb_internal_pb2.RunRecord()
  146. proto_run.run_id = self._run_id or wandb.util.generate_id()
  147. proto_run.project = self._project or wandb.util.auto_project_name(None)
  148. proto_run.entity = self._entity
  149. proto_run.telemetry.feature.sync_tfevents = True
  150. url = (
  151. f"{self._app_url}"
  152. f"/{url_quote(proto_run.entity)}"
  153. f"/{url_quote(proto_run.project)}"
  154. f"/runs/{url_quote(proto_run.run_id)}"
  155. )
  156. print(f"Syncing: {url} ...") # noqa: T201
  157. sys.stdout.flush()
  158. # using a handler here automatically handles the step
  159. # logic, adds summaries to the run, and handles different
  160. # file types (like images)... but we need to remake the send_manager
  161. record_q = queue.Queue()
  162. sender_record_q = queue.Queue()
  163. new_interface = InterfaceQueue(record_q)
  164. context_keeper = context.ContextKeeper()
  165. send_manager = sender.SendManager(
  166. settings=send_manager._settings,
  167. record_q=sender_record_q,
  168. result_q=queue.Queue(),
  169. interface=new_interface,
  170. context_keeper=context_keeper,
  171. )
  172. record = send_manager._interface._make_record(run=proto_run)
  173. settings = wandb.Settings(
  174. root_dir=self._tmp_dir.name,
  175. run_id=proto_run.run_id,
  176. x_start_time=time.time(),
  177. )
  178. settings_static = SettingsStatic(dict(settings))
  179. handle_manager = handler.HandleManager(
  180. settings=settings_static,
  181. record_q=record_q,
  182. result_q=None,
  183. stopped=False,
  184. writer_q=sender_record_q,
  185. interface=new_interface,
  186. context_keeper=context_keeper,
  187. )
  188. filesystem.mkdir_exists_ok(settings.files_dir)
  189. send_manager.send_run(record, file_dir=settings.files_dir)
  190. watcher = tb_watcher.TBWatcher(
  191. settings_static,
  192. proto_run,
  193. new_interface,
  194. True,
  195. )
  196. for tb in tb_logdirs:
  197. watcher.add(tb, True, tb_root)
  198. sys.stdout.flush()
  199. watcher.finish()
  200. # send all of our records like a boss
  201. progress_step = 0
  202. spinner_states = ["-", "\\", "|", "/"]
  203. line = " Uploading data to wandb\r"
  204. while len(handle_manager) > 0:
  205. data = next(handle_manager)
  206. handle_manager.handle(data)
  207. while len(send_manager) > 0:
  208. data = next(send_manager)
  209. send_manager.send(data)
  210. print_line = spinner_states[progress_step % 4] + line
  211. wandb.termlog(print_line, newline=False, prefix=True)
  212. progress_step += 1
  213. # finish sending any data
  214. while len(send_manager) > 0:
  215. data = next(send_manager)
  216. send_manager.send(data)
  217. sys.stdout.flush()
  218. handle_manager.finish()
  219. send_manager.finish()
  220. def _robust_scan(self, ds):
  221. """Attempt to scan data, handling incomplete files."""
  222. try:
  223. return ds.scan_data()
  224. except AssertionError as e:
  225. if ds.in_last_block():
  226. wandb.termwarn(
  227. f".wandb file is incomplete ({e}), be sure to sync this run again once it's finished"
  228. )
  229. return None
  230. else:
  231. raise
  232. def run(self):
  233. if self._log_path is not None:
  234. print(f"Find logs at: {self._log_path}") # noqa: T201
  235. for sync_item in self._sync_list:
  236. tb_event_files, tb_logdirs, tb_root = self._find_tfevent_files(sync_item)
  237. if os.path.isdir(sync_item):
  238. files = os.listdir(sync_item)
  239. filtered_files = list(filter(lambda f: f.endswith(WANDB_SUFFIX), files))
  240. if tb_root is None and (
  241. check_and_warn_old(files) or len(filtered_files) != 1
  242. ):
  243. print(f"Skipping directory: {sync_item}") # noqa: T201
  244. continue
  245. if len(filtered_files) > 0:
  246. sync_item = os.path.join(sync_item, filtered_files[0])
  247. sync_tb = self._setup_tensorboard(
  248. tb_root, tb_logdirs, tb_event_files, sync_item
  249. )
  250. # If we're syncing tensorboard, let's use a tmp dir for images etc.
  251. root_dir = self._tmp_dir.name if sync_tb else os.path.dirname(sync_item)
  252. # When appending we are allowing a possible resume, ie the run
  253. # does not have to exist already
  254. resume = "allow" if self._append else None
  255. sm = sender.SendManager.setup(root_dir, resume=resume)
  256. if sync_tb:
  257. self._send_tensorboard(tb_root, tb_logdirs, sm)
  258. continue
  259. ds = datastore.DataStore()
  260. try:
  261. ds.open_for_scan(sync_item)
  262. except AssertionError as e:
  263. print(f".wandb file is empty ({e}), skipping: {sync_item}") # noqa: T201
  264. continue
  265. # save exit for final send
  266. exit_pb = None
  267. finished = False
  268. shown = False
  269. while True:
  270. data = self._robust_scan(ds)
  271. if data is None:
  272. break
  273. pb, exit_pb, cont = self._parse_pb(data, exit_pb)
  274. if exit_pb is not None:
  275. finished = True
  276. if cont:
  277. continue
  278. sm.send(pb)
  279. # send any records that were added in previous send
  280. while not sm._record_q.empty():
  281. data = sm._record_q.get(block=True)
  282. sm.send(data)
  283. if pb.control.req_resp:
  284. result = sm._result_q.get(block=True)
  285. result_type = result.WhichOneof("result_type")
  286. if not shown and result_type == "run_result":
  287. r = result.run_result.run
  288. # TODO(jhr): hardcode until we have settings in sync
  289. url = (
  290. f"{self._app_url}"
  291. f"/{url_quote(r.entity)}"
  292. f"/{url_quote(r.project)}"
  293. f"/runs/{url_quote(r.run_id)}"
  294. )
  295. print(f"Syncing: {url} ... ", end="") # noqa: T201
  296. sys.stdout.flush()
  297. shown = True
  298. sm.finish()
  299. # Only mark synced if the run actually finished
  300. if self._mark_synced and not self._view and finished:
  301. synced_file = f"{sync_item}{SYNCED_SUFFIX}"
  302. with open(synced_file, "w"):
  303. pass
  304. print("done.") # noqa: T201
  305. class SyncManager:
  306. def __init__(
  307. self,
  308. project=None,
  309. entity=None,
  310. run_id=None,
  311. job_type=None,
  312. mark_synced=None,
  313. app_url=None,
  314. view=None,
  315. verbose=None,
  316. sync_tensorboard=None,
  317. log_path=None,
  318. append=None,
  319. skip_console=None,
  320. replace_tags=None,
  321. ):
  322. self._sync_list = []
  323. self._thread = None
  324. self._project = project
  325. self._entity = entity
  326. self._run_id = run_id
  327. self._job_type = job_type
  328. self._mark_synced = mark_synced
  329. self._app_url = app_url
  330. self._view = view
  331. self._verbose = verbose
  332. self._sync_tensorboard = sync_tensorboard
  333. self._log_path = log_path
  334. self._append = append
  335. self._skip_console = skip_console
  336. self._replace_tags = replace_tags or {}
  337. def status(self):
  338. pass
  339. def add(self, p):
  340. self._sync_list.append(os.path.abspath(str(p)))
  341. def start(self):
  342. # create a thread for each file?
  343. self._thread = SyncThread(
  344. sync_list=self._sync_list,
  345. project=self._project,
  346. entity=self._entity,
  347. run_id=self._run_id,
  348. job_type=self._job_type,
  349. view=self._view,
  350. verbose=self._verbose,
  351. mark_synced=self._mark_synced,
  352. app_url=self._app_url,
  353. sync_tensorboard=self._sync_tensorboard,
  354. log_path=self._log_path,
  355. append=self._append,
  356. skip_console=self._skip_console,
  357. replace_tags=self._replace_tags,
  358. )
  359. self._thread.start()
  360. def is_done(self):
  361. return not self._thread.is_alive()
  362. def poll(self):
  363. time.sleep(1)
  364. return False
  365. def get_runs(
  366. include_offline: bool = True,
  367. include_online: bool = True,
  368. include_synced: bool = False,
  369. include_unsynced: bool = True,
  370. exclude_globs: list[str] | None = None,
  371. include_globs: list[str] | None = None,
  372. ):
  373. # TODO(jhr): grab dir info from settings
  374. base = ".wandb" if os.path.exists(".wandb") else "wandb"
  375. if not os.path.exists(base):
  376. return ()
  377. all_dirs = os.listdir(base)
  378. dirs = []
  379. if include_offline:
  380. dirs += filter(lambda _d: _d.startswith("offline-run-"), all_dirs)
  381. if include_online:
  382. dirs += filter(lambda _d: _d.startswith("run-"), all_dirs)
  383. # find run file in each dir
  384. fnames = []
  385. dirs.sort()
  386. for d in dirs:
  387. paths = os.listdir(os.path.join(base, d))
  388. if exclude_globs:
  389. paths = set(paths)
  390. for g in exclude_globs:
  391. paths = paths - set(fnmatch.filter(paths, g))
  392. paths = list(paths)
  393. if include_globs:
  394. new_paths = set()
  395. for g in include_globs:
  396. new_paths = new_paths.union(fnmatch.filter(paths, g))
  397. paths = list(new_paths)
  398. for f in paths:
  399. if f.endswith(WANDB_SUFFIX):
  400. fnames.append(os.path.join(base, d, f))
  401. filtered = []
  402. for f in fnames:
  403. dname = os.path.dirname(f)
  404. # TODO(frz): online runs are assumed to be synced, verify from binary log.
  405. if os.path.exists(f"{f}{SYNCED_SUFFIX}") or os.path.basename(dname).startswith(
  406. "run-"
  407. ):
  408. if include_synced:
  409. filtered.append(_LocalRun(dname, True))
  410. else:
  411. if include_unsynced:
  412. filtered.append(_LocalRun(dname, False))
  413. return tuple(filtered)
  414. def get_run_from_path(path):
  415. return _LocalRun(path)