from __future__ import annotations import abc import logging from typing import Any, cast from typing_extensions import override from wandb.proto import wandb_internal_pb2 as pb from wandb.proto import wandb_telemetry_pb2 as tpb from wandb.sdk.mailbox import MailboxHandle from wandb.util import json_dumps_safer, json_friendly from .interface import InterfaceBase logger = logging.getLogger("wandb") class InterfaceShared(InterfaceBase, abc.ABC): """Partially implemented InterfaceBase. There is little reason for this to exist separately from InterfaceBase, which itself is not a pure abstract class and has no other direct subclasses. Most methods are implemented in this class in terms of the protected _publish and _deliver methods defined by subclasses. """ def __init__(self) -> None: super().__init__() @abc.abstractmethod def _publish( self, record: pb.Record, *, nowait: bool = False, ) -> None: """Send a record to the internal service. Args: record: The record to send. This method assigns its stream ID. nowait: If true, this does not block on socket IO and is safe to call in W&B's asyncio thread, but it will also not slow down even if the socket is blocked and allow data to accumulate in the Python memory. """ @abc.abstractmethod def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]: """Send a record to the internal service and return a response handle. Args: record: The record to send. This method assigns its stream ID. Returns: A mailbox handle for waiting for a response. """ @override def _publish_output( self, outdata: pb.OutputRecord, *, nowait: bool = False, ) -> None: rec = pb.Record() rec.output.CopyFrom(outdata) self._publish(rec, nowait=nowait) @override def _publish_output_raw( self, outdata: pb.OutputRawRecord, *, nowait: bool = False, ) -> None: rec = pb.Record() rec.output_raw.CopyFrom(outdata) self._publish(rec, nowait=nowait) def _publish_cancel(self, cancel: pb.CancelRequest) -> None: rec = self._make_request(cancel=cancel) self._publish(rec) def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None: rec = self._make_record(tbrecord=tbrecord) self._publish(rec) def _publish_partial_history( self, partial_history: pb.PartialHistoryRequest ) -> None: rec = self._make_request(partial_history=partial_history) self._publish(rec) def _publish_history(self, history: pb.HistoryRecord) -> None: rec = self._make_record(history=history) self._publish(rec) def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None: rec = self._make_record(preempting=preempt_rec) self._publish(rec) def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None: rec = self._make_record(telemetry=telem) self._publish(rec) def _publish_environment(self, environment: pb.EnvironmentRecord) -> None: rec = self._make_record(environment=environment) self._publish(rec) def _publish_job_input( self, job_input: pb.JobInputRequest ) -> MailboxHandle[pb.Result]: record = self._make_request(job_input=job_input) return self._deliver(record) def _make_stats(self, stats_dict: dict) -> pb.StatsRecord: stats = pb.StatsRecord() stats.stats_type = pb.StatsRecord.StatsType.SYSTEM stats.timestamp.GetCurrentTime() # todo: fix this, this is wrong :) for k, v in stats_dict.items(): item = stats.item.add() item.key = k item.value_json = json_dumps_safer(json_friendly(v)[0]) return stats def _make_request( # noqa: C901 self, get_summary: pb.GetSummaryRequest | None = None, pause: pb.PauseRequest | None = None, resume: pb.ResumeRequest | None = None, status: pb.StatusRequest | None = None, stop_status: pb.StopStatusRequest | None = None, internal_messages: pb.InternalMessagesRequest | None = None, network_status: pb.NetworkStatusRequest | None = None, poll_exit: pb.PollExitRequest | None = None, partial_history: pb.PartialHistoryRequest | None = None, sampled_history: pb.SampledHistoryRequest | None = None, run_start: pb.RunStartRequest | None = None, check_version: pb.CheckVersionRequest | None = None, log_artifact: pb.LogArtifactRequest | None = None, download_artifact: pb.DownloadArtifactRequest | None = None, link_artifact: pb.LinkArtifactRequest | None = None, defer: pb.DeferRequest | None = None, attach: pb.AttachRequest | None = None, server_info: pb.ServerInfoRequest | None = None, keepalive: pb.KeepaliveRequest | None = None, run_status: pb.RunStatusRequest | None = None, sender_mark: pb.SenderMarkRequest | None = None, sender_read: pb.SenderReadRequest | None = None, sync_finish: pb.SyncFinishRequest | None = None, status_report: pb.StatusReportRequest | None = None, cancel: pb.CancelRequest | None = None, summary_record: pb.SummaryRecordRequest | None = None, telemetry_record: pb.TelemetryRecordRequest | None = None, get_system_metrics: pb.GetSystemMetricsRequest | None = None, python_packages: pb.PythonPackagesRequest | None = None, job_input: pb.JobInputRequest | None = None, run_finish_without_exit: pb.RunFinishWithoutExitRequest | None = None, probe_system_info: pb.ProbeSystemInfoRequest | None = None, ) -> pb.Record: request = pb.Request() if get_summary: request.get_summary.CopyFrom(get_summary) elif pause: request.pause.CopyFrom(pause) elif resume: request.resume.CopyFrom(resume) elif status: request.status.CopyFrom(status) elif stop_status: request.stop_status.CopyFrom(stop_status) elif internal_messages: request.internal_messages.CopyFrom(internal_messages) elif network_status: request.network_status.CopyFrom(network_status) elif poll_exit: request.poll_exit.CopyFrom(poll_exit) elif partial_history: request.partial_history.CopyFrom(partial_history) elif sampled_history: request.sampled_history.CopyFrom(sampled_history) elif run_start: request.run_start.CopyFrom(run_start) elif check_version: request.check_version.CopyFrom(check_version) elif log_artifact: request.log_artifact.CopyFrom(log_artifact) elif download_artifact: request.download_artifact.CopyFrom(download_artifact) elif link_artifact: request.link_artifact.CopyFrom(link_artifact) elif defer: request.defer.CopyFrom(defer) elif attach: request.attach.CopyFrom(attach) elif server_info: request.server_info.CopyFrom(server_info) elif keepalive: request.keepalive.CopyFrom(keepalive) elif run_status: request.run_status.CopyFrom(run_status) elif sender_mark: request.sender_mark.CopyFrom(sender_mark) elif sender_read: request.sender_read.CopyFrom(sender_read) elif cancel: request.cancel.CopyFrom(cancel) elif status_report: request.status_report.CopyFrom(status_report) elif summary_record: request.summary_record.CopyFrom(summary_record) elif telemetry_record: request.telemetry_record.CopyFrom(telemetry_record) elif get_system_metrics: request.get_system_metrics.CopyFrom(get_system_metrics) elif sync_finish: request.sync_finish.CopyFrom(sync_finish) elif python_packages: request.python_packages.CopyFrom(python_packages) elif job_input: request.job_input.CopyFrom(job_input) elif run_finish_without_exit: request.run_finish_without_exit.CopyFrom(run_finish_without_exit) elif probe_system_info: request.probe_system_info.CopyFrom(probe_system_info) else: raise Exception("Invalid request") record = self._make_record(request=request) # All requests do not get persisted record.control.local = True if status_report: record.control.flow_control = True return record def _make_record( # noqa: C901 self, run: pb.RunRecord | None = None, config: pb.ConfigRecord | None = None, files: pb.FilesRecord | None = None, summary: pb.SummaryRecord | None = None, history: pb.HistoryRecord | None = None, stats: pb.StatsRecord | None = None, exit: pb.RunExitRecord | None = None, artifact: pb.ArtifactRecord | None = None, tbrecord: pb.TBRecord | None = None, alert: pb.AlertRecord | None = None, final: pb.FinalRecord | None = None, metric: pb.MetricRecord | None = None, header: pb.HeaderRecord | None = None, footer: pb.FooterRecord | None = None, request: pb.Request | None = None, telemetry: tpb.TelemetryRecord | None = None, preempting: pb.RunPreemptingRecord | None = None, use_artifact: pb.UseArtifactRecord | None = None, output: pb.OutputRecord | None = None, output_raw: pb.OutputRawRecord | None = None, environment: pb.EnvironmentRecord | None = None, ) -> pb.Record: record = pb.Record() if run: record.run.CopyFrom(run) elif config: record.config.CopyFrom(config) elif summary: record.summary.CopyFrom(summary) elif history: record.history.CopyFrom(history) elif files: record.files.CopyFrom(files) elif stats: record.stats.CopyFrom(stats) elif exit: record.exit.CopyFrom(exit) elif artifact: record.artifact.CopyFrom(artifact) elif tbrecord: record.tbrecord.CopyFrom(tbrecord) elif alert: record.alert.CopyFrom(alert) elif final: record.final.CopyFrom(final) elif header: record.header.CopyFrom(header) elif footer: record.footer.CopyFrom(footer) elif request: record.request.CopyFrom(request) elif telemetry: record.telemetry.CopyFrom(telemetry) elif metric: record.metric.CopyFrom(metric) elif preempting: record.preempting.CopyFrom(preempting) elif use_artifact: record.use_artifact.CopyFrom(use_artifact) elif output: record.output.CopyFrom(output) elif output_raw: record.output_raw.CopyFrom(output_raw) elif environment: record.environment.CopyFrom(environment) else: raise Exception("Invalid record") return record def _publish_defer(self, state: pb.DeferRequest.DeferState) -> None: defer = pb.DeferRequest(state=state) rec = self._make_request(defer=defer) rec.control.local = True self._publish(rec) def publish_defer(self, state: int = 0) -> None: self._publish_defer(cast("pb.DeferRequest.DeferState", state)) def _publish_header(self, header: pb.HeaderRecord) -> None: rec = self._make_record(header=header) self._publish(rec) def publish_footer(self) -> None: footer = pb.FooterRecord() rec = self._make_record(footer=footer) self._publish(rec) def publish_final(self) -> None: final = pb.FinalRecord() rec = self._make_record(final=final) self._publish(rec) def _publish_pause(self, pause: pb.PauseRequest) -> None: rec = self._make_request(pause=pause) self._publish(rec) def _publish_resume(self, resume: pb.ResumeRequest) -> None: rec = self._make_request(resume=resume) self._publish(rec) def _publish_run(self, run: pb.RunRecord) -> None: rec = self._make_record(run=run) self._publish(rec) def _publish_config(self, cfg: pb.ConfigRecord) -> None: rec = self._make_record(config=cfg) self._publish(rec) def _publish_summary(self, summary: pb.SummaryRecord) -> None: rec = self._make_record(summary=summary) self._publish(rec) def _publish_metric(self, metric: pb.MetricRecord) -> None: rec = self._make_record(metric=metric) self._publish(rec) def publish_stats(self, stats_dict: dict) -> None: stats = self._make_stats(stats_dict) rec = self._make_record(stats=stats) self._publish(rec) def _publish_python_packages( self, python_packages: pb.PythonPackagesRequest ) -> None: rec = self._make_request(python_packages=python_packages) self._publish(rec) def _publish_files(self, files: pb.FilesRecord) -> None: rec = self._make_record(files=files) self._publish(rec) def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any: rec = self._make_record(use_artifact=use_artifact) self._publish(rec) def _publish_probe_system_info( self, probe_system_info: pb.ProbeSystemInfoRequest ) -> None: record = self._make_request(probe_system_info=probe_system_info) self._publish(record) def _deliver_artifact( self, log_artifact: pb.LogArtifactRequest, ) -> MailboxHandle[pb.Result]: rec = self._make_request(log_artifact=log_artifact) return self._deliver(rec) def _deliver_download_artifact( self, download_artifact: pb.DownloadArtifactRequest ) -> MailboxHandle[pb.Result]: rec = self._make_request(download_artifact=download_artifact) return self._deliver(rec) def _deliver_link_artifact( self, link_artifact: pb.LinkArtifactRequest ) -> MailboxHandle[pb.Result]: rec = self._make_request(link_artifact=link_artifact) return self._deliver(rec) def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None: rec = self._make_record(artifact=proto_artifact) self._publish(rec) def _publish_alert(self, proto_alert: pb.AlertRecord) -> None: rec = self._make_record(alert=proto_alert) self._publish(rec) def _deliver_status( self, status: pb.StatusRequest, ) -> MailboxHandle[pb.Result]: req = self._make_request(status=status) return self._deliver(req) def _publish_exit(self, exit_data: pb.RunExitRecord) -> None: rec = self._make_record(exit=exit_data) self._publish(rec) def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None: record = self._make_request(keepalive=keepalive) self._publish(record) def _deliver_shutdown(self) -> MailboxHandle[pb.Result]: request = pb.Request(shutdown=pb.ShutdownRequest()) record = self._make_record(request=request) return self._deliver(record) def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]: record = self._make_record(run=run) return self._deliver(record) def _deliver_finish_sync( self, sync_finish: pb.SyncFinishRequest, ) -> MailboxHandle[pb.Result]: record = self._make_request(sync_finish=sync_finish) return self._deliver(record) def _deliver_run_start( self, run_start: pb.RunStartRequest, ) -> MailboxHandle[pb.Result]: record = self._make_request(run_start=run_start) return self._deliver(record) def _deliver_get_summary( self, get_summary: pb.GetSummaryRequest, ) -> MailboxHandle[pb.Result]: record = self._make_request(get_summary=get_summary) return self._deliver(record) def _deliver_get_system_metrics( self, get_system_metrics: pb.GetSystemMetricsRequest ) -> MailboxHandle[pb.Result]: record = self._make_request(get_system_metrics=get_system_metrics) return self._deliver(record) def _deliver_exit( self, exit_data: pb.RunExitRecord, ) -> MailboxHandle[pb.Result]: record = self._make_record(exit=exit_data) return self._deliver(record) def _deliver_poll_exit( self, poll_exit: pb.PollExitRequest, ) -> MailboxHandle[pb.Result]: record = self._make_request(poll_exit=poll_exit) return self._deliver(record) def _deliver_finish_without_exit( self, run_finish_without_exit: pb.RunFinishWithoutExitRequest ) -> MailboxHandle[pb.Result]: record = self._make_request(run_finish_without_exit=run_finish_without_exit) return self._deliver(record) def _deliver_stop_status( self, stop_status: pb.StopStatusRequest, ) -> MailboxHandle[pb.Result]: record = self._make_request(stop_status=stop_status) return self._deliver(record) def _deliver_attach( self, attach: pb.AttachRequest, ) -> MailboxHandle[pb.Result]: record = self._make_request(attach=attach) return self._deliver(record) def _deliver_network_status( self, network_status: pb.NetworkStatusRequest ) -> MailboxHandle[pb.Result]: record = self._make_request(network_status=network_status) return self._deliver(record) def _deliver_internal_messages( self, internal_message: pb.InternalMessagesRequest ) -> MailboxHandle[pb.Result]: record = self._make_request(internal_messages=internal_message) return self._deliver(record) def _deliver_request_sampled_history( self, sampled_history: pb.SampledHistoryRequest ) -> MailboxHandle[pb.Result]: record = self._make_request(sampled_history=sampled_history) return self._deliver(record) def _deliver_request_run_status( self, run_status: pb.RunStatusRequest ) -> MailboxHandle[pb.Result]: record = self._make_request(run_status=run_status) return self._deliver(record)