sender.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682
  1. """sender."""
  2. from __future__ import annotations
  3. import contextlib
  4. import glob
  5. import gzip
  6. import json
  7. import logging
  8. import os
  9. import queue
  10. import threading
  11. import time
  12. import traceback
  13. from collections import defaultdict
  14. from collections.abc import Generator
  15. from datetime import datetime
  16. from queue import Queue
  17. from typing import TYPE_CHECKING, Any, Literal
  18. import requests
  19. import wandb
  20. from wandb import util
  21. from wandb.analytics import get_sentry
  22. from wandb.errors import CommError, UsageError
  23. from wandb.errors.util import ProtobufErrorHandler
  24. from wandb.filesync.dir_watcher import DirWatcher
  25. from wandb.proto import wandb_internal_pb2
  26. from wandb.sdk.artifacts.artifact_saver import ArtifactSaver
  27. from wandb.sdk.interface import interface
  28. from wandb.sdk.interface.interface_queue import InterfaceQueue
  29. from wandb.sdk.internal import (
  30. context,
  31. datastore,
  32. file_stream,
  33. internal_api,
  34. sender_config,
  35. )
  36. from wandb.sdk.internal.file_pusher import FilePusher
  37. from wandb.sdk.internal.job_builder import JobBuilder
  38. from wandb.sdk.internal.settings_static import SettingsStatic
  39. from wandb.sdk.lib import (
  40. config_util,
  41. filenames,
  42. filesystem,
  43. proto_util,
  44. redirect,
  45. retry,
  46. telemetry,
  47. )
  48. from wandb.sdk.lib.proto_util import message_to_dict
  49. if TYPE_CHECKING:
  50. from wandb.proto.wandb_internal_pb2 import (
  51. ArtifactManifest,
  52. ArtifactManifestEntry,
  53. ArtifactRecord,
  54. EnvironmentRecord,
  55. HttpResponse,
  56. LocalInfo,
  57. Record,
  58. Result,
  59. RunExitResult,
  60. RunRecord,
  61. SummaryRecord,
  62. )
  63. StreamLiterals = Literal["stdout", "stderr"]
  64. logger = logging.getLogger(__name__)
  65. _OUTPUT_MIN_CALLBACK_INTERVAL = 2 # seconds
  66. def _framework_priority() -> Generator[tuple[str, str], None, None]:
  67. yield from [
  68. ("lightgbm", "lightgbm"),
  69. ("catboost", "catboost"),
  70. ("xgboost", "xgboost"),
  71. ("transformers_huggingface", "huggingface"), # backwards compatibility
  72. ("transformers", "huggingface"),
  73. ("pytorch_ignite", "ignite"), # backwards compatibility
  74. ("ignite", "ignite"),
  75. ("pytorch_lightning", "lightning"),
  76. ("fastai", "fastai"),
  77. ("torch", "torch"),
  78. ("keras", "keras"),
  79. ("tensorflow", "tensorflow"),
  80. ("sklearn", "sklearn"),
  81. ]
  82. def _manifest_json_from_proto(manifest: ArtifactManifest) -> dict:
  83. if manifest.version == 1:
  84. if manifest.manifest_file_path:
  85. contents = {}
  86. with gzip.open(manifest.manifest_file_path, "rt") as f:
  87. for line in f:
  88. entry_json = json.loads(line)
  89. path = entry_json.pop("path")
  90. contents[path] = entry_json
  91. else:
  92. contents = {
  93. content.path: _manifest_entry_from_proto(content)
  94. for content in manifest.contents
  95. }
  96. else:
  97. raise ValueError(f"unknown artifact manifest version: {manifest.version}")
  98. return {
  99. "version": manifest.version,
  100. "storagePolicy": manifest.storage_policy,
  101. "storagePolicyConfig": {
  102. config.key: json.loads(config.value_json)
  103. for config in manifest.storage_policy_config
  104. },
  105. "contents": contents,
  106. }
  107. def _manifest_entry_from_proto(entry: ArtifactManifestEntry) -> dict:
  108. birth_artifact_id = entry.birth_artifact_id if entry.birth_artifact_id else None
  109. return {
  110. "digest": entry.digest,
  111. "birthArtifactID": birth_artifact_id,
  112. "ref": entry.ref if entry.ref else None,
  113. "size": entry.size if entry.size is not None else None,
  114. "local_path": entry.local_path if entry.local_path else None,
  115. "skip_cache": entry.skip_cache,
  116. "extra": {extra.key: json.loads(extra.value_json) for extra in entry.extra},
  117. }
  118. class ResumeState:
  119. resumed: bool
  120. step: int
  121. history: int
  122. events: int
  123. output: int
  124. runtime: float
  125. wandb_runtime: int | None
  126. summary: dict[str, Any] | None
  127. config: dict[str, Any] | None
  128. tags: list[str] | None
  129. def __init__(self) -> None:
  130. self.resumed = False
  131. self.step = 0
  132. self.history = 0
  133. self.events = 0
  134. self.output = 0
  135. self.runtime = 0
  136. # wandb_runtime is the canonical runtime (stored in summary._wandb.runtime)
  137. self.wandb_runtime = None
  138. self.summary = None
  139. self.config = None
  140. self.tags = None
  141. def __str__(self) -> str:
  142. obj = ",".join(map(lambda it: f"{it[0]}={it[1]}", vars(self).items()))
  143. return f"ResumeState({obj})"
  144. class _OutputRawStream:
  145. _stopped: threading.Event
  146. _queue: queue.Queue
  147. _emulator: redirect.TerminalEmulator
  148. _writer_thr: threading.Thread
  149. _reader_thr: threading.Thread
  150. def __init__(self, stream: str, sm: SendManager):
  151. self._stopped = threading.Event()
  152. self._queue = queue.Queue()
  153. self._emulator = redirect.TerminalEmulator()
  154. self._writer_thr = threading.Thread(
  155. target=sm._output_raw_writer_thread,
  156. kwargs=dict(stream=stream),
  157. daemon=True,
  158. name=f"OutRawWr-{stream}",
  159. )
  160. self._reader_thr = threading.Thread(
  161. target=sm._output_raw_reader_thread,
  162. kwargs=dict(stream=stream),
  163. daemon=True,
  164. name=f"OutRawRd-{stream}",
  165. )
  166. def start(self) -> None:
  167. self._writer_thr.start()
  168. self._reader_thr.start()
  169. class SendManager:
  170. UPDATE_CONFIG_TIME: int = 30
  171. UPDATE_STATUS_TIME: int = 5
  172. _settings: SettingsStatic
  173. _record_q: Queue[Record]
  174. _result_q: Queue[Result]
  175. _interface: InterfaceQueue
  176. _api_settings: dict[str, str]
  177. _partial_output: dict[str, str]
  178. _context_keeper: context.ContextKeeper
  179. _telemetry_obj: telemetry.TelemetryRecord
  180. _environment_obj: EnvironmentRecord
  181. _fs: file_stream.FileStreamApi | None
  182. _run: RunRecord | None
  183. _entity: str | None
  184. _project: str | None
  185. _dir_watcher: DirWatcher | None
  186. _pusher: FilePusher | None
  187. _record_exit: Record | None
  188. _exit_result: RunExitResult | None
  189. _resume_state: ResumeState
  190. _rewind_response: dict[str, Any] | None
  191. _cached_server_info: dict[str, Any]
  192. _cached_viewer: dict[str, Any]
  193. _ds: datastore.DataStore | None
  194. _output_raw_streams: dict[StreamLiterals, _OutputRawStream]
  195. _output_raw_file: filesystem.CRDedupedFile | None
  196. _send_record_num: int
  197. _send_end_offset: int
  198. _debounce_config_time: float
  199. _debounce_status_time: float
  200. def __init__(
  201. self,
  202. settings: SettingsStatic,
  203. record_q: Queue[Record],
  204. result_q: Queue[Result],
  205. interface: InterfaceQueue,
  206. context_keeper: context.ContextKeeper,
  207. ) -> None:
  208. self._settings = settings
  209. self._record_q = record_q
  210. self._result_q = result_q
  211. self._interface = interface
  212. self._context_keeper = context_keeper
  213. self._ds = None
  214. self._send_record_num = 0
  215. self._send_end_offset = 0
  216. self._fs = None
  217. self._pusher = None
  218. self._dir_watcher = None
  219. # State updated by login
  220. self._entity = None
  221. self._flags = None
  222. # State updated by wandb.init
  223. self._run = None
  224. self._project = None
  225. # keep track of config from key/val updates
  226. self._consolidated_config = sender_config.ConfigState()
  227. self._start_time: int = 0
  228. self._telemetry_obj = telemetry.TelemetryRecord()
  229. self._environment_obj = wandb_internal_pb2.EnvironmentRecord()
  230. self._config_metric_pbdict_list: list[dict[int, Any]] = []
  231. self._metadata_summary: dict[str, Any] = defaultdict()
  232. self._cached_summary: dict[str, Any] = dict()
  233. self._config_metric_index_dict: dict[str, int] = {}
  234. self._config_metric_dict: dict[str, wandb_internal_pb2.MetricRecord] = {}
  235. self._consolidated_summary: dict[str, Any] = dict()
  236. self._cached_server_info = dict()
  237. self._cached_viewer = dict()
  238. # State updated by resuming
  239. self._resume_state = ResumeState()
  240. self._rewind_response = None
  241. # State added when run_exit is initiated and complete
  242. self._record_exit = None
  243. self._exit_result = None
  244. self._api = internal_api.Api(
  245. default_settings=settings, retry_callback=self.retry_callback
  246. )
  247. self._api_settings = dict()
  248. # queue filled by retry_callback
  249. self._retry_q: Queue[HttpResponse] = queue.Queue()
  250. # do we need to debounce?
  251. self._config_needs_debounce: bool = False
  252. # TODO(jhr): do something better, why do we need to send full lines?
  253. self._partial_output = dict()
  254. self._exit_code = 0
  255. # internal vars for handing raw console output
  256. self._output_raw_streams = dict()
  257. self._output_raw_file = None
  258. # job builder
  259. self._job_builder = JobBuilder(
  260. settings,
  261. files_dir=settings.files_dir,
  262. )
  263. time_now = time.monotonic()
  264. self._debounce_config_time = time_now
  265. self._debounce_status_time = time_now
  266. @classmethod
  267. def setup(
  268. cls,
  269. root_dir: str,
  270. resume: None | bool | str,
  271. ) -> SendManager:
  272. """Set up a standalone SendManager.
  273. Exclusively used in `sync.py`.
  274. """
  275. files_dir = os.path.join(root_dir, "files")
  276. settings = wandb.Settings(
  277. x_files_dir=files_dir,
  278. root_dir=root_dir,
  279. # _start_time=0,
  280. resume=resume,
  281. # ignore_globs=(),
  282. x_sync=True,
  283. disable_job_creation=False,
  284. x_file_stream_timeout_seconds=0,
  285. )
  286. record_q: Queue[Record] = queue.Queue()
  287. result_q: Queue[Result] = queue.Queue()
  288. publish_interface = InterfaceQueue(record_q=record_q)
  289. context_keeper = context.ContextKeeper()
  290. return SendManager(
  291. settings=SettingsStatic(dict(settings)),
  292. record_q=record_q,
  293. result_q=result_q,
  294. interface=publish_interface,
  295. context_keeper=context_keeper,
  296. )
  297. def __len__(self) -> int:
  298. return self._record_q.qsize()
  299. def __enter__(self) -> SendManager:
  300. return self
  301. def __exit__(
  302. self,
  303. exc_type: type[BaseException] | None,
  304. exc_value: BaseException | None,
  305. exc_traceback: traceback.TracebackException | None,
  306. ) -> Literal[False]:
  307. while self:
  308. data = next(self)
  309. self.send(data)
  310. self.finish()
  311. return False
  312. def retry_callback(self, status: int, response_text: str) -> None:
  313. response = wandb_internal_pb2.HttpResponse()
  314. response.http_status_code = status
  315. response.http_response_text = response_text
  316. self._retry_q.put(response)
  317. def send(self, record: Record) -> None:
  318. self._update_record_num(record.num)
  319. self._update_end_offset(record.control.end_offset)
  320. record_type = record.WhichOneof("record_type")
  321. assert record_type
  322. handler_str = "send_" + record_type
  323. send_handler = getattr(self, handler_str, None)
  324. # Don't log output to reduce log noise
  325. if record_type not in {"output", "request", "output_raw"}:
  326. logger.debug(f"send: {record_type}")
  327. assert send_handler, f"unknown send handler: {handler_str}"
  328. context_id = context.context_id_from_record(record)
  329. api_context = self._context_keeper.get(context_id)
  330. try:
  331. self._api.set_local_context(api_context)
  332. send_handler(record)
  333. except retry.RetryCancelledError:
  334. logger.debug(f"Record cancelled: {record_type}")
  335. self._context_keeper.release(context_id)
  336. finally:
  337. self._api.clear_local_context()
  338. def send_preempting(self, _: Record) -> None:
  339. if self._fs:
  340. self._fs.enqueue_preempting()
  341. def send_request_sender_mark(self, _: Record) -> None:
  342. self._maybe_report_status(always=True)
  343. def send_request(self, record: Record) -> None:
  344. request_type = record.request.WhichOneof("request_type")
  345. assert request_type
  346. handler_str = "send_request_" + request_type
  347. send_handler = getattr(self, handler_str, None)
  348. if request_type != "network_status":
  349. logger.debug(f"send_request: {request_type}")
  350. assert send_handler, f"unknown handle: {handler_str}"
  351. send_handler(record)
  352. def _respond_result(self, result: Result) -> None:
  353. context_id = context.context_id_from_result(result)
  354. self._context_keeper.release(context_id)
  355. self._result_q.put(result)
  356. def _flatten(self, dictionary: dict) -> None:
  357. if isinstance(dictionary, dict):
  358. for k, v in list(dictionary.items()):
  359. if isinstance(v, dict):
  360. self._flatten(v)
  361. dictionary.pop(k)
  362. for k2, v2 in v.items():
  363. dictionary[k + "." + k2] = v2
  364. def _update_record_num(self, record_num: int) -> None:
  365. if not record_num:
  366. return
  367. # Currently how we handle offline mode and syncing is not
  368. # compatible with this assertion due to how the exit record
  369. # is (mis)handled:
  370. # - using "always_send" in offline mode to trigger defer
  371. # state machine
  372. # - skipping the exit record in `wandb sync` mode so that
  373. # it is always executed as the last record
  374. if not self._settings._offline and not self._settings.x_sync:
  375. assert record_num == self._send_record_num + 1
  376. self._send_record_num = record_num
  377. def _update_end_offset(self, end_offset: int) -> None:
  378. if not end_offset:
  379. return
  380. self._send_end_offset = end_offset
  381. def send_request_sender_read(self, record: Record) -> None:
  382. if self._ds is None:
  383. self._ds = datastore.DataStore()
  384. self._ds.open_for_scan(self._settings.sync_file)
  385. # TODO(cancel_paused): implement cancel_set logic
  386. # The idea is that there is an active request to cancel a
  387. # message that is being read from the transaction log below
  388. start_offset = record.request.sender_read.start_offset
  389. final_offset = record.request.sender_read.final_offset
  390. self._ds.seek(start_offset)
  391. current_end_offset = 0
  392. while current_end_offset < final_offset:
  393. data = self._ds.scan_data()
  394. assert data
  395. current_end_offset = self._ds.get_offset()
  396. send_record = wandb_internal_pb2.Record()
  397. send_record.ParseFromString(data)
  398. self._update_end_offset(current_end_offset)
  399. self.send(send_record)
  400. # make sure we perform deferred operations
  401. self.debounce()
  402. # make sure that we always update writer for every sended read request
  403. self._maybe_report_status(always=True)
  404. def send_request_stop_status(self, record: Record) -> None:
  405. result = proto_util._result_from_record(record)
  406. status_resp = result.response.stop_status_response
  407. status_resp.run_should_stop = False
  408. if self._entity and self._project and self._run and self._run.run_id:
  409. try:
  410. status_resp.run_should_stop = self._api.check_stop_requested(
  411. self._project, self._entity, self._run.run_id
  412. )
  413. except Exception as e:
  414. logger.warning("Failed to check stop requested status: %s", e)
  415. self._respond_result(result)
  416. def _maybe_update_config(self, always: bool = False) -> None:
  417. time_now = time.monotonic()
  418. if (
  419. not always
  420. and time_now < self._debounce_config_time + self.UPDATE_CONFIG_TIME
  421. ):
  422. return
  423. if self._config_needs_debounce:
  424. self._debounce_config()
  425. self._debounce_config_time = time_now
  426. def _maybe_report_status(self, always: bool = False) -> None:
  427. time_now = time.monotonic()
  428. if (
  429. not always
  430. and time_now < self._debounce_status_time + self.UPDATE_STATUS_TIME
  431. ):
  432. return
  433. self._debounce_status_time = time_now
  434. status_report = wandb_internal_pb2.StatusReportRequest(
  435. record_num=self._send_record_num,
  436. sent_offset=self._send_end_offset,
  437. )
  438. status_time = time.time()
  439. status_report.sync_time.FromMicroseconds(int(status_time * 1e6))
  440. record = self._interface._make_request(status_report=status_report)
  441. self._interface._publish(record)
  442. def debounce(self, final: bool = False) -> None:
  443. self._maybe_report_status(always=final)
  444. self._maybe_update_config(always=final)
  445. def _debounce_config(self) -> None:
  446. config_value_dict = self._config_backend_dict()
  447. # TODO(jhr): check result of upsert_run?
  448. if self._run:
  449. self._api.upsert_run(
  450. name=self._run.run_id,
  451. config=config_value_dict,
  452. **self._api_settings, # type: ignore
  453. )
  454. self._config_save(config_value_dict)
  455. self._config_needs_debounce = False
  456. def send_request_network_status(self, record: Record) -> None:
  457. result = proto_util._result_from_record(record)
  458. status_resp = result.response.network_status_response
  459. while True:
  460. try:
  461. status_resp.network_responses.append(self._retry_q.get_nowait())
  462. except queue.Empty:
  463. break
  464. except Exception as e:
  465. logger.warning(f"Error emptying retry queue: {e}")
  466. self._respond_result(result)
  467. def send_exit(self, record: Record) -> None:
  468. # track where the exit came from
  469. self._record_exit = record
  470. run_exit = record.exit
  471. self._exit_code = run_exit.exit_code
  472. logger.info("handling exit code: %s", run_exit.exit_code)
  473. runtime = run_exit.runtime
  474. logger.info("handling runtime: %s", run_exit.runtime)
  475. self._metadata_summary["runtime"] = runtime
  476. self._update_summary()
  477. # We need to give the request queue a chance to empty between states
  478. # so use handle_request_defer as a state machine.
  479. logger.info("send defer")
  480. self._interface.publish_defer()
  481. def send_final(self, record: Record) -> None:
  482. pass
  483. def _flush_run(self) -> None:
  484. pass
  485. def send_request_status_report(self, record: Record) -> None:
  486. # todo? this is just a noop to please wandb sync
  487. pass
  488. def send_request_defer(self, record: Record) -> None: # noqa: C901
  489. defer = record.request.defer
  490. state = defer.state
  491. logger.info(f"handle sender defer: {state}")
  492. def transition_state() -> None:
  493. state = defer.state + 1
  494. logger.info(f"send defer: {state}")
  495. self._interface.publish_defer(state)
  496. done = False
  497. if state == defer.BEGIN:
  498. transition_state()
  499. elif state == defer.FLUSH_RUN:
  500. self._flush_run()
  501. transition_state()
  502. elif state == defer.FLUSH_STATS:
  503. # NOTE: this is handled in handler.py:handle_request_defer()
  504. transition_state()
  505. elif state == defer.FLUSH_PARTIAL_HISTORY:
  506. # NOTE: this is handled in handler.py:handle_request_defer()
  507. transition_state()
  508. elif state == defer.FLUSH_TB:
  509. # NOTE: this is handled in handler.py:handle_request_defer()
  510. transition_state()
  511. elif state == defer.FLUSH_SUM:
  512. # NOTE: this is handled in handler.py:handle_request_defer()
  513. transition_state()
  514. elif state == defer.FLUSH_DEBOUNCER:
  515. self.debounce(final=True)
  516. transition_state()
  517. elif state == defer.FLUSH_OUTPUT:
  518. self._output_raw_finish()
  519. transition_state()
  520. elif state == defer.FLUSH_JOB:
  521. self._flush_job()
  522. transition_state()
  523. elif state == defer.FLUSH_DIR:
  524. if self._dir_watcher:
  525. self._dir_watcher.finish()
  526. self._dir_watcher = None
  527. transition_state()
  528. elif state == defer.FLUSH_FP:
  529. if self._pusher:
  530. # FilePusher generates some events for FileStreamApi, so we
  531. # need to wait for pusher to finish before going to the next
  532. # state to ensure that filestream gets all the events that we
  533. # want before telling it to finish up
  534. self._pusher.finish(transition_state)
  535. else:
  536. transition_state()
  537. elif state == defer.JOIN_FP:
  538. if self._pusher:
  539. self._pusher.join()
  540. transition_state()
  541. elif state == defer.FLUSH_FS:
  542. if self._fs:
  543. # TODO(jhr): now is a good time to output pending output lines
  544. self._fs.finish(self._exit_code)
  545. self._fs = None
  546. transition_state()
  547. elif state == defer.FLUSH_FINAL:
  548. self._interface.publish_final()
  549. self._interface.publish_footer()
  550. transition_state()
  551. elif state == defer.END:
  552. done = True
  553. else:
  554. raise AssertionError("unknown state")
  555. if not done:
  556. return
  557. exit_result = wandb_internal_pb2.RunExitResult()
  558. # mark exit done in case we are polling on exit
  559. self._exit_result = exit_result
  560. # Report response to mailbox
  561. if self._record_exit and self._record_exit.control.mailbox_slot:
  562. result = proto_util._result_from_record(self._record_exit)
  563. result.exit_result.CopyFrom(exit_result)
  564. self._respond_result(result)
  565. def send_request_poll_exit(self, record: Record) -> None:
  566. if not record.control.req_resp and not record.control.mailbox_slot:
  567. return
  568. result = proto_util._result_from_record(record)
  569. if self._pusher:
  570. _alive, status = self._pusher.get_status()
  571. file_counts = self._pusher.file_counts_by_category()
  572. resp = result.response.poll_exit_response
  573. resp.pusher_stats.uploaded_bytes = status.uploaded_bytes
  574. resp.pusher_stats.total_bytes = status.total_bytes
  575. resp.pusher_stats.deduped_bytes = status.deduped_bytes
  576. resp.file_counts.wandb_count = file_counts.wandb
  577. resp.file_counts.media_count = file_counts.media
  578. resp.file_counts.artifact_count = file_counts.artifact
  579. resp.file_counts.other_count = file_counts.other
  580. if self._exit_result:
  581. result.response.poll_exit_response.done = True
  582. result.response.poll_exit_response.exit_result.CopyFrom(self._exit_result)
  583. self._respond_result(result)
  584. def _setup_resume(self, run: RunRecord) -> wandb_internal_pb2.ErrorInfo | None:
  585. """Queries the backend for a run; fail if the settings are incompatible."""
  586. if not self._settings.resume:
  587. return None
  588. # TODO: This causes a race, we need to make the upsert atomically
  589. # only create or update depending on the resume config
  590. # we use the runs entity if set, otherwise fallback to users entity
  591. # todo: ensure entity is not None as self._entity is Optional[str]
  592. entity = run.entity or self._entity
  593. logger.info(
  594. "checking resume status for %s/%s/%s", entity, run.project, run.run_id
  595. )
  596. resume_status = self._api.run_resume_status(
  597. entity=entity, # type: ignore
  598. project_name=run.project,
  599. name=run.run_id,
  600. )
  601. # No resume status = run does not exist; No t key in wandbConfig = run exists but hasn't been inited
  602. if not resume_status or '"t":' not in resume_status.get("wandbConfig", ""):
  603. if self._settings.resume == "must":
  604. error = wandb_internal_pb2.ErrorInfo()
  605. error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE
  606. error.message = (
  607. "You provided an invalid value for the `resume` argument."
  608. f" The value 'must' is not a valid option for resuming a run ({run.run_id}) that has not been initialized."
  609. " Please check your inputs and try again with a valid run ID."
  610. " If you are trying to start a new run, please omit the `resume` argument or use `resume='allow'`."
  611. )
  612. return error
  613. return None
  614. #
  615. # handle cases where we have resume_status
  616. #
  617. if self._settings.resume == "never":
  618. error = wandb_internal_pb2.ErrorInfo()
  619. error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE
  620. error.message = (
  621. "You provided an invalid value for the `resume` argument."
  622. f" The value 'never' is not a valid option for resuming a run ({run.run_id}) that already exists."
  623. " Please check your inputs and try again with a valid value for the `resume` argument."
  624. )
  625. return error
  626. history = {}
  627. events = {}
  628. config = {}
  629. summary = {}
  630. try:
  631. events_rt = 0
  632. history_rt = 0
  633. history = json.loads(resume_status["historyTail"])
  634. if history:
  635. history = json.loads(history[-1])
  636. history_rt = history.get("_runtime", 0)
  637. events = json.loads(resume_status["eventsTail"])
  638. if events:
  639. events = json.loads(events[-1])
  640. events_rt = events.get("_runtime", 0)
  641. config = json.loads(resume_status["config"] or "{}")
  642. summary = json.loads(resume_status["summaryMetrics"] or "{}")
  643. new_runtime = summary.get("_wandb", {}).get("runtime", None)
  644. if new_runtime is not None:
  645. self._resume_state.wandb_runtime = new_runtime
  646. tags = resume_status.get("tags") or []
  647. except (IndexError, ValueError):
  648. logger.exception("unable to load resume tails")
  649. if self._settings.resume == "must":
  650. error = wandb_internal_pb2.ErrorInfo()
  651. error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE
  652. error.message = f"resume='must' but could not resume ({run.run_id}) "
  653. return error
  654. # TODO: Do we need to restore config / summary?
  655. # System metrics runtime is usually greater than history
  656. self._resume_state.runtime = max(events_rt, history_rt)
  657. last_step = history.get("_step", 0)
  658. history_line_count = resume_status["historyLineCount"]
  659. self._resume_state.step = last_step + 1 if history_line_count > 0 else last_step
  660. self._resume_state.history = history_line_count
  661. self._resume_state.events = resume_status["eventsLineCount"]
  662. self._resume_state.output = resume_status["logLineCount"]
  663. self._resume_state.config = config
  664. self._resume_state.summary = summary
  665. self._resume_state.tags = tags
  666. self._resume_state.resumed = True
  667. logger.info(f"configured resuming with: {self._resume_state}")
  668. return None
  669. def _telemetry_get_framework(self) -> str:
  670. """Get telemetry data for internal config structure."""
  671. # detect framework by checking what is loaded
  672. imports: telemetry.TelemetryImports
  673. if self._telemetry_obj.HasField("imports_finish"):
  674. imports = self._telemetry_obj.imports_finish
  675. elif self._telemetry_obj.HasField("imports_init"):
  676. imports = self._telemetry_obj.imports_init
  677. else:
  678. return ""
  679. framework = next(
  680. (n for f, n in _framework_priority() if getattr(imports, f, False)), ""
  681. )
  682. return framework
  683. def _config_backend_dict(self) -> sender_config.BackendConfigDict:
  684. config = self._consolidated_config or sender_config.ConfigState()
  685. return config.to_backend_dict(
  686. telemetry_record=self._telemetry_obj,
  687. framework=self._telemetry_get_framework(),
  688. start_time_millis=self._start_time,
  689. metric_pbdicts=self._config_metric_pbdict_list,
  690. environment_record=self._environment_obj,
  691. )
  692. def _config_save(
  693. self,
  694. config_value_dict: sender_config.BackendConfigDict,
  695. ) -> None:
  696. config_path = os.path.join(self._settings.files_dir, "config.yaml")
  697. config_util.save_config_file_from_dict(config_path, config_value_dict)
  698. def _sync_spell(self) -> None:
  699. """Sync this run with spell."""
  700. if not self._run:
  701. return
  702. try:
  703. env = os.environ
  704. self._interface.publish_config(
  705. key=("_wandb", "spell_url"), val=env.get("SPELL_RUN_URL")
  706. )
  707. url = f"{self._api.app_url}/{self._run.entity}/{self._run.project}/runs/{self._run.run_id}"
  708. requests.put(
  709. env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url",
  710. json={"access_token": env.get("WANDB_ACCESS_TOKEN"), "url": url},
  711. timeout=2,
  712. )
  713. except requests.RequestException:
  714. pass
  715. # TODO: do something if sync spell is not successful?
  716. def _setup_fork(self, server_run: dict):
  717. assert self._run
  718. assert self._run.branch_point
  719. first_step = int(self._run.branch_point.value) + 1
  720. self._resume_state.step = first_step
  721. self._resume_state.history = server_run.get("historyLineCount", 0)
  722. self._run.forked = True
  723. self._run.starting_step = first_step
  724. def _load_rewind_state(self, run: RunRecord):
  725. assert run.branch_point
  726. self._rewind_response = self._api.rewind_run(
  727. run_name=run.run_id,
  728. entity=run.entity or None,
  729. project=run.project or None,
  730. metric_name=run.branch_point.metric,
  731. metric_value=run.branch_point.value,
  732. program_path=self._settings.program or None,
  733. )
  734. self._resume_state.history = self._rewind_response.get("historyLineCount", 0)
  735. self._resume_state.config = json.loads(
  736. self._rewind_response.get("config", "{}")
  737. )
  738. def _install_rewind_state(self):
  739. assert self._run
  740. assert self._run.branch_point
  741. assert self._rewind_response
  742. first_step = int(self._run.branch_point.value) + 1
  743. self._resume_state.step = first_step
  744. # We set the fork flag here because rewind uses the forking
  745. # infrastructure under the hood. Setting `forked` here
  746. # ensures that run._step is properly set in the user process.
  747. self._run.forked = True
  748. self._run.starting_step = first_step
  749. def _handle_error(
  750. self,
  751. record: Record,
  752. error: wandb_internal_pb2.ErrorInfo,
  753. run: RunRecord,
  754. ) -> None:
  755. if record.control.req_resp or record.control.mailbox_slot:
  756. result = proto_util._result_from_record(record)
  757. result.run_result.run.CopyFrom(run)
  758. result.run_result.error.CopyFrom(error)
  759. self._respond_result(result)
  760. else:
  761. logger.error("Got error in async mode: %s", error.message)
  762. def send_run(self, record: Record, file_dir: str | None = None) -> None:
  763. run = record.run
  764. error = None
  765. is_wandb_init = self._run is None
  766. # save start time of a run
  767. self._start_time = int(run.start_time.ToMicroseconds() // 1e6)
  768. # update telemetry
  769. if run.telemetry:
  770. self._telemetry_obj.MergeFrom(run.telemetry)
  771. if self._settings.x_sync:
  772. self._telemetry_obj.feature.sync = True
  773. # build config dict
  774. config_value_dict: sender_config.BackendConfigDict | None = None
  775. if run.config:
  776. self._consolidated_config.update_from_proto(run.config)
  777. config_value_dict = self._config_backend_dict()
  778. self._config_save(config_value_dict)
  779. do_rewind = run.branch_point.run == run.run_id
  780. do_fork = not do_rewind and run.branch_point.run != ""
  781. do_resume = bool(self._settings.resume)
  782. num_resume_options_set = sum([do_fork, do_rewind, do_resume])
  783. if num_resume_options_set > 1:
  784. error = wandb_internal_pb2.ErrorInfo()
  785. error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE
  786. error.message = (
  787. "Multiple resume options specified. "
  788. "Please specify only one of `fork_from`, `resume`, or `resume_from`."
  789. )
  790. self._handle_error(record, error, run)
  791. if is_wandb_init:
  792. # Ensure we have a project to query for status
  793. if run.project == "":
  794. run.project = util.auto_project_name(self._settings.program)
  795. # Only check resume status on `wandb.init`
  796. if do_resume:
  797. error = self._setup_resume(run)
  798. elif do_rewind:
  799. error = self._load_rewind_state(run)
  800. if error is not None:
  801. self._handle_error(record, error, run)
  802. return
  803. # Save the resumed config
  804. if self._resume_state.config is not None:
  805. self._consolidated_config.merge_resumed_config(
  806. config_util.dict_strip_value_dict(self._resume_state.config)
  807. )
  808. config_value_dict = self._config_backend_dict()
  809. self._config_save(config_value_dict)
  810. # handle empty config
  811. # TODO(jhr): consolidate the 4 ways config is built:
  812. # (passed config, empty config, resume config, send_config)
  813. if not config_value_dict:
  814. config_value_dict = self._config_backend_dict()
  815. self._config_save(config_value_dict)
  816. try:
  817. server_run = self._init_run(run, config_value_dict)
  818. except (CommError, UsageError) as e:
  819. logger.error(e, exc_info=True)
  820. error = ProtobufErrorHandler.from_exception(e)
  821. self._handle_error(record, error, run)
  822. return
  823. assert self._run # self._run is configured in _init_run()
  824. if do_fork:
  825. error = self._setup_fork(server_run)
  826. if error is not None:
  827. self._handle_error(record, error, run)
  828. return
  829. if record.control.req_resp or record.control.mailbox_slot:
  830. result = proto_util._result_from_record(record)
  831. # TODO: we could do self._interface.publish_defer(resp) to notify
  832. # the handler not to actually perform server updates for this uuid
  833. # because the user process will send a summary update when we resume
  834. result.run_result.run.CopyFrom(self._run)
  835. self._respond_result(result)
  836. # Only spin up our threads on the first run message
  837. if is_wandb_init:
  838. self._start_run_threads(file_dir)
  839. else:
  840. logger.info("updated run: %s", self._run.run_id)
  841. def _update_resume_state(self, is_rewinding: bool, inserted: bool):
  842. assert self._run
  843. if self._resume_state.resumed:
  844. self._run.resumed = True
  845. if self._resume_state.wandb_runtime is not None:
  846. self._run.runtime = self._resume_state.wandb_runtime
  847. elif is_rewinding:
  848. # because is_rewinding is mutually exclusive with self._resume_state.resumed,
  849. # this block will always execute if is_rewinding is set
  850. self._install_rewind_state()
  851. else:
  852. # If the user is not resuming, and we didn't insert on upsert_run then
  853. # it is likely that we are overwriting the run which we might want to
  854. # prevent in the future. This could be a false signal since an upsert_run
  855. # message which gets retried in the network could also show up as not
  856. # inserted.
  857. if not inserted:
  858. # no need to flush this, it will get updated eventually
  859. self._telemetry_obj.feature.maybe_run_overwrite = True
  860. def _init_run(
  861. self,
  862. run: RunRecord,
  863. config_dict: sender_config.BackendConfigDict | None,
  864. ) -> dict:
  865. # We subtract the previous runs runtime when resuming
  866. start_time = (
  867. run.start_time.ToMicroseconds() / 1e6
  868. ) - self._resume_state.runtime
  869. # TODO: we don't check inserted currently, ultimately we should make
  870. # the upsert know the resume state and fail transactionally
  871. if self._resume_state and self._resume_state.tags and not run.tags:
  872. run.tags.extend(self._resume_state.tags)
  873. is_rewinding = bool(self._settings.resume_from)
  874. if is_rewinding:
  875. assert self._rewind_response
  876. server_run = self._rewind_response
  877. inserted = True
  878. else:
  879. server_run, inserted = self._api.upsert_run(
  880. name=run.run_id,
  881. entity=run.entity or None,
  882. project=run.project or None,
  883. group=run.run_group or None,
  884. job_type=run.job_type or None,
  885. display_name=run.display_name or None,
  886. notes=run.notes or None,
  887. tags=run.tags[:] or None,
  888. config=config_dict or None,
  889. sweep_name=run.sweep_id or None,
  890. host=run.host or None,
  891. program_path=self._settings.program or None,
  892. repo=run.git.remote_url or None,
  893. commit=run.git.commit or None,
  894. )
  895. # TODO: we don't want to create jobs in sweeps, since the
  896. # executable doesn't appear to be consistent
  897. if run.sweep_id:
  898. self._job_builder.disable = True
  899. self._run = run
  900. if self._resume_state.resumed and is_rewinding:
  901. # this should not ever be possible to hit, since we check for
  902. # resumption above and raise an error if resumption is specified
  903. # twice.
  904. raise ValueError(
  905. "Cannot attempt to rewind and resume a run - only one of "
  906. "`resume` or `resume_from` can be specified."
  907. )
  908. self._update_resume_state(is_rewinding, inserted)
  909. self._run.starting_step = self._resume_state.step
  910. self._run.start_time.FromMicroseconds(int(start_time * 1e6))
  911. self._run.config.CopyFrom(self._interface._make_config(config_dict))
  912. if self._resume_state.summary is not None:
  913. self._run.summary.CopyFrom(
  914. self._interface._make_summary_from_dict(self._resume_state.summary)
  915. )
  916. storage_id = server_run.get("id")
  917. if storage_id:
  918. self._run.storage_id = storage_id
  919. id = server_run.get("name")
  920. if id:
  921. self._api.set_current_run_id(id)
  922. display_name = server_run.get("displayName")
  923. if display_name:
  924. self._run.display_name = display_name
  925. project = server_run.get("project")
  926. # TODO: remove self._api.set_settings, and make self._project a property?
  927. if project:
  928. project_name = project.get("name")
  929. if project_name:
  930. self._run.project = project_name
  931. self._project = project_name
  932. self._api_settings["project"] = project_name
  933. self._api.set_setting("project", project_name)
  934. entity = project.get("entity")
  935. if entity:
  936. entity_name = entity.get("name")
  937. if entity_name:
  938. self._run.entity = entity_name
  939. self._entity = entity_name
  940. self._api_settings["entity"] = entity_name
  941. self._api.set_setting("entity", entity_name)
  942. sweep_id = server_run.get("sweepName")
  943. if sweep_id:
  944. self._run.sweep_id = sweep_id
  945. if os.getenv("SPELL_RUN_URL"):
  946. self._sync_spell()
  947. return server_run
  948. def _start_run_threads(self, file_dir: str | None = None) -> None:
  949. assert self._run # self._run is configured by caller
  950. self._fs = file_stream.FileStreamApi(
  951. self._api,
  952. self._run.run_id,
  953. self._run.start_time.ToMicroseconds() / 1e6,
  954. timeout=self._settings.x_file_stream_timeout_seconds or 0,
  955. settings=self._api_settings,
  956. )
  957. # Ensure the streaming polices have the proper offsets
  958. self._fs.set_file_policy("wandb-summary.json", file_stream.SummaryFilePolicy())
  959. self._fs.set_file_policy(
  960. "wandb-history.jsonl",
  961. file_stream.JsonlFilePolicy(start_chunk_id=self._resume_state.history),
  962. )
  963. self._fs.set_file_policy(
  964. "wandb-events.jsonl",
  965. file_stream.JsonlFilePolicy(start_chunk_id=self._resume_state.events),
  966. )
  967. self._fs.set_file_policy(
  968. "output.log",
  969. file_stream.CRDedupeFilePolicy(start_chunk_id=self._resume_state.output),
  970. )
  971. # hack to merge run_settings and self._settings object together
  972. # so that fields like entity or project are available to be attached to Sentry events.
  973. run_settings = message_to_dict(self._run)
  974. _settings = dict(self._settings)
  975. _settings.update(run_settings)
  976. get_sentry().configure_scope(tags=_settings, process_context="internal")
  977. self._fs.start()
  978. self._pusher = FilePusher(self._api, self._fs, settings=self._settings)
  979. self._dir_watcher = DirWatcher(self._settings, self._pusher, file_dir)
  980. logger.info(
  981. "run started: %s with start time %s",
  982. self._run.run_id,
  983. self._run.start_time.ToMicroseconds() / 1e6,
  984. )
  985. def _save_history(self, history_dict: dict[str, Any]) -> None:
  986. if self._fs:
  987. self._fs.push(filenames.HISTORY_FNAME, json.dumps(history_dict))
  988. def send_history(self, record: Record) -> None:
  989. history = record.history
  990. history_dict = proto_util.dict_from_proto_list(history.item)
  991. self._save_history(history_dict)
  992. def _update_summary_record(self, summary: SummaryRecord) -> None:
  993. summary_dict = proto_util.dict_from_proto_list(summary.update)
  994. self._cached_summary = summary_dict
  995. self._update_summary()
  996. def send_summary(self, record: Record) -> None:
  997. self._update_summary_record(record.summary)
  998. def send_request_summary_record(self, record: Record) -> None:
  999. self._update_summary_record(record.request.summary_record.summary)
  1000. def _update_summary(self) -> None:
  1001. summary_dict = self._cached_summary.copy()
  1002. summary_dict.pop("_wandb", None)
  1003. if self._metadata_summary:
  1004. summary_dict["_wandb"] = self._metadata_summary
  1005. # merge with consolidated summary
  1006. self._consolidated_summary.update(summary_dict)
  1007. json_summary = json.dumps(self._consolidated_summary)
  1008. if self._fs:
  1009. self._fs.push(filenames.SUMMARY_FNAME, json_summary)
  1010. # TODO(jhr): we should only write this at the end of the script
  1011. summary_path = os.path.join(self._settings.files_dir, filenames.SUMMARY_FNAME)
  1012. with open(summary_path, "w") as f:
  1013. f.write(json_summary)
  1014. self._save_file(filesystem.GlobStr(filenames.SUMMARY_FNAME))
  1015. def send_stats(self, record: Record) -> None:
  1016. stats = record.stats
  1017. if stats.stats_type != wandb_internal_pb2.StatsRecord.StatsType.SYSTEM:
  1018. return
  1019. if not self._fs:
  1020. return
  1021. if not self._run:
  1022. return
  1023. now_us = stats.timestamp.ToMicroseconds()
  1024. start_us = self._run.start_time.ToMicroseconds()
  1025. d = dict()
  1026. for item in stats.item:
  1027. try:
  1028. d[item.key] = json.loads(item.value_json)
  1029. except json.JSONDecodeError:
  1030. logger.exception("error decoding stats json: %s", item.value_json)
  1031. row: dict[str, Any] = dict(system=d)
  1032. self._flatten(row)
  1033. row["_wandb"] = True
  1034. row["_timestamp"] = now_us / 1e6
  1035. row["_runtime"] = (now_us - start_us) / 1e6
  1036. self._fs.push(filenames.EVENTS_FNAME, json.dumps(row))
  1037. # TODO(jhr): check fs.push results?
  1038. def _output_raw_finish(self) -> None:
  1039. for stream, output_raw in self._output_raw_streams.items():
  1040. output_raw._stopped.set()
  1041. # shut down threads
  1042. output_raw._writer_thr.join(timeout=5)
  1043. if output_raw._writer_thr.is_alive():
  1044. logger.info("processing output...")
  1045. output_raw._writer_thr.join()
  1046. output_raw._reader_thr.join()
  1047. # flush output buffers and files
  1048. self._output_raw_flush(stream)
  1049. self._output_raw_streams = {}
  1050. if self._output_raw_file:
  1051. self._output_raw_file.close()
  1052. self._output_raw_file = None
  1053. def _output_raw_writer_thread(self, stream: StreamLiterals) -> None:
  1054. while True:
  1055. output_raw = self._output_raw_streams[stream]
  1056. if output_raw._queue.empty():
  1057. if output_raw._stopped.is_set():
  1058. return
  1059. time.sleep(0.5)
  1060. continue
  1061. data = []
  1062. while not output_raw._queue.empty():
  1063. data.append(output_raw._queue.get())
  1064. if output_raw._stopped.is_set() and sum(map(len, data)) > 100000:
  1065. logger.warning("Terminal output too large. Logging without processing.")
  1066. self._output_raw_flush(stream)
  1067. for line in data:
  1068. self._output_raw_flush(stream, line)
  1069. # TODO: lets mark that this happened in telemetry
  1070. return
  1071. try:
  1072. output_raw._emulator.write("".join(data))
  1073. except Exception as e:
  1074. logger.warning(f"problem writing to output_raw emulator: {e}")
  1075. def _output_raw_reader_thread(self, stream: StreamLiterals) -> None:
  1076. output_raw = self._output_raw_streams[stream]
  1077. while not (output_raw._stopped.is_set() and output_raw._queue.empty()):
  1078. self._output_raw_flush(stream)
  1079. time.sleep(_OUTPUT_MIN_CALLBACK_INTERVAL)
  1080. def _output_raw_flush(
  1081. self, stream: StreamLiterals, data: str | None = None
  1082. ) -> None:
  1083. if data is None:
  1084. output_raw = self._output_raw_streams[stream]
  1085. try:
  1086. data = output_raw._emulator.read()
  1087. except Exception as e:
  1088. logger.warning(f"problem reading from output_raw emulator: {e}")
  1089. if data:
  1090. self._send_output_line(stream, data)
  1091. if self._output_raw_file:
  1092. self._output_raw_file.write(data.encode("utf-8"))
  1093. def send_request_python_packages(self, record: Record) -> None:
  1094. import os
  1095. from wandb.sdk.lib.filenames import REQUIREMENTS_FNAME
  1096. installed_packages_list = sorted(
  1097. f"{r.name}=={r.version}" for r in record.request.python_packages.package
  1098. )
  1099. with open(os.path.join(self._settings.files_dir, REQUIREMENTS_FNAME), "w") as f:
  1100. f.write("\n".join(installed_packages_list))
  1101. def send_output(self, record: Record) -> None:
  1102. if not self._fs:
  1103. return
  1104. out = record.output
  1105. stream: StreamLiterals = "stdout"
  1106. if out.output_type == wandb_internal_pb2.OutputRecord.OutputType.STDERR:
  1107. stream = "stderr"
  1108. line = out.line
  1109. self._send_output_line(stream, line)
  1110. def send_output_raw(self, record: Record) -> None:
  1111. if not self._fs:
  1112. return
  1113. out = record.output_raw
  1114. stream: StreamLiterals = "stdout"
  1115. if out.output_type == wandb_internal_pb2.OutputRawRecord.OutputType.STDERR:
  1116. stream = "stderr"
  1117. line = out.line
  1118. output_raw = self._output_raw_streams.get(stream)
  1119. if not output_raw:
  1120. output_raw = _OutputRawStream(stream=stream, sm=self)
  1121. self._output_raw_streams[stream] = output_raw
  1122. # open the console output file shared between both streams
  1123. if not self._output_raw_file:
  1124. output_log_path = os.path.join(
  1125. self._settings.files_dir, filenames.OUTPUT_FNAME
  1126. )
  1127. output_raw_file = None
  1128. try:
  1129. output_raw_file = filesystem.CRDedupedFile(
  1130. open(output_log_path, "wb")
  1131. )
  1132. except OSError as e:
  1133. logger.warning(f"could not open output_raw_file: {e}")
  1134. if output_raw_file:
  1135. self._output_raw_file = output_raw_file
  1136. output_raw.start()
  1137. output_raw._queue.put(line)
  1138. def _send_output_line(self, stream: StreamLiterals, line: str) -> None:
  1139. """Combined writer for raw and non raw output lines.
  1140. This is combined because they are both post emulator.
  1141. """
  1142. prepend = ""
  1143. if stream == "stderr":
  1144. prepend = "ERROR "
  1145. if not line.endswith("\n"):
  1146. self._partial_output.setdefault(stream, "")
  1147. if line.startswith("\r"):
  1148. # TODO: maybe we shouldn't just drop this, what if there was some \ns in the partial
  1149. # that should probably be the check instead of not line.endswith(\n")
  1150. # logger.info(f"Dropping data {self._partial_output[stream]}")
  1151. self._partial_output[stream] = ""
  1152. self._partial_output[stream] += line
  1153. # TODO(jhr): how do we make sure this gets flushed?
  1154. # we might need this for other stuff like telemetry
  1155. else:
  1156. # TODO(jhr): use time from timestamp proto
  1157. # TODO(jhr): do we need to make sure we write full lines?
  1158. # seems to be some issues with line breaks
  1159. cur_time = time.time()
  1160. timestamp = datetime.utcfromtimestamp(cur_time).isoformat() + " "
  1161. prev_str = self._partial_output.get(stream, "")
  1162. line = f"{prepend}{timestamp}{prev_str}{line}"
  1163. if self._fs:
  1164. self._fs.push(filenames.OUTPUT_FNAME, line)
  1165. self._partial_output[stream] = ""
  1166. def _update_config(self) -> None:
  1167. self._config_needs_debounce = True
  1168. def send_config(self, record: Record) -> None:
  1169. self._consolidated_config.update_from_proto(record.config)
  1170. self._update_config()
  1171. def send_metric(self, record: Record) -> None:
  1172. metric = record.metric
  1173. if metric.glob_name:
  1174. logger.warning("Seen metric with glob (shouldn't happen)")
  1175. return
  1176. # merge or overwrite
  1177. old_metric = self._config_metric_dict.get(
  1178. metric.name, wandb_internal_pb2.MetricRecord()
  1179. )
  1180. if metric._control.overwrite:
  1181. old_metric.CopyFrom(metric)
  1182. else:
  1183. old_metric.MergeFrom(metric)
  1184. self._config_metric_dict[metric.name] = old_metric
  1185. metric = old_metric
  1186. # convert step_metric to index
  1187. if metric.step_metric:
  1188. find_step_idx = self._config_metric_index_dict.get(metric.step_metric)
  1189. if find_step_idx is not None:
  1190. # make a copy of this metric as we will be modifying it
  1191. rec = wandb_internal_pb2.Record()
  1192. rec.metric.CopyFrom(metric)
  1193. metric = rec.metric
  1194. metric.ClearField("step_metric")
  1195. metric.step_metric_index = find_step_idx + 1
  1196. md: dict[int, Any] = proto_util.proto_encode_to_dict(metric)
  1197. find_idx = self._config_metric_index_dict.get(metric.name)
  1198. if find_idx is not None:
  1199. self._config_metric_pbdict_list[find_idx] = md
  1200. else:
  1201. next_idx = len(self._config_metric_pbdict_list)
  1202. self._config_metric_pbdict_list.append(md)
  1203. self._config_metric_index_dict[metric.name] = next_idx
  1204. self._debounce_config()
  1205. def _update_telemetry_record(self, telemetry: telemetry.TelemetryRecord) -> None:
  1206. self._telemetry_obj.MergeFrom(telemetry)
  1207. self._debounce_config()
  1208. def send_telemetry(self, record: Record) -> None:
  1209. self._update_telemetry_record(record.telemetry)
  1210. def send_request_telemetry_record(self, record: Record) -> None:
  1211. self._update_telemetry_record(record.request.telemetry_record.telemetry)
  1212. def _save_file(
  1213. self, fname: filesystem.GlobStr, policy: filesystem.PolicyName = "end"
  1214. ) -> None:
  1215. logger.info("saving file %s with policy %s", fname, policy)
  1216. if self._dir_watcher:
  1217. self._dir_watcher.update_policy(fname, policy)
  1218. def send_files(self, record: Record) -> None:
  1219. files = record.files
  1220. for k in files.files:
  1221. # TODO(jhr): fix paths with directories
  1222. self._save_file(
  1223. filesystem.GlobStr(glob.escape(k.path)),
  1224. interface.file_enum_to_policy(k.policy),
  1225. )
  1226. def send_header(self, record: Record) -> None:
  1227. pass
  1228. def send_footer(self, record: Record) -> None:
  1229. pass
  1230. def send_tbrecord(self, record: Record) -> None:
  1231. # tbrecord watching threads are handled by handler.py
  1232. pass
  1233. def _update_environment_record(self, environment: EnvironmentRecord) -> None:
  1234. self._environment_obj.MergeFrom(environment)
  1235. self._debounce_config()
  1236. def send_environment(self, record: Record) -> None:
  1237. """Inject environment info into config and upload as a JSON file."""
  1238. self._update_environment_record(record.environment)
  1239. environment_json = json.dumps(proto_util.message_to_dict(self._environment_obj))
  1240. with open(
  1241. os.path.join(self._settings.files_dir, filenames.METADATA_FNAME), "w"
  1242. ) as f:
  1243. f.write(environment_json)
  1244. self._save_file(filesystem.GlobStr(filenames.METADATA_FNAME), policy="now")
  1245. def send_request_link_artifact(self, record: Record) -> None:
  1246. if not (record.control.req_resp or record.control.mailbox_slot):
  1247. raise ValueError(
  1248. f"Expected either `req_resp` or `mailbox_slot`, got: {record.control!r}"
  1249. )
  1250. result = proto_util._result_from_record(record)
  1251. link = record.request.link_artifact
  1252. client_id = link.client_id
  1253. server_id = link.server_id
  1254. portfolio_name = link.portfolio_name
  1255. entity = link.portfolio_entity
  1256. project = link.portfolio_project
  1257. aliases = link.portfolio_aliases
  1258. organization = link.portfolio_organization
  1259. logger.debug(
  1260. f"link_artifact params - client_id={client_id}, server_id={server_id}, "
  1261. f"portfolio_name={portfolio_name}, entity={entity}, project={project}, "
  1262. f"organization={organization}"
  1263. )
  1264. if (client_id or server_id) and portfolio_name and entity and project:
  1265. try:
  1266. response = self._api.link_artifact(
  1267. client_id,
  1268. server_id,
  1269. portfolio_name,
  1270. entity,
  1271. project,
  1272. aliases,
  1273. organization,
  1274. )
  1275. result.response.link_artifact_response.version_index = response[
  1276. "versionIndex"
  1277. ]
  1278. except Exception as e:
  1279. org_or_entity = organization or entity
  1280. result.response.link_artifact_response.error_message = (
  1281. f"error linking artifact to "
  1282. f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
  1283. )
  1284. logger.warning("Failed to link artifact to portfolio: %s", e)
  1285. self._respond_result(result)
  1286. def send_use_artifact(self, record: Record) -> None:
  1287. """Pretend to send a used artifact.
  1288. This function doesn't actually send anything, it is just used internally.
  1289. """
  1290. use = record.use_artifact
  1291. if use.type == "job" and not use.partial.job_name:
  1292. self._job_builder.disable = True
  1293. elif use.partial.job_name:
  1294. # job is partial, let job builder rebuild job, set job source dict
  1295. self._job_builder.set_partial_source_id(use.id)
  1296. def send_request_log_artifact(self, record: Record) -> None:
  1297. result = proto_util._result_from_record(record)
  1298. artifact = record.request.log_artifact.artifact
  1299. history_step = record.request.log_artifact.history_step
  1300. try:
  1301. res = self._send_artifact(artifact, history_step)
  1302. assert res, "Unable to send artifact"
  1303. result.response.log_artifact_response.artifact_id = res["id"]
  1304. logger.info(f"logged artifact {artifact.name} - {res}")
  1305. except Exception as e:
  1306. result.response.log_artifact_response.error_message = (
  1307. f'error logging artifact "{artifact.type}/{artifact.name}": {e}'
  1308. )
  1309. self._respond_result(result)
  1310. def send_artifact(self, record: Record) -> None:
  1311. artifact = record.artifact
  1312. try:
  1313. res = self._send_artifact(artifact)
  1314. logger.info(f"sent artifact {artifact.name} - {res}")
  1315. except Exception:
  1316. logger.exception(
  1317. f'send_artifact: failed for artifact "{artifact.type}/{artifact.name}"'
  1318. )
  1319. def _send_artifact(
  1320. self, artifact: ArtifactRecord, history_step: int | None = None
  1321. ) -> dict | None:
  1322. from packaging.version import parse
  1323. assert self._pusher
  1324. saver = ArtifactSaver(
  1325. api=self._api,
  1326. digest=artifact.digest,
  1327. manifest_json=_manifest_json_from_proto(artifact.manifest),
  1328. file_pusher=self._pusher,
  1329. is_user_created=artifact.user_created,
  1330. )
  1331. if artifact.distributed_id:
  1332. max_cli_version = self._max_cli_version()
  1333. if max_cli_version is None or parse(max_cli_version) < parse("0.10.16"):
  1334. logger.warning(
  1335. "This W&B Server doesn't support distributed artifacts, "
  1336. "have your administrator install wandb/local >= 0.9.37"
  1337. )
  1338. return None
  1339. metadata = json.loads(artifact.metadata) if artifact.metadata else None
  1340. res = saver.save(
  1341. entity=artifact.entity,
  1342. project=artifact.project,
  1343. type=artifact.type,
  1344. name=artifact.name,
  1345. client_id=artifact.client_id,
  1346. sequence_client_id=artifact.sequence_client_id,
  1347. metadata=metadata,
  1348. ttl_duration_seconds=artifact.ttl_duration_seconds or None,
  1349. description=artifact.description or None,
  1350. aliases=artifact.aliases,
  1351. tags=artifact.tags,
  1352. use_after_commit=artifact.use_after_commit,
  1353. distributed_id=artifact.distributed_id,
  1354. finalize=artifact.finalize,
  1355. incremental=artifact.incremental_beta1,
  1356. history_step=history_step,
  1357. base_id=artifact.base_id or None,
  1358. )
  1359. self._job_builder._handle_server_artifact(res, artifact)
  1360. if artifact.manifest.manifest_file_path:
  1361. with contextlib.suppress(FileNotFoundError):
  1362. os.remove(artifact.manifest.manifest_file_path)
  1363. return res
  1364. def send_alert(self, record: Record) -> None:
  1365. from packaging.version import parse
  1366. alert = record.alert
  1367. max_cli_version = self._max_cli_version()
  1368. if max_cli_version is None or parse(max_cli_version) < parse("0.10.9"):
  1369. logger.warning(
  1370. "This W&B server doesn't support alerts, "
  1371. "have your administrator install wandb/local >= 0.9.31"
  1372. )
  1373. else:
  1374. try:
  1375. self._api.notify_scriptable_run_alert(
  1376. title=alert.title,
  1377. text=alert.text,
  1378. level=alert.level,
  1379. wait_duration=alert.wait_duration,
  1380. )
  1381. except Exception:
  1382. logger.exception(f"send_alert: failed for alert {alert.title!r}")
  1383. def finish(self) -> None:
  1384. logger.info("shutting down sender")
  1385. # if self._tb_watcher:
  1386. # self._tb_watcher.finish()
  1387. self._output_raw_finish()
  1388. if self._dir_watcher:
  1389. self._dir_watcher.finish()
  1390. self._dir_watcher = None
  1391. if self._pusher:
  1392. self._pusher.finish()
  1393. self._pusher.join()
  1394. self._pusher = None
  1395. if self._fs:
  1396. self._fs.finish(self._exit_code)
  1397. self._fs = None
  1398. get_sentry().end_session()
  1399. def _max_cli_version(self) -> str | None:
  1400. server_info = self.get_server_info()
  1401. max_cli_version = server_info.get("cliVersionInfo", {}).get(
  1402. "max_cli_version", None
  1403. )
  1404. if not isinstance(max_cli_version, str):
  1405. return None
  1406. return max_cli_version
  1407. def get_viewer_server_info(self) -> None:
  1408. if self._cached_server_info and self._cached_viewer:
  1409. return
  1410. self._cached_viewer, self._cached_server_info = self._api.viewer_server_info()
  1411. def get_viewer_info(self) -> dict[str, Any]:
  1412. if not self._cached_viewer:
  1413. self.get_viewer_server_info()
  1414. return self._cached_viewer
  1415. def get_server_info(self) -> dict[str, Any]:
  1416. if not self._cached_server_info:
  1417. self.get_viewer_server_info()
  1418. return self._cached_server_info
  1419. def get_local_info(self) -> LocalInfo:
  1420. """Queries the server to get the local version information.
  1421. First, we perform an introspection, if it returns empty we deduce that the
  1422. docker image is out-of-date. Otherwise, we use the returned values to deduce the
  1423. state of the local server.
  1424. """
  1425. local_info = wandb_internal_pb2.LocalInfo()
  1426. if self._settings._offline:
  1427. local_info.out_of_date = False
  1428. return local_info
  1429. latest_local_version = "latest"
  1430. # Assuming the query is successful if the result is empty it indicates that
  1431. # the backend is out of date since it doesn't have the desired field
  1432. server_info = self.get_server_info()
  1433. latest_local_version_info = server_info.get("latestLocalVersionInfo", {})
  1434. if latest_local_version_info is None:
  1435. local_info.out_of_date = False
  1436. else:
  1437. local_info.out_of_date = latest_local_version_info.get("outOfDate", True)
  1438. local_info.version = latest_local_version_info.get(
  1439. "latestVersionString", latest_local_version
  1440. )
  1441. return local_info
  1442. def _flush_job(self) -> None:
  1443. if self._job_builder.disable or self._settings._offline:
  1444. return
  1445. self._job_builder.set_config(self._consolidated_config.non_internal_config())
  1446. summary_dict = self._cached_summary.copy()
  1447. summary_dict.pop("_wandb", None)
  1448. self._job_builder.set_summary(summary_dict)
  1449. artifact = self._job_builder.build(api=self._api)
  1450. if artifact is not None and self._run is not None:
  1451. proto_artifact = self._interface._make_artifact(artifact)
  1452. proto_artifact.run_id = self._run.run_id
  1453. proto_artifact.project = self._run.project
  1454. proto_artifact.entity = self._run.entity
  1455. # TODO: this should be removed when the latest tag is handled
  1456. # by the backend (WB-12116)
  1457. proto_artifact.aliases.append("latest")
  1458. # add docker image tag
  1459. for alias in self._job_builder._aliases:
  1460. proto_artifact.aliases.append(alias)
  1461. proto_artifact.user_created = True
  1462. proto_artifact.use_after_commit = True
  1463. proto_artifact.finalize = True
  1464. self._interface._publish_artifact(proto_artifact)
  1465. def __next__(self) -> Record:
  1466. return self._record_q.get(block=True)
  1467. next = __next__