interface.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087
  1. from __future__ import annotations
  2. import abc
  3. import gzip
  4. import logging
  5. import time
  6. from collections.abc import Iterable
  7. from pathlib import Path
  8. from secrets import token_hex
  9. from typing import TYPE_CHECKING, Any
  10. from wandb import termwarn
  11. from wandb.proto import wandb_internal_pb2 as pb
  12. from wandb.proto import wandb_telemetry_pb2 as tpb
  13. from wandb.sdk.lib import json_util as json
  14. from wandb.sdk.lib.filesystem import FilesDict, PolicyName
  15. from wandb.sdk.mailbox import HandleAbandonedError, MailboxHandle
  16. from wandb.util import (
  17. WandBJSONEncoderOld,
  18. get_h5_typename,
  19. json_dumps_safer,
  20. json_dumps_safer_history,
  21. json_friendly,
  22. json_friendly_val,
  23. maybe_compress_summary,
  24. )
  25. from ..data_types.utils import history_dict_to_json, val_to_json
  26. from . import summary_record as sr
  27. MANIFEST_FILE_SIZE_THRESHOLD = 100_000
  28. if TYPE_CHECKING:
  29. from wandb.sdk.artifacts.artifact import Artifact
  30. from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
  31. from ..wandb_run import Run
  32. logger = logging.getLogger("wandb")
  33. def file_policy_to_enum(policy: PolicyName) -> pb.FilesItem.PolicyType:
  34. if policy == "now":
  35. enum = pb.FilesItem.PolicyType.NOW
  36. elif policy == "end":
  37. enum = pb.FilesItem.PolicyType.END
  38. elif policy == "live":
  39. enum = pb.FilesItem.PolicyType.LIVE
  40. return enum
  41. def file_enum_to_policy(enum: pb.FilesItem.PolicyType) -> PolicyName:
  42. if enum == pb.FilesItem.PolicyType.NOW:
  43. policy: PolicyName = "now"
  44. elif enum == pb.FilesItem.PolicyType.END:
  45. policy = "end"
  46. elif enum == pb.FilesItem.PolicyType.LIVE:
  47. policy = "live"
  48. return policy
  49. class InterfaceBase(abc.ABC):
  50. """Methods for sending run messages (Records) to the service.
  51. None of the methods may be called from an asyncio context other than
  52. deliver_async() or those with a `nowait=True` argument.
  53. """
  54. _drop: bool
  55. def __init__(self) -> None:
  56. self._drop = False
  57. @abc.abstractmethod
  58. async def deliver_async(
  59. self,
  60. record: pb.Record,
  61. ) -> MailboxHandle[pb.Result]:
  62. """Send a record and create a handle to wait for the response.
  63. The synchronous publish and deliver methods on this class cannot be
  64. called in the asyncio thread because they block. Instead of having
  65. an async copy of every method, this is a general method for sending
  66. any kind of record in the asyncio thread.
  67. Args:
  68. record: The record to send. This method takes ownership of the
  69. record and it must not be used afterward.
  70. Returns:
  71. A handle to wait for a response to the record.
  72. """
  73. def publish_header(self) -> None:
  74. header = pb.HeaderRecord()
  75. self._publish_header(header)
  76. @abc.abstractmethod
  77. def _publish_header(self, header: pb.HeaderRecord) -> None:
  78. raise NotImplementedError
  79. def deliver_status(self) -> MailboxHandle[pb.Result]:
  80. return self._deliver_status(pb.StatusRequest())
  81. @abc.abstractmethod
  82. def _deliver_status(
  83. self,
  84. status: pb.StatusRequest,
  85. ) -> MailboxHandle[pb.Result]:
  86. raise NotImplementedError
  87. def _make_config(
  88. self,
  89. data: dict | None = None,
  90. key: tuple[str, ...] | str | None = None,
  91. val: Any | None = None,
  92. obj: pb.ConfigRecord | None = None,
  93. ) -> pb.ConfigRecord:
  94. config = obj or pb.ConfigRecord()
  95. if data:
  96. for k, v in data.items():
  97. update = config.update.add()
  98. update.key = k
  99. update.value_json = json_dumps_safer(json_friendly(v)[0])
  100. if key:
  101. update = config.update.add()
  102. if isinstance(key, tuple):
  103. for k in key:
  104. update.nested_key.append(k)
  105. else:
  106. update.key = key
  107. update.value_json = json_dumps_safer(json_friendly(val)[0])
  108. return config
  109. def _make_run(self, run: Run) -> pb.RunRecord: # noqa: C901
  110. proto_run = pb.RunRecord()
  111. if run._settings.entity is not None:
  112. proto_run.entity = run._settings.entity
  113. if run._settings.project is not None:
  114. proto_run.project = run._settings.project
  115. if run._settings.run_group is not None:
  116. proto_run.run_group = run._settings.run_group
  117. if run._settings.run_job_type is not None:
  118. proto_run.job_type = run._settings.run_job_type
  119. if run._settings.run_id is not None:
  120. proto_run.run_id = run._settings.run_id
  121. if run._settings.run_name is not None:
  122. proto_run.display_name = run._settings.run_name
  123. if run._settings.run_notes is not None:
  124. proto_run.notes = run._settings.run_notes
  125. if run._settings.run_tags is not None:
  126. proto_run.tags.extend(run._settings.run_tags)
  127. if run._start_time is not None:
  128. proto_run.start_time.FromMicroseconds(int(run._start_time * 1e6))
  129. if run._starting_step is not None:
  130. proto_run.starting_step = run._starting_step
  131. if run._settings.git_remote_url is not None:
  132. proto_run.git.remote_url = run._settings.git_remote_url
  133. if run._settings.git_commit is not None:
  134. proto_run.git.commit = run._settings.git_commit
  135. if run._settings.sweep_id is not None:
  136. proto_run.sweep_id = run._settings.sweep_id
  137. if run._settings.host:
  138. proto_run.host = run._settings.host
  139. if run._settings.resumed:
  140. proto_run.resumed = run._settings.resumed
  141. if run._settings.fork_from:
  142. run_moment = run._settings.fork_from
  143. proto_run.branch_point.run = run_moment.run
  144. proto_run.branch_point.metric = run_moment.metric
  145. proto_run.branch_point.value = run_moment.value
  146. if run._settings.resume_from:
  147. run_moment = run._settings.resume_from
  148. proto_run.branch_point.run = run_moment.run
  149. proto_run.branch_point.metric = run_moment.metric
  150. proto_run.branch_point.value = run_moment.value
  151. if run._forked:
  152. proto_run.forked = run._forked
  153. if run._config is not None:
  154. config_dict = run._config._as_dict() # type: ignore
  155. self._make_config(data=config_dict, obj=proto_run.config)
  156. if run._telemetry_obj:
  157. proto_run.telemetry.MergeFrom(run._telemetry_obj)
  158. if run._start_runtime:
  159. proto_run.runtime = run._start_runtime
  160. return proto_run
  161. def publish_run(self, run: Run) -> None:
  162. run_record = self._make_run(run)
  163. self._publish_run(run_record)
  164. @abc.abstractmethod
  165. def _publish_run(self, run: pb.RunRecord) -> None:
  166. raise NotImplementedError
  167. def publish_cancel(self, cancel_slot: str) -> None:
  168. cancel = pb.CancelRequest(cancel_slot=cancel_slot)
  169. self._publish_cancel(cancel)
  170. @abc.abstractmethod
  171. def _publish_cancel(self, cancel: pb.CancelRequest) -> None:
  172. raise NotImplementedError
  173. def publish_config(
  174. self,
  175. data: dict | None = None,
  176. key: tuple[str, ...] | str | None = None,
  177. val: Any | None = None,
  178. ) -> None:
  179. cfg = self._make_config(data=data, key=key, val=val)
  180. self._publish_config(cfg)
  181. @abc.abstractmethod
  182. def _publish_config(self, cfg: pb.ConfigRecord) -> None:
  183. raise NotImplementedError
  184. @abc.abstractmethod
  185. def _publish_metric(self, metric: pb.MetricRecord) -> None:
  186. raise NotImplementedError
  187. def _make_summary_from_dict(self, summary_dict: dict) -> pb.SummaryRecord:
  188. summary = pb.SummaryRecord()
  189. for k, v in summary_dict.items():
  190. update = summary.update.add()
  191. update.key = k
  192. update.value_json = json.dumps(v)
  193. return summary
  194. def _summary_encode(
  195. self,
  196. value: Any,
  197. path_from_root: str,
  198. run: Run,
  199. ) -> dict:
  200. """Normalize, compress, and encode sub-objects for backend storage.
  201. value: Object to encode.
  202. path_from_root: `str` dot separated string from the top-level summary to the
  203. current `value`.
  204. Returns:
  205. A new tree of dict's with large objects replaced with dictionaries
  206. with "_type" entries that say which type the original data was.
  207. """
  208. # Constructs a new `dict` tree in `json_value` that discards and/or
  209. # encodes objects that aren't JSON serializable.
  210. if isinstance(value, dict):
  211. json_value = {}
  212. for key, value in value.items(): # noqa: B020
  213. json_value[key] = self._summary_encode(
  214. value,
  215. path_from_root + "." + key,
  216. run=run,
  217. )
  218. return json_value
  219. else:
  220. friendly_value, converted = json_friendly(
  221. val_to_json(run, path_from_root, value, namespace="summary")
  222. )
  223. json_value, compressed = maybe_compress_summary(
  224. friendly_value, get_h5_typename(value)
  225. )
  226. if compressed:
  227. # TODO(jhr): impleement me
  228. pass
  229. # self.write_h5(path_from_root, friendly_value)
  230. return json_value
  231. def _make_summary(
  232. self,
  233. summary_record: sr.SummaryRecord,
  234. run: Run,
  235. ) -> pb.SummaryRecord:
  236. pb_summary_record = pb.SummaryRecord()
  237. for item in summary_record.update:
  238. pb_summary_item = pb_summary_record.update.add()
  239. key_length = len(item.key)
  240. assert key_length > 0
  241. if key_length > 1:
  242. pb_summary_item.nested_key.extend(item.key)
  243. else:
  244. pb_summary_item.key = item.key[0]
  245. path_from_root = ".".join(item.key)
  246. json_value = self._summary_encode(
  247. item.value,
  248. path_from_root,
  249. run=run,
  250. )
  251. json_value, _ = json_friendly(json_value) # type: ignore
  252. pb_summary_item.value_json = json.dumps(
  253. json_value,
  254. cls=WandBJSONEncoderOld,
  255. )
  256. for item in summary_record.remove:
  257. pb_summary_item = pb_summary_record.remove.add()
  258. key_length = len(item.key)
  259. assert key_length > 0
  260. if key_length > 1:
  261. pb_summary_item.nested_key.extend(item.key)
  262. else:
  263. pb_summary_item.key = item.key[0]
  264. return pb_summary_record
  265. def publish_summary(
  266. self,
  267. run: Run,
  268. summary_record: sr.SummaryRecord,
  269. ) -> None:
  270. pb_summary_record = self._make_summary(summary_record, run=run)
  271. self._publish_summary(pb_summary_record)
  272. @abc.abstractmethod
  273. def _publish_summary(self, summary: pb.SummaryRecord) -> None:
  274. raise NotImplementedError
  275. def _make_files(self, files_dict: FilesDict) -> pb.FilesRecord:
  276. files = pb.FilesRecord()
  277. for path, policy in files_dict["files"]:
  278. f = files.files.add()
  279. f.path = path
  280. f.policy = file_policy_to_enum(policy)
  281. return files
  282. def publish_files(self, files_dict: FilesDict) -> None:
  283. files = self._make_files(files_dict)
  284. self._publish_files(files)
  285. @abc.abstractmethod
  286. def _publish_files(self, files: pb.FilesRecord) -> None:
  287. raise NotImplementedError
  288. def publish_python_packages(self, working_set) -> None:
  289. python_packages = pb.PythonPackagesRequest()
  290. for pkg in working_set:
  291. python_packages.package.add(name=pkg.key, version=pkg.version)
  292. self._publish_python_packages(python_packages)
  293. @abc.abstractmethod
  294. def _publish_python_packages(
  295. self, python_packages: pb.PythonPackagesRequest
  296. ) -> None:
  297. raise NotImplementedError
  298. def _make_artifact(self, artifact: Artifact) -> pb.ArtifactRecord:
  299. proto_artifact = pb.ArtifactRecord()
  300. proto_artifact.type = artifact.type
  301. proto_artifact.name = artifact.name
  302. proto_artifact.client_id = artifact._client_id
  303. proto_artifact.sequence_client_id = artifact._sequence_client_id
  304. proto_artifact.digest = artifact.digest
  305. if artifact.distributed_id:
  306. proto_artifact.distributed_id = artifact.distributed_id
  307. if artifact.description:
  308. proto_artifact.description = artifact.description
  309. if artifact.metadata:
  310. proto_artifact.metadata = json.dumps(json_friendly_val(artifact.metadata))
  311. if artifact._base_id:
  312. proto_artifact.base_id = artifact._base_id
  313. ttl_duration_input = artifact._ttl_duration_seconds_to_gql()
  314. if ttl_duration_input:
  315. proto_artifact.ttl_duration_seconds = ttl_duration_input
  316. proto_artifact.incremental_beta1 = artifact.incremental
  317. self._make_artifact_manifest(artifact.manifest, obj=proto_artifact.manifest)
  318. return proto_artifact
  319. def _make_artifact_manifest(
  320. self,
  321. artifact_manifest: ArtifactManifest,
  322. obj: pb.ArtifactManifest | None = None,
  323. ) -> pb.ArtifactManifest:
  324. proto_manifest = obj or pb.ArtifactManifest()
  325. proto_manifest.version = artifact_manifest.version()
  326. proto_manifest.storage_policy = artifact_manifest.storage_policy.name()
  327. # Very large manifests need to be written to file to avoid protobuf size limits.
  328. if len(artifact_manifest) > MANIFEST_FILE_SIZE_THRESHOLD:
  329. path = self._write_artifact_manifest_file(artifact_manifest)
  330. proto_manifest.manifest_file_path = path
  331. return proto_manifest
  332. # Set storage policy on storageLayout (always V2) and storageRegion, only allow coreweave-us on wandb.ai for now.
  333. # NOTE: the decode logic is NewManifestFromProto in core/pkg/artifacts/manifest.go
  334. # The creation logic is in artifacts/_factories.py make_storage_policy
  335. for k, v in artifact_manifest.storage_policy.config().items() or {}.items():
  336. cfg = proto_manifest.storage_policy_config.add()
  337. cfg.key = k
  338. # TODO: Why json.dumps when existing values are plain string? We want to send complex structure without defining the proto?
  339. cfg.value_json = json.dumps(v)
  340. for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path):
  341. proto_entry = proto_manifest.contents.add()
  342. proto_entry.path = entry.path
  343. proto_entry.digest = entry.digest
  344. if entry.size:
  345. proto_entry.size = entry.size
  346. if entry.birth_artifact_id:
  347. proto_entry.birth_artifact_id = entry.birth_artifact_id
  348. if entry.ref:
  349. proto_entry.ref = entry.ref
  350. if entry.local_path:
  351. proto_entry.local_path = entry.local_path
  352. proto_entry.skip_cache = entry.skip_cache
  353. for k, v in entry.extra.items():
  354. proto_extra = proto_entry.extra.add()
  355. proto_extra.key = k
  356. proto_extra.value_json = json.dumps(v)
  357. return proto_manifest
  358. def _write_artifact_manifest_file(self, manifest: ArtifactManifest) -> str:
  359. from wandb.sdk.artifacts.staging import get_staging_dir
  360. manifest_dir = Path(get_staging_dir()) / "artifact_manifests"
  361. manifest_dir.mkdir(parents=True, exist_ok=True)
  362. # It would be simpler to use `manifest.to_json()`, but that gets very slow for
  363. # large manifests since it encodes the whole thing as a single JSON object.
  364. filename = f"{time.time()}_{token_hex(8)}.manifest_contents.jl.gz"
  365. manifest_file_path = manifest_dir / filename
  366. with gzip.open(manifest_file_path, mode="wt", compresslevel=1) as f:
  367. for entry in manifest.entries.values():
  368. f.write(f"{json.dumps(entry.to_json())}\n")
  369. return str(manifest_file_path)
  370. def deliver_link_artifact(
  371. self,
  372. artifact: Artifact,
  373. portfolio_name: str,
  374. aliases: Iterable[str],
  375. entity: str | None = None,
  376. project: str | None = None,
  377. organization: str | None = None,
  378. ) -> MailboxHandle[pb.Result]:
  379. link_artifact = pb.LinkArtifactRequest()
  380. if artifact.is_draft():
  381. link_artifact.client_id = artifact._client_id
  382. else:
  383. link_artifact.server_id = artifact.id if artifact.id else ""
  384. link_artifact.portfolio_name = portfolio_name
  385. link_artifact.portfolio_entity = entity or ""
  386. link_artifact.portfolio_organization = organization or ""
  387. link_artifact.portfolio_project = project or ""
  388. link_artifact.portfolio_aliases.extend(aliases)
  389. return self._deliver_link_artifact(link_artifact)
  390. @abc.abstractmethod
  391. def _deliver_link_artifact(
  392. self, link_artifact: pb.LinkArtifactRequest
  393. ) -> MailboxHandle[pb.Result]:
  394. raise NotImplementedError
  395. @staticmethod
  396. def _make_partial_source_str(
  397. source: Any, job_info: dict[str, Any], metadata: dict[str, Any]
  398. ) -> str:
  399. """Construct use_artifact.partial.source_info.source as str."""
  400. source_type = job_info.get("source_type", "").strip()
  401. if source_type == "artifact":
  402. info_source = job_info.get("source", {})
  403. source.artifact.artifact = info_source.get("artifact", "")
  404. source.artifact.entrypoint.extend(info_source.get("entrypoint", []))
  405. source.artifact.notebook = info_source.get("notebook", False)
  406. build_context = info_source.get("build_context")
  407. if build_context:
  408. source.artifact.build_context = build_context
  409. dockerfile = info_source.get("dockerfile")
  410. if dockerfile:
  411. source.artifact.dockerfile = dockerfile
  412. elif source_type == "repo":
  413. source.git.git_info.remote = metadata.get("git", {}).get("remote", "")
  414. source.git.git_info.commit = metadata.get("git", {}).get("commit", "")
  415. source.git.entrypoint.extend(metadata.get("entrypoint", []))
  416. source.git.notebook = metadata.get("notebook", False)
  417. build_context = metadata.get("build_context")
  418. if build_context:
  419. source.git.build_context = build_context
  420. dockerfile = metadata.get("dockerfile")
  421. if dockerfile:
  422. source.git.dockerfile = dockerfile
  423. elif source_type == "image":
  424. source.image.image = metadata.get("docker", "")
  425. else:
  426. raise ValueError("Invalid source type")
  427. source_str: str = source.SerializeToString()
  428. return source_str
  429. def _make_proto_use_artifact(
  430. self,
  431. use_artifact: pb.UseArtifactRecord,
  432. job_name: str,
  433. job_info: dict[str, Any],
  434. metadata: dict[str, Any],
  435. ) -> pb.UseArtifactRecord:
  436. use_artifact.partial.job_name = job_name
  437. use_artifact.partial.source_info._version = job_info.get("_version", "")
  438. use_artifact.partial.source_info.source_type = job_info.get("source_type", "")
  439. use_artifact.partial.source_info.runtime = job_info.get("runtime", "")
  440. src_str = self._make_partial_source_str(
  441. source=use_artifact.partial.source_info.source,
  442. job_info=job_info,
  443. metadata=metadata,
  444. )
  445. use_artifact.partial.source_info.source.ParseFromString(src_str) # type: ignore[arg-type]
  446. return use_artifact
  447. def publish_use_artifact(
  448. self,
  449. artifact: Artifact,
  450. ) -> None:
  451. assert artifact.id is not None, "Artifact must have an id"
  452. use_artifact = pb.UseArtifactRecord(
  453. id=artifact.id,
  454. type=artifact.type,
  455. name=artifact.name,
  456. )
  457. # TODO(gst): move to internal process
  458. if "_partial" in artifact.metadata:
  459. # Download source info from logged partial job artifact
  460. job_info = {}
  461. try:
  462. path = artifact.get_entry("wandb-job.json").download()
  463. with open(path) as f:
  464. job_info = json.load(f)
  465. except Exception as e:
  466. logger.warning(
  467. f"Failed to download partial job info from artifact {artifact}, : {e}"
  468. )
  469. termwarn(
  470. f"Failed to download partial job info from artifact {artifact}, : {e}"
  471. )
  472. return
  473. try:
  474. use_artifact = self._make_proto_use_artifact(
  475. use_artifact=use_artifact,
  476. job_name=artifact.name,
  477. job_info=job_info,
  478. metadata=artifact.metadata,
  479. )
  480. except Exception as e:
  481. logger.warning(f"Failed to construct use artifact proto: {e}")
  482. termwarn(f"Failed to construct use artifact proto: {e}")
  483. return
  484. self._publish_use_artifact(use_artifact)
  485. @abc.abstractmethod
  486. def _publish_use_artifact(self, proto_artifact: pb.UseArtifactRecord) -> None:
  487. raise NotImplementedError
  488. def deliver_artifact(
  489. self,
  490. run: Run,
  491. artifact: Artifact,
  492. aliases: Iterable[str],
  493. tags: Iterable[str] | None = None,
  494. history_step: int | None = None,
  495. is_user_created: bool = False,
  496. use_after_commit: bool = False,
  497. finalize: bool = True,
  498. ) -> MailboxHandle[pb.Result]:
  499. from wandb.sdk.artifacts.staging import get_staging_dir
  500. proto_run = self._make_run(run)
  501. proto_artifact = self._make_artifact(artifact)
  502. proto_artifact.run_id = proto_run.run_id
  503. proto_artifact.project = proto_run.project
  504. proto_artifact.entity = proto_run.entity
  505. proto_artifact.user_created = is_user_created
  506. proto_artifact.use_after_commit = use_after_commit
  507. proto_artifact.finalize = finalize
  508. proto_artifact.aliases.extend(aliases or [])
  509. proto_artifact.tags.extend(tags or [])
  510. log_artifact = pb.LogArtifactRequest()
  511. log_artifact.artifact.CopyFrom(proto_artifact)
  512. if history_step is not None:
  513. log_artifact.history_step = history_step
  514. log_artifact.staging_dir = get_staging_dir()
  515. resp = self._deliver_artifact(log_artifact)
  516. return resp
  517. @abc.abstractmethod
  518. def _deliver_artifact(
  519. self,
  520. log_artifact: pb.LogArtifactRequest,
  521. ) -> MailboxHandle[pb.Result]:
  522. raise NotImplementedError
  523. def deliver_download_artifact(
  524. self,
  525. artifact_id: str,
  526. download_root: str,
  527. allow_missing_references: bool,
  528. skip_cache: bool,
  529. path_prefix: str | None,
  530. ) -> MailboxHandle[pb.Result]:
  531. download_artifact = pb.DownloadArtifactRequest()
  532. download_artifact.artifact_id = artifact_id
  533. download_artifact.download_root = download_root
  534. download_artifact.allow_missing_references = allow_missing_references
  535. download_artifact.skip_cache = skip_cache
  536. download_artifact.path_prefix = path_prefix or ""
  537. resp = self._deliver_download_artifact(download_artifact)
  538. return resp
  539. @abc.abstractmethod
  540. def _deliver_download_artifact(
  541. self, download_artifact: pb.DownloadArtifactRequest
  542. ) -> MailboxHandle[pb.Result]:
  543. raise NotImplementedError
  544. def publish_artifact(
  545. self,
  546. run: Run,
  547. artifact: Artifact,
  548. aliases: Iterable[str],
  549. tags: Iterable[str] | None = None,
  550. is_user_created: bool = False,
  551. use_after_commit: bool = False,
  552. finalize: bool = True,
  553. ) -> None:
  554. proto_run = self._make_run(run)
  555. proto_artifact = self._make_artifact(artifact)
  556. proto_artifact.run_id = proto_run.run_id
  557. proto_artifact.project = proto_run.project
  558. proto_artifact.entity = proto_run.entity
  559. proto_artifact.user_created = is_user_created
  560. proto_artifact.use_after_commit = use_after_commit
  561. proto_artifact.finalize = finalize
  562. proto_artifact.aliases.extend(aliases or [])
  563. proto_artifact.tags.extend(tags or [])
  564. self._publish_artifact(proto_artifact)
  565. @abc.abstractmethod
  566. def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
  567. raise NotImplementedError
  568. def publish_tbdata(self, log_dir: str, save: bool, root_logdir: str = "") -> None:
  569. tbrecord = pb.TBRecord()
  570. tbrecord.log_dir = log_dir
  571. tbrecord.save = save
  572. tbrecord.root_dir = root_logdir
  573. self._publish_tbdata(tbrecord)
  574. @abc.abstractmethod
  575. def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None:
  576. raise NotImplementedError
  577. @abc.abstractmethod
  578. def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
  579. raise NotImplementedError
  580. def publish_environment(self, environment: pb.EnvironmentRecord) -> None:
  581. self._publish_environment(environment)
  582. @abc.abstractmethod
  583. def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
  584. raise NotImplementedError
  585. def publish_partial_history(
  586. self,
  587. run: Run,
  588. data: dict,
  589. user_step: int,
  590. step: int | None = None,
  591. flush: bool | None = None,
  592. publish_step: bool = True,
  593. ) -> None:
  594. data = history_dict_to_json(run, data, step=user_step, ignore_copy_err=True)
  595. data.pop("_step", None)
  596. # add timestamp to the history request, if not already present
  597. # the timestamp might come from the tensorboard log logic
  598. if "_timestamp" not in data:
  599. data["_timestamp"] = time.time()
  600. partial_history = pb.PartialHistoryRequest()
  601. for k, v in data.items():
  602. item = partial_history.item.add()
  603. item.key = k
  604. item.value_json = json_dumps_safer_history(v)
  605. if publish_step and step is not None:
  606. partial_history.step.num = step
  607. if flush is not None:
  608. partial_history.action.flush = flush
  609. self._publish_partial_history(partial_history)
  610. @abc.abstractmethod
  611. def _publish_partial_history(self, history: pb.PartialHistoryRequest) -> None:
  612. raise NotImplementedError
  613. def publish_history(
  614. self,
  615. run: Run,
  616. data: dict,
  617. step: int | None = None,
  618. publish_step: bool = True,
  619. ) -> None:
  620. data = history_dict_to_json(run, data, step=step)
  621. history = pb.HistoryRecord()
  622. if publish_step:
  623. assert step is not None
  624. history.step.num = step
  625. data.pop("_step", None)
  626. for k, v in data.items():
  627. item = history.item.add()
  628. item.key = k
  629. item.value_json = json_dumps_safer_history(v)
  630. self._publish_history(history)
  631. @abc.abstractmethod
  632. def _publish_history(self, history: pb.HistoryRecord) -> None:
  633. raise NotImplementedError
  634. def publish_preempting(self) -> None:
  635. preempt_rec = pb.RunPreemptingRecord()
  636. self._publish_preempting(preempt_rec)
  637. @abc.abstractmethod
  638. def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None:
  639. raise NotImplementedError
  640. def publish_output(
  641. self,
  642. name: str,
  643. data: str,
  644. *,
  645. nowait: bool = False,
  646. ) -> None:
  647. # from vendor.protobuf import google3.protobuf.timestamp
  648. # ts = timestamp.Timestamp()
  649. # ts.GetCurrentTime()
  650. # now = datetime.now()
  651. if name == "stdout":
  652. otype = pb.OutputRecord.OutputType.STDOUT
  653. elif name == "stderr":
  654. otype = pb.OutputRecord.OutputType.STDERR
  655. else:
  656. # TODO(jhr): throw error?
  657. termwarn("unknown type")
  658. o = pb.OutputRecord(output_type=otype, line=data)
  659. o.timestamp.GetCurrentTime()
  660. self._publish_output(o, nowait=nowait)
  661. @abc.abstractmethod
  662. def _publish_output(self, outdata: pb.OutputRecord, *, nowait: bool) -> None:
  663. raise NotImplementedError
  664. def publish_output_raw(
  665. self,
  666. name: str,
  667. data: str,
  668. *,
  669. nowait: bool = False,
  670. ) -> None:
  671. # from vendor.protobuf import google3.protobuf.timestamp
  672. # ts = timestamp.Timestamp()
  673. # ts.GetCurrentTime()
  674. # now = datetime.now()
  675. if name == "stdout":
  676. otype = pb.OutputRawRecord.OutputType.STDOUT
  677. elif name == "stderr":
  678. otype = pb.OutputRawRecord.OutputType.STDERR
  679. else:
  680. # TODO(jhr): throw error?
  681. termwarn("unknown type")
  682. o = pb.OutputRawRecord(output_type=otype, line=data)
  683. o.timestamp.GetCurrentTime()
  684. self._publish_output_raw(o, nowait=nowait)
  685. @abc.abstractmethod
  686. def _publish_output_raw(
  687. self,
  688. outdata: pb.OutputRawRecord,
  689. *,
  690. nowait: bool,
  691. ) -> None:
  692. raise NotImplementedError
  693. def publish_pause(self) -> None:
  694. pause = pb.PauseRequest()
  695. self._publish_pause(pause)
  696. @abc.abstractmethod
  697. def _publish_pause(self, pause: pb.PauseRequest) -> None:
  698. raise NotImplementedError
  699. def publish_resume(self) -> None:
  700. resume = pb.ResumeRequest()
  701. self._publish_resume(resume)
  702. @abc.abstractmethod
  703. def _publish_resume(self, resume: pb.ResumeRequest) -> None:
  704. raise NotImplementedError
  705. def publish_alert(
  706. self, title: str, text: str, level: str, wait_duration: int
  707. ) -> None:
  708. proto_alert = pb.AlertRecord()
  709. proto_alert.title = title
  710. proto_alert.text = text
  711. proto_alert.level = level
  712. proto_alert.wait_duration = wait_duration
  713. self._publish_alert(proto_alert)
  714. @abc.abstractmethod
  715. def _publish_alert(self, alert: pb.AlertRecord) -> None:
  716. raise NotImplementedError
  717. def _make_exit(self, exit_code: int | None) -> pb.RunExitRecord:
  718. exit = pb.RunExitRecord()
  719. if exit_code is not None:
  720. exit.exit_code = exit_code
  721. return exit
  722. def publish_exit(self, exit_code: int | None) -> None:
  723. exit_data = self._make_exit(exit_code)
  724. self._publish_exit(exit_data)
  725. @abc.abstractmethod
  726. def _publish_exit(self, exit_data: pb.RunExitRecord) -> None:
  727. raise NotImplementedError
  728. def publish_keepalive(self) -> None:
  729. keepalive = pb.KeepaliveRequest()
  730. self._publish_keepalive(keepalive)
  731. @abc.abstractmethod
  732. def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None:
  733. raise NotImplementedError
  734. def publish_job_input(
  735. self,
  736. include_paths: list[list[str]],
  737. exclude_paths: list[list[str]],
  738. input_schema: dict | None,
  739. run_config: bool = False,
  740. file_path: str = "",
  741. ):
  742. """Publishes a request to add inputs to the job.
  743. If run_config is True, the wandb.config will be added as a job input.
  744. If file_path is provided, the file at file_path will be added as a job
  745. input.
  746. The paths provided as arguments are sequences of dictionary keys that
  747. specify a path within the wandb.config. If a path is included, the
  748. corresponding field will be treated as a job input. If a path is
  749. excluded, the corresponding field will not be treated as a job input.
  750. Args:
  751. include_paths: paths within config to include as job inputs.
  752. exclude_paths: paths within config to exclude as job inputs.
  753. input_schema: A JSON Schema describing which attributes will be
  754. editable from the Launch drawer.
  755. run_config: bool indicating whether wandb.config is the input source.
  756. file_path: path to file to include as a job input.
  757. """
  758. if run_config and file_path:
  759. raise ValueError(
  760. "run_config and file_path are mutually exclusive arguments."
  761. )
  762. request = pb.JobInputRequest()
  763. include_records = [pb.JobInputPath(path=path) for path in include_paths]
  764. exclude_records = [pb.JobInputPath(path=path) for path in exclude_paths]
  765. request.include_paths.extend(include_records)
  766. request.exclude_paths.extend(exclude_records)
  767. source = pb.JobInputSource(
  768. run_config=pb.JobInputSource.RunConfigSource(),
  769. )
  770. if run_config:
  771. source.run_config.CopyFrom(pb.JobInputSource.RunConfigSource())
  772. else:
  773. source.file.CopyFrom(
  774. pb.JobInputSource.ConfigFileSource(path=file_path),
  775. )
  776. request.input_source.CopyFrom(source)
  777. if input_schema:
  778. request.input_schema = json_dumps_safer(input_schema)
  779. return self._publish_job_input(request)
  780. @abc.abstractmethod
  781. def _publish_job_input(
  782. self, request: pb.JobInputRequest
  783. ) -> MailboxHandle[pb.Result]:
  784. raise NotImplementedError
  785. def publish_probe_system_info(self) -> None:
  786. probe_system_info = pb.ProbeSystemInfoRequest()
  787. return self._publish_probe_system_info(probe_system_info)
  788. @abc.abstractmethod
  789. def _publish_probe_system_info(
  790. self, probe_system_info: pb.ProbeSystemInfoRequest
  791. ) -> None:
  792. raise NotImplementedError
  793. def join(self) -> None:
  794. # Drop indicates that the internal process has already been shutdown
  795. if self._drop:
  796. return
  797. handle = self._deliver_shutdown()
  798. try:
  799. handle.wait_or(timeout=30)
  800. except TimeoutError:
  801. # This can happen if the server fails to respond due to a bug
  802. # or due to being very busy.
  803. logger.warning("timed out communicating shutdown")
  804. except HandleAbandonedError:
  805. # This can happen if the connection to the server is closed
  806. # before a response is read.
  807. logger.warning("handle abandoned while communicating shutdown")
  808. @abc.abstractmethod
  809. def _deliver_shutdown(self) -> MailboxHandle[pb.Result]:
  810. raise NotImplementedError
  811. def deliver_run(self, run: Run) -> MailboxHandle[pb.Result]:
  812. run_record = self._make_run(run)
  813. return self._deliver_run(run_record)
  814. def deliver_finish_sync(
  815. self,
  816. ) -> MailboxHandle[pb.Result]:
  817. sync = pb.SyncFinishRequest()
  818. return self._deliver_finish_sync(sync)
  819. @abc.abstractmethod
  820. def _deliver_finish_sync(
  821. self, sync: pb.SyncFinishRequest
  822. ) -> MailboxHandle[pb.Result]:
  823. raise NotImplementedError
  824. @abc.abstractmethod
  825. def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]:
  826. raise NotImplementedError
  827. def deliver_run_start(self, run: Run) -> MailboxHandle[pb.Result]:
  828. run_start = pb.RunStartRequest(run=self._make_run(run))
  829. return self._deliver_run_start(run_start)
  830. @abc.abstractmethod
  831. def _deliver_run_start(
  832. self, run_start: pb.RunStartRequest
  833. ) -> MailboxHandle[pb.Result]:
  834. raise NotImplementedError
  835. def deliver_attach(self, attach_id: str) -> MailboxHandle[pb.Result]:
  836. attach = pb.AttachRequest(attach_id=attach_id)
  837. return self._deliver_attach(attach)
  838. @abc.abstractmethod
  839. def _deliver_attach(
  840. self,
  841. status: pb.AttachRequest,
  842. ) -> MailboxHandle[pb.Result]:
  843. raise NotImplementedError
  844. def deliver_stop_status(self) -> MailboxHandle[pb.Result]:
  845. status = pb.StopStatusRequest()
  846. return self._deliver_stop_status(status)
  847. @abc.abstractmethod
  848. def _deliver_stop_status(
  849. self,
  850. status: pb.StopStatusRequest,
  851. ) -> MailboxHandle[pb.Result]:
  852. raise NotImplementedError
  853. def deliver_network_status(self) -> MailboxHandle[pb.Result]:
  854. status = pb.NetworkStatusRequest()
  855. return self._deliver_network_status(status)
  856. @abc.abstractmethod
  857. def _deliver_network_status(
  858. self,
  859. status: pb.NetworkStatusRequest,
  860. ) -> MailboxHandle[pb.Result]:
  861. raise NotImplementedError
  862. def deliver_internal_messages(self) -> MailboxHandle[pb.Result]:
  863. internal_message = pb.InternalMessagesRequest()
  864. return self._deliver_internal_messages(internal_message)
  865. @abc.abstractmethod
  866. def _deliver_internal_messages(
  867. self, internal_message: pb.InternalMessagesRequest
  868. ) -> MailboxHandle[pb.Result]:
  869. raise NotImplementedError
  870. def deliver_get_summary(self) -> MailboxHandle[pb.Result]:
  871. get_summary = pb.GetSummaryRequest()
  872. return self._deliver_get_summary(get_summary)
  873. @abc.abstractmethod
  874. def _deliver_get_summary(
  875. self,
  876. get_summary: pb.GetSummaryRequest,
  877. ) -> MailboxHandle[pb.Result]:
  878. raise NotImplementedError
  879. def deliver_get_system_metrics(self) -> MailboxHandle[pb.Result]:
  880. get_system_metrics = pb.GetSystemMetricsRequest()
  881. return self._deliver_get_system_metrics(get_system_metrics)
  882. @abc.abstractmethod
  883. def _deliver_get_system_metrics(
  884. self, get_summary: pb.GetSystemMetricsRequest
  885. ) -> MailboxHandle[pb.Result]:
  886. raise NotImplementedError
  887. def deliver_exit(self, exit_code: int | None) -> MailboxHandle[pb.Result]:
  888. exit_data = self._make_exit(exit_code)
  889. return self._deliver_exit(exit_data)
  890. @abc.abstractmethod
  891. def _deliver_exit(
  892. self,
  893. exit_data: pb.RunExitRecord,
  894. ) -> MailboxHandle[pb.Result]:
  895. raise NotImplementedError
  896. def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
  897. poll_exit = pb.PollExitRequest()
  898. return self._deliver_poll_exit(poll_exit)
  899. @abc.abstractmethod
  900. def _deliver_poll_exit(
  901. self,
  902. poll_exit: pb.PollExitRequest,
  903. ) -> MailboxHandle[pb.Result]:
  904. raise NotImplementedError
  905. def deliver_finish_without_exit(self) -> MailboxHandle[pb.Result]:
  906. run_finish_without_exit = pb.RunFinishWithoutExitRequest()
  907. return self._deliver_finish_without_exit(run_finish_without_exit)
  908. @abc.abstractmethod
  909. def _deliver_finish_without_exit(
  910. self, run_finish_without_exit: pb.RunFinishWithoutExitRequest
  911. ) -> MailboxHandle[pb.Result]:
  912. raise NotImplementedError
  913. def deliver_request_sampled_history(self) -> MailboxHandle[pb.Result]:
  914. sampled_history = pb.SampledHistoryRequest()
  915. return self._deliver_request_sampled_history(sampled_history)
  916. @abc.abstractmethod
  917. def _deliver_request_sampled_history(
  918. self, sampled_history: pb.SampledHistoryRequest
  919. ) -> MailboxHandle[pb.Result]:
  920. raise NotImplementedError
  921. def deliver_request_run_status(self) -> MailboxHandle[pb.Result]:
  922. run_status = pb.RunStatusRequest()
  923. return self._deliver_request_run_status(run_status)
  924. @abc.abstractmethod
  925. def _deliver_request_run_status(
  926. self, run_status: pb.RunStatusRequest
  927. ) -> MailboxHandle[pb.Result]:
  928. raise NotImplementedError