interface_shared.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. from __future__ import annotations
  2. import abc
  3. import logging
  4. from typing import Any, cast
  5. from typing_extensions import override
  6. from wandb.proto import wandb_internal_pb2 as pb
  7. from wandb.proto import wandb_telemetry_pb2 as tpb
  8. from wandb.sdk.mailbox import MailboxHandle
  9. from wandb.util import json_dumps_safer, json_friendly
  10. from .interface import InterfaceBase
  11. logger = logging.getLogger("wandb")
  12. class InterfaceShared(InterfaceBase, abc.ABC):
  13. """Partially implemented InterfaceBase.
  14. There is little reason for this to exist separately from InterfaceBase,
  15. which itself is not a pure abstract class and has no other direct
  16. subclasses. Most methods are implemented in this class in terms of the
  17. protected _publish and _deliver methods defined by subclasses.
  18. """
  19. def __init__(self) -> None:
  20. super().__init__()
  21. @abc.abstractmethod
  22. def _publish(
  23. self,
  24. record: pb.Record,
  25. *,
  26. nowait: bool = False,
  27. ) -> None:
  28. """Send a record to the internal service.
  29. Args:
  30. record: The record to send. This method assigns its stream ID.
  31. nowait: If true, this does not block on socket IO and is safe
  32. to call in W&B's asyncio thread, but it will also not slow
  33. down even if the socket is blocked and allow data to accumulate
  34. in the Python memory.
  35. """
  36. @abc.abstractmethod
  37. def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
  38. """Send a record to the internal service and return a response handle.
  39. Args:
  40. record: The record to send. This method assigns its stream ID.
  41. Returns:
  42. A mailbox handle for waiting for a response.
  43. """
  44. @override
  45. def _publish_output(
  46. self,
  47. outdata: pb.OutputRecord,
  48. *,
  49. nowait: bool = False,
  50. ) -> None:
  51. rec = pb.Record()
  52. rec.output.CopyFrom(outdata)
  53. self._publish(rec, nowait=nowait)
  54. @override
  55. def _publish_output_raw(
  56. self,
  57. outdata: pb.OutputRawRecord,
  58. *,
  59. nowait: bool = False,
  60. ) -> None:
  61. rec = pb.Record()
  62. rec.output_raw.CopyFrom(outdata)
  63. self._publish(rec, nowait=nowait)
  64. def _publish_cancel(self, cancel: pb.CancelRequest) -> None:
  65. rec = self._make_request(cancel=cancel)
  66. self._publish(rec)
  67. def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None:
  68. rec = self._make_record(tbrecord=tbrecord)
  69. self._publish(rec)
  70. def _publish_partial_history(
  71. self, partial_history: pb.PartialHistoryRequest
  72. ) -> None:
  73. rec = self._make_request(partial_history=partial_history)
  74. self._publish(rec)
  75. def _publish_history(self, history: pb.HistoryRecord) -> None:
  76. rec = self._make_record(history=history)
  77. self._publish(rec)
  78. def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None:
  79. rec = self._make_record(preempting=preempt_rec)
  80. self._publish(rec)
  81. def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
  82. rec = self._make_record(telemetry=telem)
  83. self._publish(rec)
  84. def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
  85. rec = self._make_record(environment=environment)
  86. self._publish(rec)
  87. def _publish_job_input(
  88. self, job_input: pb.JobInputRequest
  89. ) -> MailboxHandle[pb.Result]:
  90. record = self._make_request(job_input=job_input)
  91. return self._deliver(record)
  92. def _make_stats(self, stats_dict: dict) -> pb.StatsRecord:
  93. stats = pb.StatsRecord()
  94. stats.stats_type = pb.StatsRecord.StatsType.SYSTEM
  95. stats.timestamp.GetCurrentTime() # todo: fix this, this is wrong :)
  96. for k, v in stats_dict.items():
  97. item = stats.item.add()
  98. item.key = k
  99. item.value_json = json_dumps_safer(json_friendly(v)[0])
  100. return stats
  101. def _make_request( # noqa: C901
  102. self,
  103. get_summary: pb.GetSummaryRequest | None = None,
  104. pause: pb.PauseRequest | None = None,
  105. resume: pb.ResumeRequest | None = None,
  106. status: pb.StatusRequest | None = None,
  107. stop_status: pb.StopStatusRequest | None = None,
  108. internal_messages: pb.InternalMessagesRequest | None = None,
  109. network_status: pb.NetworkStatusRequest | None = None,
  110. poll_exit: pb.PollExitRequest | None = None,
  111. partial_history: pb.PartialHistoryRequest | None = None,
  112. sampled_history: pb.SampledHistoryRequest | None = None,
  113. run_start: pb.RunStartRequest | None = None,
  114. check_version: pb.CheckVersionRequest | None = None,
  115. log_artifact: pb.LogArtifactRequest | None = None,
  116. download_artifact: pb.DownloadArtifactRequest | None = None,
  117. link_artifact: pb.LinkArtifactRequest | None = None,
  118. defer: pb.DeferRequest | None = None,
  119. attach: pb.AttachRequest | None = None,
  120. server_info: pb.ServerInfoRequest | None = None,
  121. keepalive: pb.KeepaliveRequest | None = None,
  122. run_status: pb.RunStatusRequest | None = None,
  123. sender_mark: pb.SenderMarkRequest | None = None,
  124. sender_read: pb.SenderReadRequest | None = None,
  125. sync_finish: pb.SyncFinishRequest | None = None,
  126. status_report: pb.StatusReportRequest | None = None,
  127. cancel: pb.CancelRequest | None = None,
  128. summary_record: pb.SummaryRecordRequest | None = None,
  129. telemetry_record: pb.TelemetryRecordRequest | None = None,
  130. get_system_metrics: pb.GetSystemMetricsRequest | None = None,
  131. python_packages: pb.PythonPackagesRequest | None = None,
  132. job_input: pb.JobInputRequest | None = None,
  133. run_finish_without_exit: pb.RunFinishWithoutExitRequest | None = None,
  134. probe_system_info: pb.ProbeSystemInfoRequest | None = None,
  135. ) -> pb.Record:
  136. request = pb.Request()
  137. if get_summary:
  138. request.get_summary.CopyFrom(get_summary)
  139. elif pause:
  140. request.pause.CopyFrom(pause)
  141. elif resume:
  142. request.resume.CopyFrom(resume)
  143. elif status:
  144. request.status.CopyFrom(status)
  145. elif stop_status:
  146. request.stop_status.CopyFrom(stop_status)
  147. elif internal_messages:
  148. request.internal_messages.CopyFrom(internal_messages)
  149. elif network_status:
  150. request.network_status.CopyFrom(network_status)
  151. elif poll_exit:
  152. request.poll_exit.CopyFrom(poll_exit)
  153. elif partial_history:
  154. request.partial_history.CopyFrom(partial_history)
  155. elif sampled_history:
  156. request.sampled_history.CopyFrom(sampled_history)
  157. elif run_start:
  158. request.run_start.CopyFrom(run_start)
  159. elif check_version:
  160. request.check_version.CopyFrom(check_version)
  161. elif log_artifact:
  162. request.log_artifact.CopyFrom(log_artifact)
  163. elif download_artifact:
  164. request.download_artifact.CopyFrom(download_artifact)
  165. elif link_artifact:
  166. request.link_artifact.CopyFrom(link_artifact)
  167. elif defer:
  168. request.defer.CopyFrom(defer)
  169. elif attach:
  170. request.attach.CopyFrom(attach)
  171. elif server_info:
  172. request.server_info.CopyFrom(server_info)
  173. elif keepalive:
  174. request.keepalive.CopyFrom(keepalive)
  175. elif run_status:
  176. request.run_status.CopyFrom(run_status)
  177. elif sender_mark:
  178. request.sender_mark.CopyFrom(sender_mark)
  179. elif sender_read:
  180. request.sender_read.CopyFrom(sender_read)
  181. elif cancel:
  182. request.cancel.CopyFrom(cancel)
  183. elif status_report:
  184. request.status_report.CopyFrom(status_report)
  185. elif summary_record:
  186. request.summary_record.CopyFrom(summary_record)
  187. elif telemetry_record:
  188. request.telemetry_record.CopyFrom(telemetry_record)
  189. elif get_system_metrics:
  190. request.get_system_metrics.CopyFrom(get_system_metrics)
  191. elif sync_finish:
  192. request.sync_finish.CopyFrom(sync_finish)
  193. elif python_packages:
  194. request.python_packages.CopyFrom(python_packages)
  195. elif job_input:
  196. request.job_input.CopyFrom(job_input)
  197. elif run_finish_without_exit:
  198. request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
  199. elif probe_system_info:
  200. request.probe_system_info.CopyFrom(probe_system_info)
  201. else:
  202. raise Exception("Invalid request")
  203. record = self._make_record(request=request)
  204. # All requests do not get persisted
  205. record.control.local = True
  206. if status_report:
  207. record.control.flow_control = True
  208. return record
  209. def _make_record( # noqa: C901
  210. self,
  211. run: pb.RunRecord | None = None,
  212. config: pb.ConfigRecord | None = None,
  213. files: pb.FilesRecord | None = None,
  214. summary: pb.SummaryRecord | None = None,
  215. history: pb.HistoryRecord | None = None,
  216. stats: pb.StatsRecord | None = None,
  217. exit: pb.RunExitRecord | None = None,
  218. artifact: pb.ArtifactRecord | None = None,
  219. tbrecord: pb.TBRecord | None = None,
  220. alert: pb.AlertRecord | None = None,
  221. final: pb.FinalRecord | None = None,
  222. metric: pb.MetricRecord | None = None,
  223. header: pb.HeaderRecord | None = None,
  224. footer: pb.FooterRecord | None = None,
  225. request: pb.Request | None = None,
  226. telemetry: tpb.TelemetryRecord | None = None,
  227. preempting: pb.RunPreemptingRecord | None = None,
  228. use_artifact: pb.UseArtifactRecord | None = None,
  229. output: pb.OutputRecord | None = None,
  230. output_raw: pb.OutputRawRecord | None = None,
  231. environment: pb.EnvironmentRecord | None = None,
  232. ) -> pb.Record:
  233. record = pb.Record()
  234. if run:
  235. record.run.CopyFrom(run)
  236. elif config:
  237. record.config.CopyFrom(config)
  238. elif summary:
  239. record.summary.CopyFrom(summary)
  240. elif history:
  241. record.history.CopyFrom(history)
  242. elif files:
  243. record.files.CopyFrom(files)
  244. elif stats:
  245. record.stats.CopyFrom(stats)
  246. elif exit:
  247. record.exit.CopyFrom(exit)
  248. elif artifact:
  249. record.artifact.CopyFrom(artifact)
  250. elif tbrecord:
  251. record.tbrecord.CopyFrom(tbrecord)
  252. elif alert:
  253. record.alert.CopyFrom(alert)
  254. elif final:
  255. record.final.CopyFrom(final)
  256. elif header:
  257. record.header.CopyFrom(header)
  258. elif footer:
  259. record.footer.CopyFrom(footer)
  260. elif request:
  261. record.request.CopyFrom(request)
  262. elif telemetry:
  263. record.telemetry.CopyFrom(telemetry)
  264. elif metric:
  265. record.metric.CopyFrom(metric)
  266. elif preempting:
  267. record.preempting.CopyFrom(preempting)
  268. elif use_artifact:
  269. record.use_artifact.CopyFrom(use_artifact)
  270. elif output:
  271. record.output.CopyFrom(output)
  272. elif output_raw:
  273. record.output_raw.CopyFrom(output_raw)
  274. elif environment:
  275. record.environment.CopyFrom(environment)
  276. else:
  277. raise Exception("Invalid record")
  278. return record
  279. def _publish_defer(self, state: pb.DeferRequest.DeferState) -> None:
  280. defer = pb.DeferRequest(state=state)
  281. rec = self._make_request(defer=defer)
  282. rec.control.local = True
  283. self._publish(rec)
  284. def publish_defer(self, state: int = 0) -> None:
  285. self._publish_defer(cast("pb.DeferRequest.DeferState", state))
  286. def _publish_header(self, header: pb.HeaderRecord) -> None:
  287. rec = self._make_record(header=header)
  288. self._publish(rec)
  289. def publish_footer(self) -> None:
  290. footer = pb.FooterRecord()
  291. rec = self._make_record(footer=footer)
  292. self._publish(rec)
  293. def publish_final(self) -> None:
  294. final = pb.FinalRecord()
  295. rec = self._make_record(final=final)
  296. self._publish(rec)
  297. def _publish_pause(self, pause: pb.PauseRequest) -> None:
  298. rec = self._make_request(pause=pause)
  299. self._publish(rec)
  300. def _publish_resume(self, resume: pb.ResumeRequest) -> None:
  301. rec = self._make_request(resume=resume)
  302. self._publish(rec)
  303. def _publish_run(self, run: pb.RunRecord) -> None:
  304. rec = self._make_record(run=run)
  305. self._publish(rec)
  306. def _publish_config(self, cfg: pb.ConfigRecord) -> None:
  307. rec = self._make_record(config=cfg)
  308. self._publish(rec)
  309. def _publish_summary(self, summary: pb.SummaryRecord) -> None:
  310. rec = self._make_record(summary=summary)
  311. self._publish(rec)
  312. def _publish_metric(self, metric: pb.MetricRecord) -> None:
  313. rec = self._make_record(metric=metric)
  314. self._publish(rec)
  315. def publish_stats(self, stats_dict: dict) -> None:
  316. stats = self._make_stats(stats_dict)
  317. rec = self._make_record(stats=stats)
  318. self._publish(rec)
  319. def _publish_python_packages(
  320. self, python_packages: pb.PythonPackagesRequest
  321. ) -> None:
  322. rec = self._make_request(python_packages=python_packages)
  323. self._publish(rec)
  324. def _publish_files(self, files: pb.FilesRecord) -> None:
  325. rec = self._make_record(files=files)
  326. self._publish(rec)
  327. def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any:
  328. rec = self._make_record(use_artifact=use_artifact)
  329. self._publish(rec)
  330. def _publish_probe_system_info(
  331. self, probe_system_info: pb.ProbeSystemInfoRequest
  332. ) -> None:
  333. record = self._make_request(probe_system_info=probe_system_info)
  334. self._publish(record)
  335. def _deliver_artifact(
  336. self,
  337. log_artifact: pb.LogArtifactRequest,
  338. ) -> MailboxHandle[pb.Result]:
  339. rec = self._make_request(log_artifact=log_artifact)
  340. return self._deliver(rec)
  341. def _deliver_download_artifact(
  342. self, download_artifact: pb.DownloadArtifactRequest
  343. ) -> MailboxHandle[pb.Result]:
  344. rec = self._make_request(download_artifact=download_artifact)
  345. return self._deliver(rec)
  346. def _deliver_link_artifact(
  347. self, link_artifact: pb.LinkArtifactRequest
  348. ) -> MailboxHandle[pb.Result]:
  349. rec = self._make_request(link_artifact=link_artifact)
  350. return self._deliver(rec)
  351. def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
  352. rec = self._make_record(artifact=proto_artifact)
  353. self._publish(rec)
  354. def _publish_alert(self, proto_alert: pb.AlertRecord) -> None:
  355. rec = self._make_record(alert=proto_alert)
  356. self._publish(rec)
  357. def _deliver_status(
  358. self,
  359. status: pb.StatusRequest,
  360. ) -> MailboxHandle[pb.Result]:
  361. req = self._make_request(status=status)
  362. return self._deliver(req)
  363. def _publish_exit(self, exit_data: pb.RunExitRecord) -> None:
  364. rec = self._make_record(exit=exit_data)
  365. self._publish(rec)
  366. def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None:
  367. record = self._make_request(keepalive=keepalive)
  368. self._publish(record)
  369. def _deliver_shutdown(self) -> MailboxHandle[pb.Result]:
  370. request = pb.Request(shutdown=pb.ShutdownRequest())
  371. record = self._make_record(request=request)
  372. return self._deliver(record)
  373. def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]:
  374. record = self._make_record(run=run)
  375. return self._deliver(record)
  376. def _deliver_finish_sync(
  377. self,
  378. sync_finish: pb.SyncFinishRequest,
  379. ) -> MailboxHandle[pb.Result]:
  380. record = self._make_request(sync_finish=sync_finish)
  381. return self._deliver(record)
  382. def _deliver_run_start(
  383. self,
  384. run_start: pb.RunStartRequest,
  385. ) -> MailboxHandle[pb.Result]:
  386. record = self._make_request(run_start=run_start)
  387. return self._deliver(record)
  388. def _deliver_get_summary(
  389. self,
  390. get_summary: pb.GetSummaryRequest,
  391. ) -> MailboxHandle[pb.Result]:
  392. record = self._make_request(get_summary=get_summary)
  393. return self._deliver(record)
  394. def _deliver_get_system_metrics(
  395. self, get_system_metrics: pb.GetSystemMetricsRequest
  396. ) -> MailboxHandle[pb.Result]:
  397. record = self._make_request(get_system_metrics=get_system_metrics)
  398. return self._deliver(record)
  399. def _deliver_exit(
  400. self,
  401. exit_data: pb.RunExitRecord,
  402. ) -> MailboxHandle[pb.Result]:
  403. record = self._make_record(exit=exit_data)
  404. return self._deliver(record)
  405. def _deliver_poll_exit(
  406. self,
  407. poll_exit: pb.PollExitRequest,
  408. ) -> MailboxHandle[pb.Result]:
  409. record = self._make_request(poll_exit=poll_exit)
  410. return self._deliver(record)
  411. def _deliver_finish_without_exit(
  412. self, run_finish_without_exit: pb.RunFinishWithoutExitRequest
  413. ) -> MailboxHandle[pb.Result]:
  414. record = self._make_request(run_finish_without_exit=run_finish_without_exit)
  415. return self._deliver(record)
  416. def _deliver_stop_status(
  417. self,
  418. stop_status: pb.StopStatusRequest,
  419. ) -> MailboxHandle[pb.Result]:
  420. record = self._make_request(stop_status=stop_status)
  421. return self._deliver(record)
  422. def _deliver_attach(
  423. self,
  424. attach: pb.AttachRequest,
  425. ) -> MailboxHandle[pb.Result]:
  426. record = self._make_request(attach=attach)
  427. return self._deliver(record)
  428. def _deliver_network_status(
  429. self, network_status: pb.NetworkStatusRequest
  430. ) -> MailboxHandle[pb.Result]:
  431. record = self._make_request(network_status=network_status)
  432. return self._deliver(record)
  433. def _deliver_internal_messages(
  434. self, internal_message: pb.InternalMessagesRequest
  435. ) -> MailboxHandle[pb.Result]:
  436. record = self._make_request(internal_messages=internal_message)
  437. return self._deliver(record)
  438. def _deliver_request_sampled_history(
  439. self, sampled_history: pb.SampledHistoryRequest
  440. ) -> MailboxHandle[pb.Result]:
  441. record = self._make_request(sampled_history=sampled_history)
  442. return self._deliver(record)
  443. def _deliver_request_run_status(
  444. self, run_status: pb.RunStatusRequest
  445. ) -> MailboxHandle[pb.Result]:
  446. record = self._make_request(run_status=run_status)
  447. return self._deliver(record)