handler.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. """Handle Manager."""
  2. from __future__ import annotations
  3. import json
  4. import logging
  5. import math
  6. import numbers
  7. import time
  8. from collections import defaultdict
  9. from collections.abc import Iterable, Sequence
  10. from queue import Queue
  11. from threading import Event
  12. from typing import TYPE_CHECKING, Any, Callable, cast
  13. from wandb.errors.links import url_registry
  14. from wandb.proto.wandb_internal_pb2 import (
  15. HistoryRecord,
  16. InternalMessages,
  17. MetricRecord,
  18. Record,
  19. Result,
  20. RunRecord,
  21. SampledHistoryItem,
  22. SummaryItem,
  23. SummaryRecord,
  24. SummaryRecordRequest,
  25. )
  26. from ..interface.interface_queue import InterfaceQueue
  27. from ..lib import handler_util, proto_util
  28. from . import context, sample, tb_watcher
  29. from .settings_static import SettingsStatic
  30. if TYPE_CHECKING:
  31. from wandb.proto.wandb_internal_pb2 import MetricSummary
  32. SummaryDict = dict[str, Any]
  33. logger = logging.getLogger(__name__)
  34. # Update (March 5, 2024): Since ~2020/2021, when constructing the summary
  35. # object, we had replaced the artifact path for media types with the latest
  36. # artifact path. The primary purpose of this was to support live updating of
  37. # media objects in the UI (since the default artifact path was fully qualified
  38. # and would not update). However, in March of 2024, a bug was discovered with
  39. # this approach which causes this path to be incorrect in cases where the media
  40. # object is logged to another artifact before being logged to the run. Setting
  41. # this to `False` disables this copy behavior. The impact is that users will
  42. # need to refresh to see updates. Ironically, this updating behavior is not
  43. # currently supported in the UI, so the impact of this change is minimal.
  44. REPLACE_SUMMARY_ART_PATH_WITH_LATEST = False
  45. def _dict_nested_set(target: dict[str, Any], key_list: Sequence[str], v: Any) -> None:
  46. # recurse down the dictionary structure:
  47. for k in key_list[:-1]:
  48. target.setdefault(k, {})
  49. new_target = target.get(k)
  50. if TYPE_CHECKING:
  51. new_target = cast(dict[str, Any], new_target)
  52. target = new_target
  53. # use the last element of the key to write the leaf:
  54. target[key_list[-1]] = v
  55. class HandleManager:
  56. _consolidated_summary: SummaryDict
  57. _sampled_history: dict[str, sample.UniformSampleAccumulator]
  58. _partial_history: dict[str, Any]
  59. _run_proto: RunRecord | None
  60. _settings: SettingsStatic
  61. _record_q: Queue[Record]
  62. _result_q: Queue[Result]
  63. _stopped: Event
  64. _writer_q: Queue[Record]
  65. _interface: InterfaceQueue
  66. _tb_watcher: tb_watcher.TBWatcher | None
  67. _metric_defines: dict[str, MetricRecord]
  68. _metric_globs: dict[str, MetricRecord]
  69. _metric_track: dict[tuple[str, ...], float]
  70. _metric_copy: dict[tuple[str, ...], Any]
  71. _track_time: float | None
  72. _accumulate_time: float
  73. _run_start_time: float | None
  74. _context_keeper: context.ContextKeeper
  75. def __init__(
  76. self,
  77. settings: SettingsStatic,
  78. record_q: Queue[Record],
  79. result_q: Queue[Result],
  80. stopped: Event,
  81. writer_q: Queue[Record],
  82. interface: InterfaceQueue,
  83. context_keeper: context.ContextKeeper,
  84. ) -> None:
  85. self._settings = settings
  86. self._record_q = record_q
  87. self._result_q = result_q
  88. self._stopped = stopped
  89. self._writer_q = writer_q
  90. self._interface = interface
  91. self._context_keeper = context_keeper
  92. self._tb_watcher = None
  93. self._step = 0
  94. self._track_time = None
  95. self._accumulate_time = 0
  96. self._run_start_time = None
  97. # keep track of summary from key/val updates
  98. self._consolidated_summary = dict()
  99. self._sampled_history = defaultdict(sample.UniformSampleAccumulator)
  100. self._run_proto = None
  101. self._partial_history = dict()
  102. self._metric_defines = defaultdict(MetricRecord)
  103. self._metric_globs = defaultdict(MetricRecord)
  104. self._metric_track = dict()
  105. self._metric_copy = dict()
  106. self._internal_messages = InternalMessages()
  107. self._dropped_history = False
  108. def __len__(self) -> int:
  109. return self._record_q.qsize()
  110. def handle(self, record: Record) -> None:
  111. self._context_keeper.add_from_record(record)
  112. record_type = record.WhichOneof("record_type")
  113. assert record_type
  114. handler_str = "handle_" + record_type
  115. handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
  116. assert handler, f"unknown handle: {handler_str}" # type: ignore
  117. handler(record)
  118. def handle_request(self, record: Record) -> None:
  119. request_type = record.request.WhichOneof("request_type")
  120. assert request_type
  121. handler_str = "handle_request_" + request_type
  122. handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
  123. if request_type != "network_status":
  124. logger.debug(f"handle_request: {request_type}")
  125. assert handler, f"unknown handle: {handler_str}" # type: ignore
  126. handler(record)
  127. def _dispatch_record(self, record: Record, always_send: bool = False) -> None:
  128. if always_send:
  129. record.control.always_send = True
  130. self._writer_q.put(record)
  131. def _respond_result(self, result: Result) -> None:
  132. context_id = context.context_id_from_result(result)
  133. self._context_keeper.release(context_id)
  134. self._result_q.put(result)
  135. def debounce(self) -> None:
  136. pass
  137. def handle_request_cancel(self, record: Record) -> None:
  138. self._dispatch_record(record)
  139. def handle_request_defer(self, record: Record) -> None:
  140. defer = record.request.defer
  141. state = defer.state
  142. logger.info(f"handle defer: {state}")
  143. if state == defer.FLUSH_TB:
  144. if self._tb_watcher:
  145. # shutdown tensorboard workers so we get all metrics flushed
  146. self._tb_watcher.finish()
  147. self._tb_watcher = None
  148. elif state == defer.FLUSH_PARTIAL_HISTORY:
  149. self._flush_partial_history()
  150. elif state == defer.FLUSH_SUM:
  151. self._save_summary(self._consolidated_summary, flush=True)
  152. # defer is used to drive the sender finish state machine
  153. self._dispatch_record(record, always_send=True)
  154. def handle_request_python_packages(self, record: Record) -> None:
  155. self._dispatch_record(record)
  156. def handle_run(self, record: Record) -> None:
  157. if self._settings._offline:
  158. self._run_proto = record.run
  159. result = proto_util._result_from_record(record)
  160. result.run_result.run.CopyFrom(record.run)
  161. self._respond_result(result)
  162. self._dispatch_record(record)
  163. def handle_stats(self, record: Record) -> None:
  164. self._dispatch_record(record)
  165. def handle_config(self, record: Record) -> None:
  166. self._dispatch_record(record)
  167. def handle_output(self, record: Record) -> None:
  168. self._dispatch_record(record)
  169. def handle_output_raw(self, record: Record) -> None:
  170. self._dispatch_record(record)
  171. def handle_files(self, record: Record) -> None:
  172. self._dispatch_record(record)
  173. def handle_request_link_artifact(self, record: Record) -> None:
  174. self._dispatch_record(record)
  175. def handle_use_artifact(self, record: Record) -> None:
  176. self._dispatch_record(record)
  177. def handle_artifact(self, record: Record) -> None:
  178. self._dispatch_record(record)
  179. def handle_alert(self, record: Record) -> None:
  180. self._dispatch_record(record)
  181. def _save_summary(self, summary_dict: SummaryDict, flush: bool = False) -> None:
  182. summary = SummaryRecord()
  183. for k, v in summary_dict.items():
  184. update = summary.update.add()
  185. update.key = k
  186. update.value_json = json.dumps(v)
  187. if flush:
  188. record = Record(summary=summary)
  189. self._dispatch_record(record)
  190. elif not self._settings._offline:
  191. # Send this summary update as a request since we aren't persisting every update
  192. summary_record = SummaryRecordRequest(summary=summary)
  193. request_record = self._interface._make_request(
  194. summary_record=summary_record
  195. )
  196. self._dispatch_record(request_record)
  197. def _save_history(
  198. self,
  199. history: HistoryRecord,
  200. ) -> None:
  201. for item in history.item:
  202. # TODO(jhr) save nested keys?
  203. k = item.key
  204. v = json.loads(item.value_json)
  205. if isinstance(v, numbers.Real):
  206. self._sampled_history[k].add(v)
  207. def _update_summary_metrics(
  208. self,
  209. s: MetricSummary,
  210. kl: list[str],
  211. v: numbers.Real,
  212. float_v: float,
  213. goal_max: bool | None,
  214. ) -> bool:
  215. updated = False
  216. best_key: tuple[str, ...] | None = None
  217. if s.none:
  218. return False
  219. if s.copy and len(kl) > 1:
  220. # non-key list copy already done in _update_summary
  221. _dict_nested_set(self._consolidated_summary, kl, v)
  222. return True
  223. if s.last:
  224. last_key = tuple(kl + ["last"])
  225. old_last = self._metric_track.get(last_key)
  226. if old_last is None or float_v != old_last:
  227. self._metric_track[last_key] = float_v
  228. _dict_nested_set(self._consolidated_summary, last_key, v)
  229. updated = True
  230. if s.best:
  231. best_key = tuple(kl + ["best"])
  232. if s.max or best_key and goal_max:
  233. max_key = tuple(kl + ["max"])
  234. old_max = self._metric_track.get(max_key)
  235. if old_max is None or float_v > old_max:
  236. self._metric_track[max_key] = float_v
  237. if s.max:
  238. _dict_nested_set(self._consolidated_summary, max_key, v)
  239. updated = True
  240. if best_key:
  241. _dict_nested_set(self._consolidated_summary, best_key, v)
  242. updated = True
  243. # defaulting to minimize if goal is not specified
  244. if s.min or best_key and not goal_max:
  245. min_key = tuple(kl + ["min"])
  246. old_min = self._metric_track.get(min_key)
  247. if old_min is None or float_v < old_min:
  248. self._metric_track[min_key] = float_v
  249. if s.min:
  250. _dict_nested_set(self._consolidated_summary, min_key, v)
  251. updated = True
  252. if best_key:
  253. _dict_nested_set(self._consolidated_summary, best_key, v)
  254. updated = True
  255. if s.mean:
  256. tot_key = tuple(kl + ["tot"])
  257. num_key = tuple(kl + ["num"])
  258. avg_key = tuple(kl + ["mean"])
  259. tot = self._metric_track.get(tot_key, 0.0)
  260. num = self._metric_track.get(num_key, 0)
  261. tot += float_v
  262. num += 1
  263. self._metric_track[tot_key] = tot
  264. self._metric_track[num_key] = num
  265. _dict_nested_set(self._consolidated_summary, avg_key, tot / num)
  266. updated = True
  267. return updated
  268. def _update_summary_leaf(
  269. self,
  270. kl: list[str],
  271. v: Any,
  272. d: MetricRecord | None = None,
  273. ) -> bool:
  274. has_summary = d and d.HasField("summary")
  275. if len(kl) == 1:
  276. copy_key = tuple(kl)
  277. old_copy = self._metric_copy.get(copy_key)
  278. if old_copy is None or v != old_copy:
  279. self._metric_copy[copy_key] = v
  280. # Store copy metric if not specified, or copy behavior
  281. if not has_summary or (d and d.summary.copy):
  282. self._consolidated_summary[kl[0]] = v
  283. return True
  284. if not d:
  285. return False
  286. if not has_summary:
  287. return False
  288. if not isinstance(v, numbers.Real):
  289. return False
  290. if math.isnan(v):
  291. return False
  292. float_v = float(v)
  293. goal_max = None
  294. if d.goal:
  295. goal_max = d.goal == d.GOAL_MAXIMIZE
  296. return bool(
  297. self._update_summary_metrics(
  298. d.summary, kl=kl, v=v, float_v=float_v, goal_max=goal_max
  299. )
  300. )
  301. def _update_summary_list(
  302. self,
  303. kl: list[str],
  304. v: Any,
  305. d: MetricRecord | None = None,
  306. ) -> bool:
  307. metric_key = ".".join([k.replace(".", "\\.") for k in kl])
  308. d = self._metric_defines.get(metric_key, d)
  309. # if the dict has _type key, it's a wandb table object
  310. if isinstance(v, dict) and not handler_util.metric_is_wandb_dict(v):
  311. updated = False
  312. for nk, nv in v.items():
  313. if self._update_summary_list(kl=kl[:] + [nk], v=nv, d=d):
  314. updated = True
  315. return updated
  316. # If the dict is a media object, update the pointer to the latest alias
  317. elif (
  318. REPLACE_SUMMARY_ART_PATH_WITH_LATEST
  319. and isinstance(v, dict)
  320. and handler_util.metric_is_wandb_dict(v)
  321. ):
  322. if "_latest_artifact_path" in v and "artifact_path" in v:
  323. # TODO: Make non-destructive?
  324. v["artifact_path"] = v["_latest_artifact_path"]
  325. updated = self._update_summary_leaf(kl=kl, v=v, d=d)
  326. return updated
  327. def _update_summary_media_objects(self, v: dict[str, Any]) -> dict[str, Any]:
  328. # For now, non-recursive - just top level
  329. for nk, nv in v.items():
  330. if REPLACE_SUMMARY_ART_PATH_WITH_LATEST and (
  331. isinstance(nv, dict)
  332. and handler_util.metric_is_wandb_dict(nv)
  333. and "_latest_artifact_path" in nv
  334. and "artifact_path" in nv
  335. ):
  336. # TODO: Make non-destructive?
  337. nv["artifact_path"] = nv["_latest_artifact_path"]
  338. v[nk] = nv
  339. return v
  340. def _update_summary(self, history_dict: dict[str, Any]) -> list[str]:
  341. # keep old behavior fast path if no define metrics have been used
  342. if not self._metric_defines:
  343. history_dict = self._update_summary_media_objects(history_dict)
  344. self._consolidated_summary.update(history_dict)
  345. return list(history_dict.keys())
  346. updated_keys = []
  347. for k, v in history_dict.items():
  348. if self._update_summary_list(kl=[k], v=v):
  349. updated_keys.append(k)
  350. return updated_keys
  351. def _history_assign_step(
  352. self,
  353. history: HistoryRecord,
  354. history_dict: dict[str, Any],
  355. ) -> None:
  356. has_step = history.HasField("step")
  357. item = history.item.add()
  358. item.key = "_step"
  359. if has_step:
  360. step = history.step.num
  361. history_dict["_step"] = step
  362. item.value_json = json.dumps(step)
  363. self._step = step + 1
  364. else:
  365. history_dict["_step"] = self._step
  366. item.value_json = json.dumps(self._step)
  367. self._step += 1
  368. def _history_define_metric(self, hkey: str) -> MetricRecord | None:
  369. """Check for hkey match in glob metrics and return the defined metric."""
  370. # Dont define metric for internal metrics
  371. if hkey.startswith("_"):
  372. return None
  373. for k, mglob in self._metric_globs.items():
  374. if k.endswith("*") and hkey.startswith(k[:-1]):
  375. m = MetricRecord()
  376. m.CopyFrom(mglob)
  377. m.ClearField("glob_name")
  378. m.options.defined = False
  379. m.name = hkey
  380. return m
  381. return None
  382. def _history_update_leaf(
  383. self,
  384. kl: list[str],
  385. v: Any,
  386. history_dict: dict[str, Any],
  387. update_history: dict[str, Any],
  388. ) -> None:
  389. hkey = ".".join([k.replace(".", "\\.") for k in kl])
  390. m = self._metric_defines.get(hkey)
  391. if not m:
  392. m = self._history_define_metric(hkey)
  393. if not m:
  394. return
  395. mr = Record()
  396. mr.metric.CopyFrom(m)
  397. mr.control.local = True # Dont store this, just send it
  398. self._handle_defined_metric(mr)
  399. if m.options.step_sync and m.step_metric and m.step_metric not in history_dict:
  400. copy_key = tuple([m.step_metric])
  401. step = self._metric_copy.get(copy_key)
  402. if step is not None:
  403. update_history[m.step_metric] = step
  404. def _history_update_list(
  405. self,
  406. kl: list[str],
  407. v: Any,
  408. history_dict: dict[str, Any],
  409. update_history: dict[str, Any],
  410. ) -> None:
  411. if isinstance(v, dict):
  412. for nk, nv in v.items():
  413. self._history_update_list(
  414. kl=kl[:] + [nk],
  415. v=nv,
  416. history_dict=history_dict,
  417. update_history=update_history,
  418. )
  419. return
  420. self._history_update_leaf(
  421. kl=kl, v=v, history_dict=history_dict, update_history=update_history
  422. )
  423. def _history_update(
  424. self,
  425. history: HistoryRecord,
  426. history_dict: dict[str, Any],
  427. ) -> None:
  428. # if syncing an old run, we can skip this logic
  429. if history_dict.get("_step") is None:
  430. self._history_assign_step(history, history_dict)
  431. update_history: dict[str, Any] = {}
  432. # Look for metric matches
  433. if self._metric_defines or self._metric_globs:
  434. for hkey, hval in history_dict.items():
  435. self._history_update_list([hkey], hval, history_dict, update_history)
  436. if update_history:
  437. history_dict.update(update_history)
  438. for k, v in update_history.items():
  439. item = history.item.add()
  440. item.key = k
  441. item.value_json = json.dumps(v)
  442. def handle_history(self, record: Record) -> None:
  443. history_dict = proto_util.dict_from_proto_list(record.history.item)
  444. # Inject _runtime if it is not present
  445. if history_dict is not None and "_runtime" not in history_dict:
  446. self._history_assign_runtime(record.history, history_dict)
  447. self._history_update(record.history, history_dict)
  448. self._dispatch_record(record)
  449. self._save_history(record.history)
  450. # update summary from history
  451. updated_keys = self._update_summary(history_dict)
  452. if updated_keys:
  453. updated_items = {k: self._consolidated_summary[k] for k in updated_keys}
  454. self._save_summary(updated_items)
  455. def _flush_partial_history(
  456. self,
  457. step: int | None = None,
  458. ) -> None:
  459. if not self._partial_history:
  460. return
  461. history = HistoryRecord()
  462. for k, v in self._partial_history.items():
  463. item = history.item.add()
  464. item.key = k
  465. item.value_json = json.dumps(v)
  466. if step is not None:
  467. history.step.num = step
  468. self.handle_history(Record(history=history))
  469. self._partial_history = {}
  470. def handle_request_sender_mark_report(self, record: Record) -> None:
  471. self._dispatch_record(record, always_send=True)
  472. def handle_request_status_report(self, record: Record) -> None:
  473. self._dispatch_record(record, always_send=True)
  474. def handle_request_partial_history(self, record: Record) -> None:
  475. partial_history = record.request.partial_history
  476. flush = None
  477. if partial_history.HasField("action"):
  478. flush = partial_history.action.flush
  479. step = None
  480. if partial_history.HasField("step"):
  481. step = partial_history.step.num
  482. history_dict = proto_util.dict_from_proto_list(partial_history.item)
  483. if step is not None:
  484. if step < self._step:
  485. if not self._dropped_history:
  486. message = (
  487. "Step only supports monotonically increasing values, use define_metric to set a custom x "
  488. f"axis. For details see: {url_registry.url('define-metric')}"
  489. )
  490. self._internal_messages.warning.append(message)
  491. self._dropped_history = True
  492. message = (
  493. f"(User provided step: {step} is less than current step: {self._step}. "
  494. f"Dropping entry: {history_dict})."
  495. )
  496. self._internal_messages.warning.append(message)
  497. return
  498. elif step > self._step:
  499. self._flush_partial_history()
  500. self._step = step
  501. elif flush is None:
  502. flush = True
  503. self._partial_history.update(history_dict)
  504. if flush:
  505. self._flush_partial_history(self._step)
  506. def handle_summary(self, record: Record) -> None:
  507. summary = record.summary
  508. for item in summary.update:
  509. if len(item.nested_key) > 0:
  510. # we use either key or nested_key -- not both
  511. assert item.key == ""
  512. key = tuple(item.nested_key)
  513. else:
  514. # no counter-assertion here, because technically
  515. # summary[""] is valid
  516. key = (item.key,)
  517. target = self._consolidated_summary
  518. # recurse down the dictionary structure:
  519. for prop in key[:-1]:
  520. target = target[prop]
  521. # use the last element of the key to write the leaf:
  522. target[key[-1]] = json.loads(item.value_json)
  523. for item in summary.remove:
  524. if len(item.nested_key) > 0:
  525. # we use either key or nested_key -- not both
  526. assert item.key == ""
  527. key = tuple(item.nested_key)
  528. else:
  529. # no counter-assertion here, because technically
  530. # summary[""] is valid
  531. key = (item.key,)
  532. target = self._consolidated_summary
  533. # recurse down the dictionary structure:
  534. for prop in key[:-1]:
  535. target = target[prop]
  536. # use the last element of the key to erase the leaf:
  537. del target[key[-1]]
  538. self._save_summary(self._consolidated_summary)
  539. def handle_exit(self, record: Record) -> None:
  540. if self._track_time is not None:
  541. self._accumulate_time += time.time() - self._track_time
  542. record.exit.runtime = int(self._accumulate_time)
  543. self._dispatch_record(record, always_send=True)
  544. def handle_final(self, record: Record) -> None:
  545. self._dispatch_record(record, always_send=True)
  546. def handle_preempting(self, record: Record) -> None:
  547. self._dispatch_record(record)
  548. def handle_header(self, record: Record) -> None:
  549. self._dispatch_record(record)
  550. def handle_footer(self, record: Record) -> None:
  551. self._dispatch_record(record)
  552. def handle_metadata(self, record: Record) -> None:
  553. self._dispatch_record(record)
  554. def handle_request_attach(self, record: Record) -> None:
  555. result = proto_util._result_from_record(record)
  556. attach_id = record.request.attach.attach_id
  557. assert attach_id
  558. assert self._run_proto
  559. result.response.attach_response.run.CopyFrom(self._run_proto)
  560. self._respond_result(result)
  561. def handle_request_log_artifact(self, record: Record) -> None:
  562. self._dispatch_record(record)
  563. def handle_telemetry(self, record: Record) -> None:
  564. self._dispatch_record(record)
  565. def handle_request_run_start(self, record: Record) -> None:
  566. run_start = record.request.run_start
  567. assert run_start
  568. assert run_start.run
  569. self._run_proto = run_start.run
  570. self._run_start_time = run_start.run.start_time.ToMicroseconds() / 1e6
  571. self._track_time = time.time()
  572. if run_start.run.resumed and run_start.run.runtime:
  573. self._accumulate_time = run_start.run.runtime
  574. else:
  575. self._accumulate_time = 0
  576. self._tb_watcher = tb_watcher.TBWatcher(
  577. self._settings, interface=self._interface, run_proto=run_start.run
  578. )
  579. if run_start.run.resumed or run_start.run.forked:
  580. self._step = run_start.run.starting_step
  581. result = proto_util._result_from_record(record)
  582. self._respond_result(result)
  583. def handle_request_resume(self, record: Record) -> None:
  584. if self._track_time is not None:
  585. self._accumulate_time += time.time() - self._track_time
  586. self._track_time = time.time()
  587. def handle_request_pause(self, record: Record) -> None:
  588. if self._track_time is not None:
  589. self._accumulate_time += time.time() - self._track_time
  590. self._track_time = None
  591. def handle_request_poll_exit(self, record: Record) -> None:
  592. self._dispatch_record(record, always_send=True)
  593. def handle_request_stop_status(self, record: Record) -> None:
  594. self._dispatch_record(record)
  595. def handle_request_network_status(self, record: Record) -> None:
  596. self._dispatch_record(record)
  597. def handle_request_internal_messages(self, record: Record) -> None:
  598. result = proto_util._result_from_record(record)
  599. result.response.internal_messages_response.messages.CopyFrom(
  600. self._internal_messages
  601. )
  602. self._internal_messages.Clear()
  603. self._respond_result(result)
  604. def handle_request_status(self, record: Record) -> None:
  605. result = proto_util._result_from_record(record)
  606. self._respond_result(result)
  607. def handle_request_get_summary(self, record: Record) -> None:
  608. result = proto_util._result_from_record(record)
  609. for key, value in self._consolidated_summary.items():
  610. item = SummaryItem()
  611. item.key = key
  612. item.value_json = json.dumps(value)
  613. result.response.get_summary_response.item.append(item)
  614. self._respond_result(result)
  615. def handle_tbrecord(self, record: Record) -> None:
  616. logger.info("handling tbrecord: %s", record)
  617. if self._tb_watcher:
  618. tbrecord = record.tbrecord
  619. self._tb_watcher.add(tbrecord.log_dir, tbrecord.save, tbrecord.root_dir)
  620. self._dispatch_record(record)
  621. def _handle_defined_metric(self, record: Record) -> None:
  622. metric = record.metric
  623. if metric._control.overwrite:
  624. self._metric_defines[metric.name].CopyFrom(metric)
  625. else:
  626. self._metric_defines[metric.name].MergeFrom(metric)
  627. # before dispatching, make sure step_metric is defined, if not define it and
  628. # dispatch it locally first
  629. metric = self._metric_defines[metric.name]
  630. if metric.step_metric and metric.step_metric not in self._metric_defines:
  631. m = MetricRecord(name=metric.step_metric)
  632. self._metric_defines[metric.step_metric] = m
  633. mr = Record()
  634. mr.metric.CopyFrom(m)
  635. mr.control.local = True # Don't store this, just send it
  636. self._dispatch_record(mr)
  637. self._dispatch_record(record)
  638. def _handle_glob_metric(self, record: Record) -> None:
  639. metric = record.metric
  640. if metric._control.overwrite:
  641. self._metric_globs[metric.glob_name].CopyFrom(metric)
  642. else:
  643. self._metric_globs[metric.glob_name].MergeFrom(metric)
  644. self._dispatch_record(record)
  645. def handle_metric(self, record: Record) -> None:
  646. """Handle MetricRecord.
  647. Walkthrough of the life of a MetricRecord:
  648. Metric defined:
  649. - run.define_metric() parses arguments create wandb_metric.Metric
  650. - build MetricRecord publish to interface
  651. - handler (this function) keeps list of metrics published:
  652. - self._metric_defines: Fully defined metrics
  653. - self._metric_globs: metrics that have a wildcard
  654. - dispatch writer and sender thread
  655. - writer: records are saved to persistent store
  656. - sender: fully defined metrics get mapped into metadata for UI
  657. History logged:
  658. - handle_history
  659. - check if metric matches _metric_defines
  660. - if not, check if metric matches _metric_globs
  661. - if _metric globs match, generate defined metric and call _handle_metric
  662. Args:
  663. record (Record): Metric record to process
  664. """
  665. if record.metric.name:
  666. self._handle_defined_metric(record)
  667. elif record.metric.glob_name:
  668. self._handle_glob_metric(record)
  669. def handle_request_sampled_history(self, record: Record) -> None:
  670. result = proto_util._result_from_record(record)
  671. for key, sampled in self._sampled_history.items():
  672. item = SampledHistoryItem()
  673. item.key = key
  674. values: Iterable[Any] = sampled.get()
  675. if all(isinstance(i, numbers.Integral) for i in values):
  676. try:
  677. item.values_int.extend(values)
  678. except ValueError:
  679. # it is safe to ignore these as this is for display information
  680. pass
  681. elif all(isinstance(i, numbers.Real) for i in values):
  682. item.values_float.extend(values)
  683. result.response.sampled_history_response.item.append(item)
  684. self._respond_result(result)
  685. def handle_request_keepalive(self, record: Record) -> None:
  686. """Handle a keepalive request.
  687. Keepalive is a noop, we just want to verify transport is alive.
  688. """
  689. def handle_request_run_status(self, record: Record) -> None:
  690. self._dispatch_record(record, always_send=True)
  691. def handle_request_shutdown(self, record: Record) -> None:
  692. # TODO(jhr): should we drain things and stop new requests from coming in?
  693. result = proto_util._result_from_record(record)
  694. self._respond_result(result)
  695. self._stopped.set()
  696. def handle_request_operations(self, record: Record) -> None:
  697. """No-op. Not implemented for the legacy-service."""
  698. self._respond_result(proto_util._result_from_record(record))
  699. def finish(self) -> None:
  700. logger.info("shutting down handler")
  701. if self._tb_watcher:
  702. self._tb_watcher.finish()
  703. # self._context_keeper._debug_print_orphans()
  704. def __next__(self) -> Record:
  705. return self._record_q.get(block=True)
  706. next = __next__
  707. def _history_assign_runtime(
  708. self,
  709. history: HistoryRecord,
  710. history_dict: dict[str, Any],
  711. ) -> None:
  712. # _runtime calculation is meaningless if there is no _timestamp
  713. if "_timestamp" not in history_dict:
  714. return
  715. # if it is offline sync, self._run_start_time is None
  716. # in that case set it to the first tfevent timestamp
  717. if self._run_start_time is None:
  718. self._run_start_time = history_dict["_timestamp"]
  719. history_dict["_runtime"] = history_dict["_timestamp"] - self._run_start_time
  720. item = history.item.add()
  721. item.key = "_runtime"
  722. item.value_json = json.dumps(history_dict[item.key])