| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513 |
- from __future__ import annotations
- import abc
- import logging
- from typing import Any, cast
- from typing_extensions import override
- from wandb.proto import wandb_internal_pb2 as pb
- from wandb.proto import wandb_telemetry_pb2 as tpb
- from wandb.sdk.mailbox import MailboxHandle
- from wandb.util import json_dumps_safer, json_friendly
- from .interface import InterfaceBase
- logger = logging.getLogger("wandb")
- class InterfaceShared(InterfaceBase, abc.ABC):
- """Partially implemented InterfaceBase.
- There is little reason for this to exist separately from InterfaceBase,
- which itself is not a pure abstract class and has no other direct
- subclasses. Most methods are implemented in this class in terms of the
- protected _publish and _deliver methods defined by subclasses.
- """
- def __init__(self) -> None:
- super().__init__()
- @abc.abstractmethod
- def _publish(
- self,
- record: pb.Record,
- *,
- nowait: bool = False,
- ) -> None:
- """Send a record to the internal service.
- Args:
- record: The record to send. This method assigns its stream ID.
- nowait: If true, this does not block on socket IO and is safe
- to call in W&B's asyncio thread, but it will also not slow
- down even if the socket is blocked and allow data to accumulate
- in the Python memory.
- """
- @abc.abstractmethod
- def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
- """Send a record to the internal service and return a response handle.
- Args:
- record: The record to send. This method assigns its stream ID.
- Returns:
- A mailbox handle for waiting for a response.
- """
- @override
- def _publish_output(
- self,
- outdata: pb.OutputRecord,
- *,
- nowait: bool = False,
- ) -> None:
- rec = pb.Record()
- rec.output.CopyFrom(outdata)
- self._publish(rec, nowait=nowait)
- @override
- def _publish_output_raw(
- self,
- outdata: pb.OutputRawRecord,
- *,
- nowait: bool = False,
- ) -> None:
- rec = pb.Record()
- rec.output_raw.CopyFrom(outdata)
- self._publish(rec, nowait=nowait)
- def _publish_cancel(self, cancel: pb.CancelRequest) -> None:
- rec = self._make_request(cancel=cancel)
- self._publish(rec)
- def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None:
- rec = self._make_record(tbrecord=tbrecord)
- self._publish(rec)
- def _publish_partial_history(
- self, partial_history: pb.PartialHistoryRequest
- ) -> None:
- rec = self._make_request(partial_history=partial_history)
- self._publish(rec)
- def _publish_history(self, history: pb.HistoryRecord) -> None:
- rec = self._make_record(history=history)
- self._publish(rec)
- def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None:
- rec = self._make_record(preempting=preempt_rec)
- self._publish(rec)
- def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
- rec = self._make_record(telemetry=telem)
- self._publish(rec)
- def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
- rec = self._make_record(environment=environment)
- self._publish(rec)
- def _publish_job_input(
- self, job_input: pb.JobInputRequest
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(job_input=job_input)
- return self._deliver(record)
- def _make_stats(self, stats_dict: dict) -> pb.StatsRecord:
- stats = pb.StatsRecord()
- stats.stats_type = pb.StatsRecord.StatsType.SYSTEM
- stats.timestamp.GetCurrentTime() # todo: fix this, this is wrong :)
- for k, v in stats_dict.items():
- item = stats.item.add()
- item.key = k
- item.value_json = json_dumps_safer(json_friendly(v)[0])
- return stats
- def _make_request( # noqa: C901
- self,
- get_summary: pb.GetSummaryRequest | None = None,
- pause: pb.PauseRequest | None = None,
- resume: pb.ResumeRequest | None = None,
- status: pb.StatusRequest | None = None,
- stop_status: pb.StopStatusRequest | None = None,
- internal_messages: pb.InternalMessagesRequest | None = None,
- network_status: pb.NetworkStatusRequest | None = None,
- poll_exit: pb.PollExitRequest | None = None,
- partial_history: pb.PartialHistoryRequest | None = None,
- sampled_history: pb.SampledHistoryRequest | None = None,
- run_start: pb.RunStartRequest | None = None,
- check_version: pb.CheckVersionRequest | None = None,
- log_artifact: pb.LogArtifactRequest | None = None,
- download_artifact: pb.DownloadArtifactRequest | None = None,
- link_artifact: pb.LinkArtifactRequest | None = None,
- defer: pb.DeferRequest | None = None,
- attach: pb.AttachRequest | None = None,
- server_info: pb.ServerInfoRequest | None = None,
- keepalive: pb.KeepaliveRequest | None = None,
- run_status: pb.RunStatusRequest | None = None,
- sender_mark: pb.SenderMarkRequest | None = None,
- sender_read: pb.SenderReadRequest | None = None,
- sync_finish: pb.SyncFinishRequest | None = None,
- status_report: pb.StatusReportRequest | None = None,
- cancel: pb.CancelRequest | None = None,
- summary_record: pb.SummaryRecordRequest | None = None,
- telemetry_record: pb.TelemetryRecordRequest | None = None,
- get_system_metrics: pb.GetSystemMetricsRequest | None = None,
- python_packages: pb.PythonPackagesRequest | None = None,
- job_input: pb.JobInputRequest | None = None,
- run_finish_without_exit: pb.RunFinishWithoutExitRequest | None = None,
- probe_system_info: pb.ProbeSystemInfoRequest | None = None,
- ) -> pb.Record:
- request = pb.Request()
- if get_summary:
- request.get_summary.CopyFrom(get_summary)
- elif pause:
- request.pause.CopyFrom(pause)
- elif resume:
- request.resume.CopyFrom(resume)
- elif status:
- request.status.CopyFrom(status)
- elif stop_status:
- request.stop_status.CopyFrom(stop_status)
- elif internal_messages:
- request.internal_messages.CopyFrom(internal_messages)
- elif network_status:
- request.network_status.CopyFrom(network_status)
- elif poll_exit:
- request.poll_exit.CopyFrom(poll_exit)
- elif partial_history:
- request.partial_history.CopyFrom(partial_history)
- elif sampled_history:
- request.sampled_history.CopyFrom(sampled_history)
- elif run_start:
- request.run_start.CopyFrom(run_start)
- elif check_version:
- request.check_version.CopyFrom(check_version)
- elif log_artifact:
- request.log_artifact.CopyFrom(log_artifact)
- elif download_artifact:
- request.download_artifact.CopyFrom(download_artifact)
- elif link_artifact:
- request.link_artifact.CopyFrom(link_artifact)
- elif defer:
- request.defer.CopyFrom(defer)
- elif attach:
- request.attach.CopyFrom(attach)
- elif server_info:
- request.server_info.CopyFrom(server_info)
- elif keepalive:
- request.keepalive.CopyFrom(keepalive)
- elif run_status:
- request.run_status.CopyFrom(run_status)
- elif sender_mark:
- request.sender_mark.CopyFrom(sender_mark)
- elif sender_read:
- request.sender_read.CopyFrom(sender_read)
- elif cancel:
- request.cancel.CopyFrom(cancel)
- elif status_report:
- request.status_report.CopyFrom(status_report)
- elif summary_record:
- request.summary_record.CopyFrom(summary_record)
- elif telemetry_record:
- request.telemetry_record.CopyFrom(telemetry_record)
- elif get_system_metrics:
- request.get_system_metrics.CopyFrom(get_system_metrics)
- elif sync_finish:
- request.sync_finish.CopyFrom(sync_finish)
- elif python_packages:
- request.python_packages.CopyFrom(python_packages)
- elif job_input:
- request.job_input.CopyFrom(job_input)
- elif run_finish_without_exit:
- request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
- elif probe_system_info:
- request.probe_system_info.CopyFrom(probe_system_info)
- else:
- raise Exception("Invalid request")
- record = self._make_record(request=request)
- # All requests do not get persisted
- record.control.local = True
- if status_report:
- record.control.flow_control = True
- return record
- def _make_record( # noqa: C901
- self,
- run: pb.RunRecord | None = None,
- config: pb.ConfigRecord | None = None,
- files: pb.FilesRecord | None = None,
- summary: pb.SummaryRecord | None = None,
- history: pb.HistoryRecord | None = None,
- stats: pb.StatsRecord | None = None,
- exit: pb.RunExitRecord | None = None,
- artifact: pb.ArtifactRecord | None = None,
- tbrecord: pb.TBRecord | None = None,
- alert: pb.AlertRecord | None = None,
- final: pb.FinalRecord | None = None,
- metric: pb.MetricRecord | None = None,
- header: pb.HeaderRecord | None = None,
- footer: pb.FooterRecord | None = None,
- request: pb.Request | None = None,
- telemetry: tpb.TelemetryRecord | None = None,
- preempting: pb.RunPreemptingRecord | None = None,
- use_artifact: pb.UseArtifactRecord | None = None,
- output: pb.OutputRecord | None = None,
- output_raw: pb.OutputRawRecord | None = None,
- environment: pb.EnvironmentRecord | None = None,
- ) -> pb.Record:
- record = pb.Record()
- if run:
- record.run.CopyFrom(run)
- elif config:
- record.config.CopyFrom(config)
- elif summary:
- record.summary.CopyFrom(summary)
- elif history:
- record.history.CopyFrom(history)
- elif files:
- record.files.CopyFrom(files)
- elif stats:
- record.stats.CopyFrom(stats)
- elif exit:
- record.exit.CopyFrom(exit)
- elif artifact:
- record.artifact.CopyFrom(artifact)
- elif tbrecord:
- record.tbrecord.CopyFrom(tbrecord)
- elif alert:
- record.alert.CopyFrom(alert)
- elif final:
- record.final.CopyFrom(final)
- elif header:
- record.header.CopyFrom(header)
- elif footer:
- record.footer.CopyFrom(footer)
- elif request:
- record.request.CopyFrom(request)
- elif telemetry:
- record.telemetry.CopyFrom(telemetry)
- elif metric:
- record.metric.CopyFrom(metric)
- elif preempting:
- record.preempting.CopyFrom(preempting)
- elif use_artifact:
- record.use_artifact.CopyFrom(use_artifact)
- elif output:
- record.output.CopyFrom(output)
- elif output_raw:
- record.output_raw.CopyFrom(output_raw)
- elif environment:
- record.environment.CopyFrom(environment)
- else:
- raise Exception("Invalid record")
- return record
- def _publish_defer(self, state: pb.DeferRequest.DeferState) -> None:
- defer = pb.DeferRequest(state=state)
- rec = self._make_request(defer=defer)
- rec.control.local = True
- self._publish(rec)
- def publish_defer(self, state: int = 0) -> None:
- self._publish_defer(cast("pb.DeferRequest.DeferState", state))
- def _publish_header(self, header: pb.HeaderRecord) -> None:
- rec = self._make_record(header=header)
- self._publish(rec)
- def publish_footer(self) -> None:
- footer = pb.FooterRecord()
- rec = self._make_record(footer=footer)
- self._publish(rec)
- def publish_final(self) -> None:
- final = pb.FinalRecord()
- rec = self._make_record(final=final)
- self._publish(rec)
- def _publish_pause(self, pause: pb.PauseRequest) -> None:
- rec = self._make_request(pause=pause)
- self._publish(rec)
- def _publish_resume(self, resume: pb.ResumeRequest) -> None:
- rec = self._make_request(resume=resume)
- self._publish(rec)
- def _publish_run(self, run: pb.RunRecord) -> None:
- rec = self._make_record(run=run)
- self._publish(rec)
- def _publish_config(self, cfg: pb.ConfigRecord) -> None:
- rec = self._make_record(config=cfg)
- self._publish(rec)
- def _publish_summary(self, summary: pb.SummaryRecord) -> None:
- rec = self._make_record(summary=summary)
- self._publish(rec)
- def _publish_metric(self, metric: pb.MetricRecord) -> None:
- rec = self._make_record(metric=metric)
- self._publish(rec)
- def publish_stats(self, stats_dict: dict) -> None:
- stats = self._make_stats(stats_dict)
- rec = self._make_record(stats=stats)
- self._publish(rec)
- def _publish_python_packages(
- self, python_packages: pb.PythonPackagesRequest
- ) -> None:
- rec = self._make_request(python_packages=python_packages)
- self._publish(rec)
- def _publish_files(self, files: pb.FilesRecord) -> None:
- rec = self._make_record(files=files)
- self._publish(rec)
- def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any:
- rec = self._make_record(use_artifact=use_artifact)
- self._publish(rec)
- def _publish_probe_system_info(
- self, probe_system_info: pb.ProbeSystemInfoRequest
- ) -> None:
- record = self._make_request(probe_system_info=probe_system_info)
- self._publish(record)
- def _deliver_artifact(
- self,
- log_artifact: pb.LogArtifactRequest,
- ) -> MailboxHandle[pb.Result]:
- rec = self._make_request(log_artifact=log_artifact)
- return self._deliver(rec)
- def _deliver_download_artifact(
- self, download_artifact: pb.DownloadArtifactRequest
- ) -> MailboxHandle[pb.Result]:
- rec = self._make_request(download_artifact=download_artifact)
- return self._deliver(rec)
- def _deliver_link_artifact(
- self, link_artifact: pb.LinkArtifactRequest
- ) -> MailboxHandle[pb.Result]:
- rec = self._make_request(link_artifact=link_artifact)
- return self._deliver(rec)
- def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
- rec = self._make_record(artifact=proto_artifact)
- self._publish(rec)
- def _publish_alert(self, proto_alert: pb.AlertRecord) -> None:
- rec = self._make_record(alert=proto_alert)
- self._publish(rec)
- def _deliver_status(
- self,
- status: pb.StatusRequest,
- ) -> MailboxHandle[pb.Result]:
- req = self._make_request(status=status)
- return self._deliver(req)
- def _publish_exit(self, exit_data: pb.RunExitRecord) -> None:
- rec = self._make_record(exit=exit_data)
- self._publish(rec)
- def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None:
- record = self._make_request(keepalive=keepalive)
- self._publish(record)
- def _deliver_shutdown(self) -> MailboxHandle[pb.Result]:
- request = pb.Request(shutdown=pb.ShutdownRequest())
- record = self._make_record(request=request)
- return self._deliver(record)
- def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]:
- record = self._make_record(run=run)
- return self._deliver(record)
- def _deliver_finish_sync(
- self,
- sync_finish: pb.SyncFinishRequest,
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(sync_finish=sync_finish)
- return self._deliver(record)
- def _deliver_run_start(
- self,
- run_start: pb.RunStartRequest,
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(run_start=run_start)
- return self._deliver(record)
- def _deliver_get_summary(
- self,
- get_summary: pb.GetSummaryRequest,
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(get_summary=get_summary)
- return self._deliver(record)
- def _deliver_get_system_metrics(
- self, get_system_metrics: pb.GetSystemMetricsRequest
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(get_system_metrics=get_system_metrics)
- return self._deliver(record)
- def _deliver_exit(
- self,
- exit_data: pb.RunExitRecord,
- ) -> MailboxHandle[pb.Result]:
- record = self._make_record(exit=exit_data)
- return self._deliver(record)
- def _deliver_poll_exit(
- self,
- poll_exit: pb.PollExitRequest,
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(poll_exit=poll_exit)
- return self._deliver(record)
- def _deliver_finish_without_exit(
- self, run_finish_without_exit: pb.RunFinishWithoutExitRequest
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(run_finish_without_exit=run_finish_without_exit)
- return self._deliver(record)
- def _deliver_stop_status(
- self,
- stop_status: pb.StopStatusRequest,
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(stop_status=stop_status)
- return self._deliver(record)
- def _deliver_attach(
- self,
- attach: pb.AttachRequest,
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(attach=attach)
- return self._deliver(record)
- def _deliver_network_status(
- self, network_status: pb.NetworkStatusRequest
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(network_status=network_status)
- return self._deliver(record)
- def _deliver_internal_messages(
- self, internal_message: pb.InternalMessagesRequest
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(internal_messages=internal_message)
- return self._deliver(record)
- def _deliver_request_sampled_history(
- self, sampled_history: pb.SampledHistoryRequest
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(sampled_history=sampled_history)
- return self._deliver(record)
- def _deliver_request_run_status(
- self, run_status: pb.RunStatusRequest
- ) -> MailboxHandle[pb.Result]:
- record = self._make_request(run_status=run_status)
- return self._deliver(record)
|