artifact.py 104 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729
  1. """Artifact class."""
  2. from __future__ import annotations
  3. import atexit
  4. import contextlib
  5. import json
  6. import logging
  7. import multiprocessing.dummy
  8. import os
  9. import re
  10. import shutil
  11. import stat
  12. import tempfile
  13. import time
  14. from collections import deque
  15. from collections.abc import Iterator, Sequence
  16. from concurrent.futures import Executor, ThreadPoolExecutor, as_completed
  17. from copy import copy
  18. from dataclasses import asdict, replace
  19. from datetime import timedelta
  20. from itertools import filterfalse
  21. from pathlib import Path, PurePosixPath
  22. from typing import ( # noqa: UP035 (can't use `type` - shadows `Artifact.type`)
  23. IO,
  24. TYPE_CHECKING,
  25. Any,
  26. Callable,
  27. Final,
  28. Literal,
  29. Type,
  30. )
  31. from urllib.parse import quote, urljoin, urlparse
  32. from pydantic import NonNegativeInt
  33. import wandb
  34. from wandb import data_types, env
  35. from wandb._iterutils import one, unique_list
  36. from wandb._pydantic import from_json
  37. from wandb._strutils import nameof
  38. from wandb.apis.normalize import normalize_exceptions
  39. from wandb.apis.public import ArtifactCollection, ArtifactFiles, Run
  40. from wandb.apis.public.utils import gql_compat
  41. from wandb.data_types import WBValue
  42. from wandb.errors import CommError
  43. from wandb.errors.errors import UnsupportedError
  44. from wandb.errors.term import termerror, termlog, termwarn
  45. from wandb.proto import wandb_internal_pb2 as pb
  46. from wandb.proto.wandb_telemetry_pb2 import Deprecated
  47. from wandb.sdk import wandb_setup
  48. from wandb.sdk.data_types._dtypes import Type as WBType
  49. from wandb.sdk.data_types._dtypes import TypeRegistry
  50. from wandb.sdk.lib import retry, telemetry
  51. from wandb.sdk.lib.deprecation import warn_and_record_deprecation
  52. from wandb.sdk.lib.filesystem import check_exists, system_preferred_path
  53. from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
  54. from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
  55. from wandb.sdk.lib.runid import generate_fast_id, generate_id
  56. from wandb.sdk.mailbox import MailboxHandle
  57. from wandb.util import (
  58. alias_is_version_index,
  59. artifact_to_json,
  60. fsync_open,
  61. json_dumps_safer,
  62. uri_from_path,
  63. vendor_setup,
  64. )
  65. from ._factories import make_storage_policy
  66. from ._gqlutils import org_info_from_entity, resolve_org_entity_name, server_supports
  67. from ._validators import ensure_logged, ensure_not_finalized
  68. from .artifact_download_logger import ArtifactDownloadLogger
  69. from .artifact_instance_cache import (
  70. artifact_instance_cache,
  71. artifact_instance_cache_by_client_id,
  72. )
  73. from .artifact_manifest import ArtifactManifest
  74. from .artifact_manifest_entry import ArtifactManifestEntry
  75. from .artifact_manifests.artifact_manifest_v1 import ArtifactManifestV1
  76. from .artifact_state import ArtifactState
  77. from .artifact_ttl import ArtifactTTL
  78. from .exceptions import (
  79. ArtifactNotLoggedError,
  80. TooFewItemsError,
  81. TooManyItemsError,
  82. WaitTimeoutError,
  83. )
  84. from .staging import get_staging_dir
  85. from .storage_handlers.gcs_handler import _GCSIsADirectoryError
  86. from .storage_policies._factories import make_http_session
  87. from .storage_policies._multipart import should_multipart_download
  88. reset_path = vendor_setup()
  89. from wandb_gql import gql # noqa: E402
  90. reset_path()
  91. if TYPE_CHECKING:
  92. from collections.abc import Iterable
  93. from wandb.apis.public import RetryingClient
  94. from ._generated import ArtifactFragment, ArtifactMembershipFragment
  95. from ._models.pagination import FileWithUrlConnection
  96. from ._validators import FullArtifactPath, LinkArtifactFields
  97. logger = logging.getLogger(__name__)
  98. _MB: Final[int] = 1024 * 1024
  99. _FILE_EXECUTOR_WORKERS = 32
  100. _MP_EXECUTOR_WORKERS = 32
  101. class Artifact:
  102. """Flexible and lightweight building block for dataset and model versioning.
  103. Construct an empty W&B Artifact. Populate an artifacts contents with methods that
  104. begin with `add`. Once the artifact has all the desired files, you can call
  105. `run.log_artifact()` to log it.
  106. Args:
  107. name (str): A human-readable name for the artifact. Use the name to identify
  108. a specific artifact in the W&B App UI or programmatically. You can
  109. interactively reference an artifact with the `use_artifact` Public API.
  110. A name can contain letters, numbers, underscores, hyphens, and dots.
  111. The name must be unique across a project.
  112. type (str): The artifact's type. Use the type of an artifact to both organize
  113. and differentiate artifacts. You can use any string that contains letters,
  114. numbers, underscores, hyphens, and dots. Common types include `dataset` or
  115. `model`. Include `model` within your type string if you want to link the
  116. artifact to the W&B Model Registry.
  117. Note that some types reserved for internal use and cannot be set by users.
  118. Such types include `job` and types that start with `wandb-`.
  119. description (str | None) = None: A description of the artifact. For Model or
  120. Dataset Artifacts, add documentation for your standardized team model or
  121. dataset card. View an artifact's description programmatically with the
  122. `Artifact.description` attribute or programmatically with the W&B App UI.
  123. W&B renders the description as markdown in the W&B App.
  124. metadata (dict[str, Any] | None) = None: Additional information about an artifact.
  125. Specify metadata as a dictionary of key-value pairs. You can specify no more
  126. than 100 total keys.
  127. incremental: Use `Artifact.new_draft()` method instead to modify an
  128. existing artifact.
  129. use_as: Deprecated.
  130. Returns:
  131. An `Artifact` object.
  132. """
  133. _TMP_DIR = tempfile.TemporaryDirectory("wandb-artifacts")
  134. atexit.register(_TMP_DIR.cleanup)
  135. def __init__(
  136. self,
  137. name: str,
  138. type: str,
  139. description: str | None = None,
  140. metadata: dict[str, Any] | None = None,
  141. incremental: bool = False,
  142. use_as: str | None = None,
  143. storage_region: str | None = None,
  144. ) -> None:
  145. from wandb.sdk.artifacts._internal_artifact import InternalArtifact
  146. from ._validators import (
  147. validate_artifact_name,
  148. validate_artifact_type,
  149. validate_metadata,
  150. )
  151. if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
  152. raise ValueError(
  153. f"Artifact name may only contain alphanumeric characters, dashes, "
  154. f"underscores, and dots. Invalid name: {name!r}"
  155. )
  156. if incremental and not isinstance(self, InternalArtifact):
  157. termwarn("Using experimental arg `incremental`")
  158. # Internal.
  159. self._client: RetryingClient | None = None
  160. self._tmp_dir: tempfile.TemporaryDirectory | None = None
  161. self._added_objs: dict[int, tuple[WBValue, ArtifactManifestEntry]] = {}
  162. self._added_local_paths: dict[str, ArtifactManifestEntry] = {}
  163. self._save_handle: MailboxHandle[pb.Result] | None = None
  164. self._download_roots: set[str] = set()
  165. # Set by new_draft(), otherwise the latest artifact will be used as the base.
  166. self._base_id: str | None = None
  167. # Properties.
  168. self._id: str | None = None
  169. # Client IDs don't need cryptographic strength, so use a faster implementation.
  170. self._client_id: str = generate_fast_id(128)
  171. self._sequence_client_id: str = generate_fast_id(128)
  172. self._entity: str | None = None
  173. self._project: str | None = None
  174. self._name: str = validate_artifact_name(name) # includes version after saving
  175. self._version: str | None = None
  176. self._source_entity: str | None = None
  177. self._source_project: str | None = None
  178. self._source_name: str = name # includes version after saving
  179. self._source_version: str | None = None
  180. self._source_artifact: Artifact | None = None
  181. self._is_link: bool = False
  182. self._type: str = validate_artifact_type(type, name)
  183. self._description: str | None = description
  184. self._metadata: dict[str, Any] = validate_metadata(metadata)
  185. self._ttl_duration_seconds: int | None = None
  186. self._ttl_is_inherited: bool = True
  187. self._ttl_changed: bool = False
  188. self._aliases: list[str] = []
  189. self._saved_aliases: list[str] = []
  190. self._tags: list[str] = []
  191. self._saved_tags: list[str] = []
  192. self._distributed_id: str | None = None
  193. self._incremental: bool = incremental
  194. if use_as is not None:
  195. warn_and_record_deprecation(
  196. feature=Deprecated(artifact__init_use_as=True),
  197. message=(
  198. "`use_as` argument is deprecated and does not affect the behaviour of `wandb.Artifact()`"
  199. ),
  200. )
  201. self._use_as: str | None = None
  202. self._state: ArtifactState = ArtifactState.PENDING
  203. # NOTE: These fields only reflect the last fetched response from the
  204. # server, if any. If the ArtifactManifest has already been fetched and/or
  205. # populated locally, it should take priority when determining these values.
  206. self._size: NonNegativeInt | None = None
  207. self._digest: str | None = None
  208. self._manifest: ArtifactManifest | None = ArtifactManifestV1(
  209. storage_policy=make_storage_policy(region=storage_region)
  210. )
  211. self._commit_hash: str | None = None
  212. self._file_count: int | None = None
  213. self._created_at: str | None = None
  214. self._updated_at: str | None = None
  215. self._final: bool = False
  216. self._history_step: int | None = None
  217. self._linked_artifacts: list[Artifact] = []
  218. self._fetch_file_urls_decorated: Callable[..., Any] | None = None
  219. # Cache.
  220. artifact_instance_cache_by_client_id[self._client_id] = self
  221. def __repr__(self) -> str:
  222. return f"<Artifact {self.id or self.name}>"
  223. @classmethod
  224. def _from_id(cls, artifact_id: str, client: RetryingClient) -> Artifact | None:
  225. from ._generated import ARTIFACT_BY_ID_GQL, ArtifactByID
  226. from ._validators import FullArtifactPath
  227. if cached_artifact := artifact_instance_cache.get(artifact_id):
  228. return cached_artifact
  229. gql_op = gql(ARTIFACT_BY_ID_GQL)
  230. data = client.execute(gql_op, variable_values={"id": artifact_id})
  231. result = ArtifactByID.model_validate(data)
  232. if (artifact := result.artifact) is None:
  233. return None
  234. src_collection = artifact.artifact_sequence
  235. src_project = src_collection.project
  236. entity_name = src_project.entity.name if src_project else ""
  237. project_name = src_project.name if src_project else ""
  238. name = f"{src_collection.name}:v{artifact.version_index}"
  239. path = FullArtifactPath(prefix=entity_name, project=project_name, name=name)
  240. return cls._from_attrs(path, artifact, client)
  241. @classmethod
  242. def _membership_from_name(
  243. cls, *, path: FullArtifactPath, client: RetryingClient
  244. ) -> Artifact:
  245. from ._generated import (
  246. ARTIFACT_MEMBERSHIP_BY_NAME_GQL,
  247. ArtifactMembershipByName,
  248. )
  249. if not server_supports(client, pb.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP):
  250. raise UnsupportedError(
  251. "Querying for the artifact collection membership is not supported "
  252. "by this version of wandb server. Consider updating to the latest version."
  253. )
  254. gql_op = gql(ARTIFACT_MEMBERSHIP_BY_NAME_GQL)
  255. gql_vars = {"entity": path.prefix, "project": path.project, "name": path.name}
  256. data = client.execute(gql_op, variable_values=gql_vars)
  257. result = ArtifactMembershipByName.model_validate(data)
  258. if not (project := result.project):
  259. msg = f"project {path.project!r} not found under entity {path.prefix!r}"
  260. raise ValueError(msg)
  261. if not (membership := project.artifact_collection_membership):
  262. entity_project = f"{path.prefix}/{path.project}"
  263. msg = f"artifact membership {path.name!r} not found in {entity_project!r}"
  264. raise ValueError(msg)
  265. return cls._from_membership(membership, target=path, client=client)
  266. @classmethod
  267. def _from_name(
  268. cls,
  269. *,
  270. path: FullArtifactPath,
  271. client: RetryingClient,
  272. enable_tracking: bool = False,
  273. ) -> Artifact:
  274. from ._generated import ARTIFACT_BY_NAME_GQL, ArtifactByName
  275. if server_supports(client, pb.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP):
  276. return cls._membership_from_name(path=path, client=client)
  277. gql_vars = {
  278. "entity": path.prefix,
  279. "project": path.project,
  280. "name": path.name,
  281. "enableTracking": enable_tracking,
  282. }
  283. gql_op = gql(ARTIFACT_BY_NAME_GQL)
  284. data = client.execute(gql_op, variable_values=gql_vars)
  285. result = ArtifactByName.model_validate(data)
  286. if not (project := result.project):
  287. msg = f"project {path.project!r} not found in entity {path.prefix!r}"
  288. raise ValueError(msg)
  289. if not (artifact := project.artifact):
  290. entity_project = f"{path.prefix}/{path.project}"
  291. msg = f"artifact {path.name!r} not found in {entity_project!r}"
  292. raise ValueError(msg)
  293. return cls._from_attrs(path, artifact, client)
  294. @classmethod
  295. def _from_membership(
  296. cls,
  297. membership: ArtifactMembershipFragment,
  298. target: FullArtifactPath,
  299. client: RetryingClient,
  300. ) -> Artifact:
  301. from ._validators import is_artifact_registry_project
  302. if not (
  303. (collection := membership.artifact_collection)
  304. and (name := collection.name)
  305. and (proj := collection.project)
  306. ):
  307. raise ValueError("Missing artifact collection project in GraphQL response")
  308. if is_artifact_registry_project(proj.name) and (
  309. target.project == "model-registry"
  310. ):
  311. wandb.termwarn(
  312. "This model registry has been migrated and will be discontinued. "
  313. f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
  314. f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
  315. )
  316. # Update the target path to use the actual project/entity names returned in the
  317. # response, in case they differ from the original target path
  318. # E.g. uppercase vs lowercase, migrated legacy model registry, etc.
  319. new_target = replace(target, prefix=proj.entity.name, project=proj.name)
  320. if not (artifact := membership.artifact):
  321. raise ValueError(f"Artifact {target.to_str()!r} not found in response")
  322. return cls._from_attrs(new_target, artifact, client, membership=membership)
  323. @classmethod
  324. def _from_attrs(
  325. cls,
  326. path: FullArtifactPath,
  327. src_art: ArtifactFragment,
  328. client: RetryingClient,
  329. *,
  330. # aliases/version_index are taken from the membership, if given
  331. membership: ArtifactMembershipFragment | None = None,
  332. ) -> Artifact:
  333. # Placeholder is required to skip validation.
  334. artifact = cls("placeholder", type="placeholder")
  335. artifact._client = client
  336. artifact._entity = path.prefix
  337. artifact._project = path.project
  338. artifact._name = path.name
  339. artifact._assign_attrs(src_art, membership=membership)
  340. artifact.finalize()
  341. # Cache.
  342. assert artifact.id is not None
  343. artifact_instance_cache[artifact.id] = artifact
  344. return artifact
  345. # TODO: Eventually factor out is_link. Have to currently use it since some forms of fetching the artifact
  346. # doesn't make it clear if the artifact is a link or not and have to manually set it.
  347. def _assign_attrs(
  348. self,
  349. src_art: ArtifactFragment,
  350. *,
  351. # aliases/version_index are taken from the membership, if given
  352. membership: ArtifactMembershipFragment | None = None,
  353. is_link: bool | None = None,
  354. ) -> None:
  355. """Update this Artifact's attributes using the server response."""
  356. from ._validators import validate_metadata, validate_ttl_duration_seconds
  357. self._id = src_art.id
  358. src_collection = src_art.artifact_sequence
  359. src_project = src_collection.project
  360. self._source_entity = src_project.entity.name if src_project else ""
  361. self._source_project = src_project.name if src_project else ""
  362. self._source_name = f"{src_collection.name}:v{src_art.version_index}"
  363. self._source_version = f"v{src_art.version_index}"
  364. self._entity = self._entity or self._source_entity
  365. self._project = self._project or self._source_project
  366. self._name = self._name or self._source_name
  367. # TODO: Refactor artifact query to fetch artifact via membership instead
  368. # and get the collection type
  369. if is_link is None:
  370. self._is_link = (
  371. self._entity != self._source_entity
  372. or self._project != self._source_project
  373. or self._name.split(":")[0] != self._source_name.split(":")[0]
  374. )
  375. else:
  376. self._is_link = is_link
  377. self._type = src_art.artifact_type.name
  378. self._description = src_art.description
  379. # The future of aliases is to move all alias fetches to the membership level
  380. # so we don't have to do the collection fetches below
  381. if membership:
  382. aliases = [a.alias for a in membership.aliases]
  383. elif src_art.aliases:
  384. entity = self._entity
  385. project = self._project
  386. collection = self._name.split(":")[0]
  387. aliases = [
  388. a.alias
  389. for a in src_art.aliases
  390. if (
  391. (alias_coll := a.artifact_collection)
  392. and (alias_proj := alias_coll.project)
  393. and alias_proj.entity.name == entity
  394. and alias_proj.name == project
  395. and alias_coll.name == collection
  396. )
  397. ]
  398. else:
  399. aliases = []
  400. version_aliases = list(filter(alias_is_version_index, aliases))
  401. other_aliases = list(filterfalse(alias_is_version_index, aliases))
  402. try:
  403. version = one(
  404. version_aliases, too_short=TooFewItemsError, too_long=TooManyItemsError
  405. )
  406. except TooFewItemsError:
  407. # default to the membership version if passed to this method,
  408. # otherwise fallback to the source version
  409. if membership and (m_version_index := membership.version_index) is not None:
  410. version = f"v{m_version_index}"
  411. else:
  412. version = f"v{src_art.version_index}"
  413. except TooManyItemsError:
  414. msg = f"Expected at most one version alias, got {len(version_aliases)}: {version_aliases!r}"
  415. raise ValueError(msg) from None
  416. self._version = version
  417. self._name = self._name if (":" in self._name) else f"{self._name}:{version}"
  418. self._aliases = copy(other_aliases)
  419. self._saved_aliases = copy(other_aliases)
  420. self._tags = [tag.name for tag in src_art.tags]
  421. self._saved_tags = copy(self._tags)
  422. self._metadata = validate_metadata(src_art.metadata)
  423. self._ttl_duration_seconds = validate_ttl_duration_seconds(
  424. src_art.ttl_duration_seconds
  425. )
  426. self._ttl_is_inherited = src_art.ttl_is_inherited
  427. self._state = ArtifactState(src_art.state)
  428. self._size = src_art.size
  429. self._digest = src_art.digest
  430. self._manifest = None
  431. self._commit_hash = src_art.commit_hash
  432. self._file_count = src_art.file_count
  433. self._created_at = src_art.created_at
  434. self._updated_at = src_art.updated_at
  435. self._history_step = src_art.history_step
  436. @ensure_logged
  437. def new_draft(self) -> Artifact:
  438. """Create a new draft artifact with the same content as this committed artifact.
  439. Modifying an existing artifact creates a new artifact version known
  440. as an "incremental artifact". The artifact returned can be extended or
  441. modified and logged as a new version.
  442. Returns:
  443. An `Artifact` object.
  444. Raises:
  445. ArtifactNotLoggedError: If the artifact is not logged.
  446. """
  447. # Name, _entity and _project are set to the *source* name/entity/project:
  448. # if this artifact is saved it must be saved to the source sequence.
  449. artifact = Artifact(self.source_name.split(":")[0], self.type)
  450. artifact._entity = self._source_entity
  451. artifact._project = self._source_project
  452. artifact._source_entity = self._source_entity
  453. artifact._source_project = self._source_project
  454. # This artifact's parent is the one we are making a draft from.
  455. artifact._base_id = self.id
  456. # We can reuse the client, and copy over all the attributes that aren't
  457. # version-dependent and don't depend on having been logged.
  458. artifact._client = self._client
  459. artifact._description = self.description
  460. artifact._metadata = self.metadata
  461. artifact._manifest = ArtifactManifest.from_manifest_json(
  462. self.manifest.to_manifest_json()
  463. )
  464. return artifact
  465. # Properties (Python Class managed attributes).
  466. @property
  467. def id(self) -> str | None:
  468. """The artifact's ID."""
  469. if self.is_draft():
  470. return None
  471. assert self._id is not None
  472. return self._id
  473. @property
  474. @ensure_logged
  475. def entity(self) -> str:
  476. """The name of the entity that the artifact collection belongs to.
  477. If the artifact is a link, the entity will be the entity of the linked artifact.
  478. """
  479. assert self._entity is not None
  480. return self._entity
  481. @property
  482. @ensure_logged
  483. def project(self) -> str:
  484. """The name of the project that the artifact collection belongs to.
  485. If the artifact is a link, the project will be the project of the linked artifact.
  486. """
  487. assert self._project is not None
  488. return self._project
  489. @property
  490. def name(self) -> str:
  491. """The artifact name and version of the artifact.
  492. A string with the format `{collection}:{alias}`. If fetched before an artifact is
  493. logged/saved, the name won't contain the alias.
  494. If the artifact is a link, the name will be the name of the linked artifact.
  495. """
  496. return self._name
  497. @property
  498. def qualified_name(self) -> str:
  499. """The entity/project/name of the artifact.
  500. If the artifact is a link, the qualified name will be the qualified name of the
  501. linked artifact path.
  502. """
  503. return f"{self.entity}/{self.project}/{self.name}"
  504. @property
  505. @ensure_logged
  506. def version(self) -> str:
  507. """The artifact's version.
  508. A string with the format `v{number}`.
  509. If this is a link artifact, the version will be from the linked collection.
  510. """
  511. assert self._version is not None
  512. return self._version
  513. @property
  514. @ensure_logged
  515. def collection(self) -> ArtifactCollection:
  516. """The collection this artifact is retrieved from.
  517. A collection is an ordered group of artifact versions.
  518. If this artifact is retrieved from a collection that it is linked to,
  519. return that collection. Otherwise, return the collection
  520. that the artifact version originates from.
  521. The collection that an artifact originates from is known as
  522. the source sequence.
  523. """
  524. if (client := self._client) is None:
  525. raise RuntimeError("Client not initialized")
  526. base_name = self.name.split(":")[0]
  527. return ArtifactCollection(
  528. client, self.entity, self.project, base_name, self.type
  529. )
  530. @property
  531. @ensure_logged
  532. def source_entity(self) -> str:
  533. """The name of the entity of the source artifact."""
  534. assert self._source_entity is not None
  535. return self._source_entity
  536. @property
  537. @ensure_logged
  538. def source_project(self) -> str:
  539. """The name of the project of the source artifact."""
  540. assert self._source_project is not None
  541. return self._source_project
  542. @property
  543. def source_name(self) -> str:
  544. """The artifact name and version of the source artifact.
  545. A string with the format `{source_collection}:{alias}`. Before the artifact
  546. is saved, contains only the name since the version is not yet known.
  547. """
  548. return self._source_name
  549. @property
  550. def source_qualified_name(self) -> str:
  551. """The source_entity/source_project/source_name of the source artifact."""
  552. return f"{self.source_entity}/{self.source_project}/{self.source_name}"
  553. @property
  554. @ensure_logged
  555. def source_version(self) -> str:
  556. """The source artifact's version.
  557. A string with the format `v{number}`.
  558. """
  559. assert self._source_version is not None
  560. return self._source_version
  561. @property
  562. @ensure_logged
  563. def source_collection(self) -> ArtifactCollection:
  564. """The artifact's source collection.
  565. The source collection is the collection that the artifact was logged from.
  566. """
  567. if (client := self._client) is None:
  568. raise RuntimeError("Client not initialized")
  569. base_name = self.source_name.split(":")[0]
  570. return ArtifactCollection(
  571. client, self.source_entity, self.source_project, base_name, self.type
  572. )
  573. @property
  574. def is_link(self) -> bool:
  575. """Boolean flag indicating if the artifact is a link artifact.
  576. True: The artifact is a link artifact to a source artifact.
  577. False: The artifact is a source artifact.
  578. """
  579. return self._is_link
  580. @property
  581. @ensure_logged
  582. def linked_artifacts(self) -> list[Artifact]:
  583. """Returns a list of all the linked artifacts of a source artifact.
  584. If this artifact is a link artifact (`artifact.is_link == True`),
  585. it will return an empty list.
  586. Limited to 500 results.
  587. """
  588. if not self.is_link:
  589. self._linked_artifacts = self._fetch_linked_artifacts()
  590. return self._linked_artifacts
  591. @property
  592. @ensure_logged
  593. def source_artifact(self) -> Artifact:
  594. """Returns the source artifact, which is the original logged artifact.
  595. If this artifact is a source artifact (`artifact.is_link == False`),
  596. it will return itself.
  597. """
  598. from ._validators import FullArtifactPath
  599. if not self.is_link:
  600. return self
  601. if self._source_artifact is None:
  602. if (client := self._client) is None:
  603. raise ValueError("Client is not initialized")
  604. try:
  605. path = FullArtifactPath(
  606. prefix=self.source_entity,
  607. project=self.source_project,
  608. name=self.source_name,
  609. )
  610. self._source_artifact = self._from_name(path=path, client=client)
  611. except Exception as e:
  612. raise ValueError(
  613. f"Unable to fetch source artifact for linked artifact {self.name}"
  614. ) from e
  615. return self._source_artifact
  616. @property
  617. def type(self) -> str:
  618. """The artifact's type. Common types include `dataset` or `model`."""
  619. return self._type
  620. @property
  621. @ensure_logged
  622. def url(self) -> str:
  623. """
  624. Constructs the URL of the artifact.
  625. Returns:
  626. str: The URL of the artifact.
  627. """
  628. from ._validators import is_artifact_registry_project
  629. try:
  630. base_url = self._client.app_url # type: ignore[union-attr]
  631. except AttributeError:
  632. return ""
  633. if not self.is_link:
  634. return self._construct_standard_url(base_url)
  635. if is_artifact_registry_project(self.project):
  636. return self._construct_registry_url(base_url)
  637. if self._type == "model" or self.project == "model-registry":
  638. return self._construct_model_registry_url(base_url)
  639. return self._construct_standard_url(base_url)
  640. def _construct_standard_url(self, base_url: str) -> str:
  641. if not all(
  642. [
  643. base_url,
  644. self.entity,
  645. self.project,
  646. self._type,
  647. self.collection.name,
  648. self._version,
  649. ]
  650. ):
  651. return ""
  652. return urljoin(
  653. base_url,
  654. f"{self.entity}/{self.project}/artifacts/{quote(self._type)}/{quote(self.collection.name)}/{self._version}",
  655. )
  656. def _construct_registry_url(self, base_url: str) -> str:
  657. from ._validators import remove_registry_prefix
  658. if not all(
  659. [
  660. base_url,
  661. self.entity,
  662. self.project,
  663. self.collection.name,
  664. self._version,
  665. ]
  666. ):
  667. return ""
  668. try:
  669. org_name = org_info_from_entity(self._client, self.entity).organization.name # type: ignore[union-attr]
  670. except (AttributeError, ValueError):
  671. return ""
  672. selection_path = quote(
  673. f"{self.entity}/{self.project}/{self.collection.name}", safe=""
  674. )
  675. return urljoin(
  676. base_url,
  677. f"orgs/{org_name}/registry/{remove_registry_prefix(self.project)}?selectionPath={selection_path}&view=membership&version={self.version}",
  678. )
  679. def _construct_model_registry_url(self, base_url: str) -> str:
  680. if not all(
  681. [
  682. base_url,
  683. self.entity,
  684. self.project,
  685. self.collection.name,
  686. self._version,
  687. ]
  688. ):
  689. return ""
  690. selection_path = quote(
  691. f"{self.entity}/{self.project}/{self.collection.name}", safe=""
  692. )
  693. return urljoin(
  694. base_url,
  695. f"{self.entity}/registry/model?selectionPath={selection_path}&view=membership&version={self._version}",
  696. )
  697. @property
  698. def description(self) -> str | None:
  699. """A description of the artifact."""
  700. return self._description
  701. @description.setter
  702. def description(self, description: str | None) -> None:
  703. """Set the description of the artifact.
  704. For model or dataset Artifacts, add documentation for your
  705. standardized team model or dataset card. In the W&B UI the
  706. description is rendered as markdown.
  707. Editing the description will apply the changes to the source artifact
  708. and all linked artifacts associated with it.
  709. Args:
  710. description: Free text that offers a description of the artifact.
  711. """
  712. if self.is_link:
  713. wandb.termwarn(
  714. "Editing the description of this linked artifact will edit the description for the source artifact and it's linked artifacts as well."
  715. )
  716. self._description = description
  717. @property
  718. def metadata(self) -> dict:
  719. """User-defined artifact metadata.
  720. Structured data associated with the artifact.
  721. """
  722. return self._metadata
  723. @metadata.setter
  724. def metadata(self, metadata: dict) -> None:
  725. """User-defined artifact metadata.
  726. Metadata set this way will eventually be queryable and plottable in the UI; e.g.
  727. the class distribution of a dataset.
  728. Note: There is currently a limit of 100 total keys.
  729. Editing the metadata will apply the changes to the source artifact
  730. and all linked artifacts associated with it.
  731. Args:
  732. metadata: Structured data associated with the artifact.
  733. """
  734. from ._validators import validate_metadata
  735. if self.is_link:
  736. wandb.termwarn(
  737. "Editing the metadata of this linked artifact will edit the metadata for the source artifact and it's linked artifacts as well."
  738. )
  739. self._metadata = validate_metadata(metadata)
  740. @property
  741. def ttl(self) -> timedelta | None:
  742. """The time-to-live (TTL) policy of an artifact.
  743. Artifacts are deleted shortly after a TTL policy's duration passes.
  744. If set to `None`, the artifact deactivates TTL policies and will be not
  745. scheduled for deletion, even if there is a team default TTL.
  746. An artifact inherits a TTL policy from
  747. the team default if the team administrator defines a default
  748. TTL and there is no custom policy set on an artifact.
  749. Raises:
  750. ArtifactNotLoggedError: Unable to fetch inherited TTL if the
  751. artifact has not been logged or saved.
  752. """
  753. if self._ttl_is_inherited and (self.is_draft() or self._ttl_changed):
  754. raise ArtifactNotLoggedError(f"{nameof(type(self))}.ttl", self)
  755. if self._ttl_duration_seconds is None:
  756. return None
  757. return timedelta(seconds=self._ttl_duration_seconds)
  758. @ttl.setter
  759. def ttl(self, ttl: timedelta | ArtifactTTL | None) -> None:
  760. """The time-to-live (TTL) policy of an artifact.
  761. Artifacts are deleted shortly after a TTL policy's duration passes.
  762. If set to `None`, the artifact has no TTL policy set and it is not
  763. scheduled for deletion. An artifact inherits a TTL policy from
  764. the team default if the team administrator defines a default
  765. TTL and there is no custom policy set on an artifact.
  766. Args:
  767. ttl: The duration as a positive `datetime.timedelta` that represents
  768. how long the artifact will remain active from its creation.
  769. """
  770. if self.type == "wandb-history":
  771. raise ValueError("Cannot set artifact TTL for type wandb-history")
  772. if self.is_link:
  773. raise ValueError(
  774. "Cannot set TTL for link artifact. "
  775. "Unlink the artifact first then set the TTL for the source artifact"
  776. )
  777. self._ttl_changed = True
  778. if isinstance(ttl, ArtifactTTL):
  779. if ttl == ArtifactTTL.INHERIT:
  780. self._ttl_is_inherited = True
  781. else:
  782. raise ValueError(f"Unhandled ArtifactTTL enum {ttl}")
  783. else:
  784. self._ttl_is_inherited = False
  785. if ttl is None:
  786. self._ttl_duration_seconds = None
  787. else:
  788. if ttl.total_seconds() <= 0:
  789. raise ValueError(
  790. f"Artifact TTL Duration has to be positive. ttl: {ttl.total_seconds()}"
  791. )
  792. self._ttl_duration_seconds = int(ttl.total_seconds())
  793. @property
  794. @ensure_logged
  795. def aliases(self) -> list[str]:
  796. """List of one or more semantically-friendly references or
  797. identifying "nicknames" assigned to an artifact version.
  798. Aliases are mutable references that you can programmatically reference.
  799. Change an artifact's alias with the W&B App UI or programmatically.
  800. See [Create new artifact versions](https://docs.wandb.ai/models/artifacts/create-a-new-artifact-version)
  801. for more information.
  802. """
  803. return self._aliases
  804. @aliases.setter
  805. @ensure_logged
  806. def aliases(self, aliases: list[str]) -> None:
  807. """Set the aliases associated with this artifact."""
  808. from ._validators import validate_aliases
  809. self._aliases = validate_aliases(aliases)
  810. @property
  811. @ensure_logged
  812. def tags(self) -> list[str]:
  813. """List of one or more tags assigned to this artifact version."""
  814. return self._tags
  815. @tags.setter
  816. @ensure_logged
  817. def tags(self, tags: list[str]) -> None:
  818. """Set the tags associated with this artifact.
  819. Editing tags will apply the changes to the source artifact
  820. and all linked artifacts associated with it.
  821. """
  822. from ._validators import validate_tags
  823. if self.is_link:
  824. wandb.termwarn(
  825. "Editing tags will apply the changes to the source artifact and all linked artifacts associated with it."
  826. )
  827. self._tags = validate_tags(tags)
  828. @property
  829. def distributed_id(self) -> str | None:
  830. """The distributed ID of the artifact.
  831. <!-- lazydoc-ignore: internal -->
  832. """
  833. return self._distributed_id
  834. @distributed_id.setter
  835. def distributed_id(self, distributed_id: str | None) -> None:
  836. self._distributed_id = distributed_id
  837. @property
  838. def incremental(self) -> bool:
  839. """Boolean flag indicating if the artifact is an incremental artifact.
  840. <!-- lazydoc-ignore: internal -->
  841. """
  842. return self._incremental
  843. @property
  844. def use_as(self) -> str | None:
  845. """Deprecated."""
  846. warn_and_record_deprecation(
  847. feature=Deprecated(artifact__use_as=True),
  848. message=("The use_as property of Artifact is deprecated."),
  849. )
  850. return self._use_as
  851. @property
  852. def state(self) -> str:
  853. """The status of the artifact. One of: "PENDING", "COMMITTED", or "DELETED"."""
  854. return self._state.value
  855. @property
  856. def manifest(self) -> ArtifactManifest:
  857. """The artifact's manifest.
  858. The manifest lists all of its contents, and can't be changed once the artifact
  859. has been logged.
  860. """
  861. if self._manifest is None:
  862. self._manifest = self._fetch_manifest()
  863. return self._manifest
  864. def _fetch_manifest(self) -> ArtifactManifest:
  865. """Fetch, parse, and load the full ArtifactManifest."""
  866. from ._generated import FETCH_ARTIFACT_MANIFEST_GQL, FetchArtifactManifest
  867. if (client := self._client) is None:
  868. raise RuntimeError("Client not initialized for artifact queries")
  869. # From the GraphQL API, get the (expiring) directUrl for downloading the manifest.
  870. gql_op = gql(FETCH_ARTIFACT_MANIFEST_GQL)
  871. gql_vars = {"id": self.id}
  872. data = client.execute(gql_op, variable_values=gql_vars)
  873. result = FetchArtifactManifest.model_validate(data)
  874. # Now fetch the actual manifest contents from the directUrl.
  875. if (artifact := result.artifact) and (manifest := artifact.current_manifest):
  876. # Create a short lived session instead of using requests.get()
  877. # because make_http_session() adds http headers from env vars.
  878. # Artifact manifest json is also downloaded from object storage
  879. # using presigned urls like artifact files, which requires adding
  880. # extra http headers when user specifies them in env vars.
  881. #
  882. # FIXME: For successive/repeated calls to `manifest`, figure out
  883. # how to reuse a single `requests.Session` within the constraints
  884. # of the current API. Creating a new session for _each_ fetch is
  885. # wasteful and introduces noticeable perf overhead when e.g.
  886. # downloading many artifacts sequentially or concurrently. The
  887. # storage policy's session is also not reused across different
  888. # artifacts.
  889. with make_http_session() as session:
  890. response = session.get(manifest.file.direct_url)
  891. return ArtifactManifest.from_manifest_json(from_json(response.content))
  892. raise ValueError("Failed to fetch artifact manifest")
  893. @property
  894. def digest(self) -> str:
  895. """The logical digest of the artifact.
  896. The digest is the checksum of the artifact's contents. If an artifact has the
  897. same digest as the current `latest` version, then `log_artifact` is a no-op.
  898. """
  899. # Use the last fetched value of `Artifact.digest` ONLY if present AND the manifest
  900. # has not been fetched and/or populated locally.
  901. # Otherwise, use the manifest directly to recalculate the digest, as its contents
  902. # may have been locally modified.
  903. return (
  904. self._digest
  905. if (self._manifest is None) and (self._digest is not None)
  906. else self.manifest.digest()
  907. )
  908. @property
  909. def size(self) -> int:
  910. """The total size of the artifact in bytes.
  911. Includes any references tracked by this artifact.
  912. """
  913. # Use the last fetched value of `Artifact.size` ONLY if present AND the manifest
  914. # has not been fetched and/or populated locally.
  915. # Otherwise, use the manifest directly to recalculate the size, as its contents
  916. # may have been locally modified.
  917. #
  918. # NOTE on choice of GQL field: `Artifact.size` counts references, while
  919. # `Artifact.storageBytes` does not.
  920. return (
  921. self._size
  922. if (self._manifest is None) and (self._size is not None)
  923. else self.manifest.size()
  924. )
  925. @property
  926. @ensure_logged
  927. def commit_hash(self) -> str:
  928. """The hash returned when this artifact was committed."""
  929. assert self._commit_hash is not None
  930. return self._commit_hash
  931. @property
  932. @ensure_logged
  933. def file_count(self) -> int:
  934. """The number of files (including references)."""
  935. assert self._file_count is not None
  936. return self._file_count
  937. @property
  938. @ensure_logged
  939. def created_at(self) -> str:
  940. """Timestamp when the artifact was created."""
  941. assert self._created_at is not None
  942. return self._created_at
  943. @property
  944. @ensure_logged
  945. def updated_at(self) -> str:
  946. """The time when the artifact was last updated."""
  947. assert self._created_at is not None
  948. return self._updated_at or self._created_at
  949. @property
  950. @ensure_logged
  951. def history_step(self) -> int | None:
  952. """The nearest step which logged history metrics for this artifact's source run.
  953. Examples:
  954. ```python
  955. run = artifact.logged_by()
  956. if run and (artifact.history_step is not None):
  957. history = run.sample_history(
  958. min_step=artifact.history_step,
  959. max_step=artifact.history_step + 1,
  960. keys=["my_metric"],
  961. )
  962. ```
  963. """
  964. if self._history_step is None:
  965. return None
  966. return max(0, self._history_step - 1)
  967. # State management.
  968. def finalize(self) -> None:
  969. """Finalize the artifact version.
  970. You cannot modify an artifact version once it is finalized because the artifact
  971. is logged as a specific artifact version. Create a new artifact version
  972. to log more data to an artifact. An artifact is automatically finalized
  973. when you log the artifact with `log_artifact`.
  974. """
  975. self._final = True
  976. def is_draft(self) -> bool:
  977. """Check if artifact is not saved.
  978. Returns:
  979. Boolean. `False` if artifact is saved. `True` if artifact is not saved.
  980. """
  981. return self._state is ArtifactState.PENDING
  982. def _is_draft_save_started(self) -> bool:
  983. return self._save_handle is not None
  984. def save(
  985. self,
  986. project: str | None = None,
  987. settings: wandb.Settings | None = None,
  988. ) -> None:
  989. """Persist any changes made to the artifact.
  990. If currently in a run, that run will log this artifact. If not currently in a
  991. run, a run of type "auto" is created to track this artifact.
  992. Args:
  993. project: A project to use for the artifact in the case that a run is not
  994. already in context.
  995. settings: A settings object to use when initializing an automatic run. Most
  996. commonly used in testing harness.
  997. """
  998. if self._state is not ArtifactState.PENDING:
  999. return self._update()
  1000. if self._incremental:
  1001. with telemetry.context() as tel:
  1002. tel.feature.artifact_incremental = True
  1003. if run := wandb_setup.singleton().most_recent_active_run:
  1004. # TODO: Deprecate and encourage explicit log_artifact().
  1005. run.log_artifact(self)
  1006. else:
  1007. if settings is None:
  1008. settings = wandb.Settings(silent="true")
  1009. with wandb.init( # type: ignore
  1010. entity=self._source_entity,
  1011. project=project or self._source_project,
  1012. job_type="auto",
  1013. settings=settings,
  1014. ) as run:
  1015. # redoing this here because in this branch we know we didn't
  1016. # have the run at the beginning of the method
  1017. if self._incremental:
  1018. with telemetry.context(run=run) as tel:
  1019. tel.feature.artifact_incremental = True
  1020. run.log_artifact(self)
  1021. def _set_save_handle(
  1022. self,
  1023. save_handle: MailboxHandle[pb.Result],
  1024. client: RetryingClient,
  1025. ) -> None:
  1026. self._save_handle = save_handle
  1027. self._client = client
  1028. def wait(self, timeout: int | None = None) -> Artifact:
  1029. """If needed, wait for this artifact to finish logging.
  1030. Args:
  1031. timeout: The time, in seconds, to wait.
  1032. Returns:
  1033. An `Artifact` object.
  1034. """
  1035. if self.is_draft():
  1036. if self._save_handle is None:
  1037. raise ArtifactNotLoggedError(nameof(self.wait), self)
  1038. try:
  1039. result = self._save_handle.wait_or(timeout=timeout)
  1040. except TimeoutError as e:
  1041. raise WaitTimeoutError(
  1042. "Artifact upload wait timed out, failed to fetch Artifact response"
  1043. ) from e
  1044. response = result.response.log_artifact_response
  1045. if response.error_message:
  1046. raise ValueError(response.error_message)
  1047. self._populate_after_save(response.artifact_id)
  1048. return self
  1049. def _populate_after_save(self, artifact_id: str) -> None:
  1050. from ._generated import ARTIFACT_BY_ID_GQL, ArtifactByID
  1051. if (client := self._client) is None:
  1052. raise RuntimeError("Client not initialized for artifact queries")
  1053. gql_op = gql(ARTIFACT_BY_ID_GQL)
  1054. data = client.execute(gql_op, variable_values={"id": artifact_id})
  1055. result = ArtifactByID.model_validate(data)
  1056. if not (artifact := result.artifact):
  1057. raise ValueError(f"Unable to fetch artifact with id: {artifact_id!r}")
  1058. # _populate_after_save is only called on source artifacts, not linked artifacts
  1059. # We have to manually set is_link because we aren't fetching the collection
  1060. # the artifact. That requires greater refactoring for commitArtifact to return
  1061. # the artifact collection type.
  1062. self._assign_attrs(artifact, is_link=False)
  1063. @normalize_exceptions
  1064. def _update(self) -> None:
  1065. """Persists artifact changes to the wandb backend."""
  1066. from ._generated import UPDATE_ARTIFACT_GQL, UpdateArtifact, UpdateArtifactInput
  1067. from ._validators import FullArtifactPath, validate_tags
  1068. if (client := self._client) is None:
  1069. raise RuntimeError("Client not initialized for artifact mutations")
  1070. entity, project, collection = self.entity, self.project, self.name.split(":")[0]
  1071. old_aliases, new_aliases = set(self._saved_aliases), set(self.aliases)
  1072. target = FullArtifactPath(prefix=entity, project=project, name=collection)
  1073. self._add_aliases(new_aliases - old_aliases, target=target)
  1074. self._delete_aliases(old_aliases - new_aliases, target=target)
  1075. self._saved_aliases = copy(self.aliases)
  1076. old_tags, new_tags = set(self._saved_tags), set(self.tags)
  1077. gql_op = gql(UPDATE_ARTIFACT_GQL)
  1078. gql_input = UpdateArtifactInput(
  1079. artifact_id=self.id,
  1080. description=self.description,
  1081. metadata=json_dumps_safer(self.metadata),
  1082. ttl_duration_seconds=self._ttl_duration_seconds_to_gql(),
  1083. tags_to_add=[{"tagName": t} for t in validate_tags(new_tags - old_tags)],
  1084. tags_to_delete=[{"tagName": t} for t in validate_tags(old_tags - new_tags)],
  1085. )
  1086. gql_vars = {"input": gql_input.model_dump()}
  1087. data = client.execute(gql_op, variable_values=gql_vars)
  1088. result = UpdateArtifact.model_validate(data).result
  1089. if not (result and (artifact := result.artifact)):
  1090. raise ValueError("Unable to parse updateArtifact response")
  1091. self._assign_attrs(artifact)
  1092. self._ttl_changed = False # Reset after updating artifact
  1093. def _add_aliases(self, alias_names: set[str], target: FullArtifactPath) -> None:
  1094. from ._generated import ADD_ALIASES_GQL, AddAliasesInput
  1095. if (client := self._client) is None:
  1096. raise RuntimeError("Client not initialized for artifact mutations")
  1097. # If there aren't any aliases to add, we can skip the GraphQL call.
  1098. if alias_names:
  1099. target_props = {
  1100. "entityName": target.prefix,
  1101. "projectName": target.project,
  1102. "artifactCollectionName": target.name,
  1103. }
  1104. alias_inputs = [{**target_props, "alias": name} for name in alias_names]
  1105. gql_op = gql(ADD_ALIASES_GQL)
  1106. gql_input = AddAliasesInput(artifact_id=self.id, aliases=alias_inputs)
  1107. gql_vars = {"input": gql_input.model_dump()}
  1108. try:
  1109. client.execute(gql_op, variable_values=gql_vars)
  1110. except CommError as e:
  1111. msg = (
  1112. "You do not have permission to add"
  1113. f" {'at least one of the following aliases' if len(alias_names) > 1 else 'the following alias'}"
  1114. f" to this artifact: {alias_names!r}"
  1115. )
  1116. raise CommError(msg) from e
  1117. def _delete_aliases(self, alias_names: set[str], target: FullArtifactPath) -> None:
  1118. from ._generated import DELETE_ALIASES_GQL, DeleteAliasesInput
  1119. if (client := self._client) is None:
  1120. raise RuntimeError("Client not initialized for artifact mutations")
  1121. # If there aren't any aliases to delete, we can skip the GraphQL call.
  1122. if alias_names:
  1123. target_props = {
  1124. "entityName": target.prefix,
  1125. "projectName": target.project,
  1126. "artifactCollectionName": target.name,
  1127. }
  1128. alias_inputs = [{**target_props, "alias": name} for name in alias_names]
  1129. gql_op = gql(DELETE_ALIASES_GQL)
  1130. gql_input = DeleteAliasesInput(artifact_id=self.id, aliases=alias_inputs)
  1131. gql_vars = {"input": gql_input.model_dump()}
  1132. try:
  1133. client.execute(gql_op, variable_values=gql_vars)
  1134. except CommError as e:
  1135. msg = (
  1136. f"You do not have permission to delete"
  1137. f" {'at least one of the following aliases' if len(alias_names) > 1 else 'the following alias'}"
  1138. f" from this artifact: {alias_names!r}"
  1139. )
  1140. raise CommError(msg) from e
  1141. # Adding, removing, getting entries.
  1142. def __getitem__(self, name: str) -> WBValue | None:
  1143. """Get the WBValue object located at the artifact relative `name`.
  1144. Args:
  1145. name: The artifact relative name to get.
  1146. Returns:
  1147. W&B object that can be logged with `run.log()` and visualized in the W&B UI.
  1148. Raises:
  1149. ArtifactNotLoggedError: If the artifact isn't logged or the run is offline.
  1150. """
  1151. return self.get(name)
  1152. def __setitem__(self, name: str, item: WBValue) -> ArtifactManifestEntry:
  1153. """Add `item` to the artifact at path `name`.
  1154. Args:
  1155. name: The path within the artifact to add the object.
  1156. item: The object to add.
  1157. Returns:
  1158. The added manifest entry
  1159. Raises:
  1160. ArtifactFinalizedError: You cannot make changes to the current
  1161. artifact version because it is finalized. Log a new artifact
  1162. version instead.
  1163. """
  1164. return self.add(item, name)
  1165. @contextlib.contextmanager
  1166. @ensure_not_finalized
  1167. def new_file(
  1168. self, name: str, mode: str = "x", encoding: str | None = None
  1169. ) -> Iterator[IO]:
  1170. """Open a new temporary file and add it to the artifact.
  1171. Args:
  1172. name: The name of the new file to add to the artifact.
  1173. mode: The file access mode to use to open the new file.
  1174. encoding: The encoding used to open the new file.
  1175. Returns:
  1176. A new file object that can be written to. Upon closing, the file
  1177. is automatically added to the artifact.
  1178. Raises:
  1179. ArtifactFinalizedError: You cannot make changes to the current
  1180. artifact version because it is finalized. Log a new artifact
  1181. version instead.
  1182. """
  1183. overwrite: bool = "x" not in mode
  1184. if self._tmp_dir is None:
  1185. self._tmp_dir = tempfile.TemporaryDirectory()
  1186. path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
  1187. Path(path).parent.mkdir(parents=True, exist_ok=True)
  1188. try:
  1189. with fsync_open(path, mode, encoding) as f:
  1190. yield f
  1191. except FileExistsError:
  1192. raise ValueError(f"File with name {name!r} already exists at {path!r}")
  1193. except UnicodeEncodeError as e:
  1194. termerror(
  1195. f"Failed to open the provided file ({nameof(type(e))}: {e}). Please "
  1196. f"provide the proper encoding."
  1197. )
  1198. raise
  1199. self.add_file(
  1200. path, name=name, policy="immutable", skip_cache=True, overwrite=overwrite
  1201. )
  1202. @ensure_not_finalized
  1203. def add_file(
  1204. self,
  1205. local_path: str,
  1206. name: str | None = None,
  1207. is_tmp: bool | None = False,
  1208. skip_cache: bool | None = False,
  1209. policy: Literal["mutable", "immutable"] | None = "mutable",
  1210. overwrite: bool = False,
  1211. ) -> ArtifactManifestEntry:
  1212. """Add a local file to the artifact.
  1213. Args:
  1214. local_path: The path to the file being added.
  1215. name: The path within the artifact to use for the file being added.
  1216. Defaults to the basename of the file.
  1217. is_tmp: If true, then the file is renamed deterministically to avoid
  1218. collisions.
  1219. skip_cache: If `True`, do not copy files to the cache
  1220. after uploading.
  1221. policy: By default, set to "mutable". If set to "mutable",
  1222. create a temporary copy of the file to prevent corruption
  1223. during upload. If set to "immutable", disable
  1224. protection and rely on the user not to delete or change the
  1225. file.
  1226. overwrite: If `True`, overwrite the file if it already exists.
  1227. Returns:
  1228. The added manifest entry.
  1229. Raises:
  1230. ArtifactFinalizedError: You cannot make changes to the current
  1231. artifact version because it is finalized. Log a new artifact
  1232. version instead.
  1233. ValueError: Policy must be "mutable" or "immutable"
  1234. """
  1235. if not os.path.isfile(local_path):
  1236. raise ValueError(f"Path is not a file: {local_path!r}")
  1237. name = LogicalPath(name or os.path.basename(local_path))
  1238. digest = md5_file_b64(local_path)
  1239. if is_tmp:
  1240. file_path, file_name = os.path.split(name)
  1241. file_name_parts = file_name.split(".")
  1242. file_name_parts[0] = b64_to_hex_id(digest)[:20]
  1243. name = os.path.join(file_path, ".".join(file_name_parts))
  1244. return self._add_local_file(
  1245. name,
  1246. local_path,
  1247. digest=digest,
  1248. skip_cache=skip_cache,
  1249. policy=policy,
  1250. overwrite=overwrite,
  1251. )
  1252. @ensure_not_finalized
  1253. def add_dir(
  1254. self,
  1255. local_path: str,
  1256. name: str | None = None,
  1257. skip_cache: bool | None = False,
  1258. policy: Literal["mutable", "immutable"] | None = "mutable",
  1259. merge: bool = False,
  1260. ) -> None:
  1261. """Add a local directory to the artifact.
  1262. Args:
  1263. local_path: The path of the local directory.
  1264. name: The subdirectory name within an artifact. The name you
  1265. specify appears in the W&B App UI nested by artifact's `type`.
  1266. Defaults to the root of the artifact.
  1267. skip_cache: If set to `True`, W&B will not copy/move files to
  1268. the cache while uploading
  1269. policy: By default, "mutable".
  1270. - mutable: Create a temporary copy of the file to prevent
  1271. corruption during upload.
  1272. - immutable: Disable protection, rely on the user not to delete
  1273. or change the file.
  1274. merge: If `False` (default), throws ValueError if a file was already added
  1275. in a previous add_dir call and its content has changed. If `True`,
  1276. overwrites existing files with changed content. Always adds new files
  1277. and never removes files. To replace an entire directory, pass a name
  1278. when adding the directory using `add_dir(local_path, name=my_prefix)`
  1279. and call `remove(my_prefix)` to remove the directory, then add it again.
  1280. Raises:
  1281. ArtifactFinalizedError: You cannot make changes to the current
  1282. artifact version because it is finalized. Log a new artifact
  1283. version instead.
  1284. ValueError: Policy must be "mutable" or "immutable"
  1285. """
  1286. if not os.path.isdir(local_path):
  1287. raise ValueError(f"Path is not a directory: {local_path!r}")
  1288. termlog(
  1289. f"Adding directory to artifact ({Path('.', local_path)})... ",
  1290. newline=False,
  1291. )
  1292. start_time = time.monotonic()
  1293. paths: deque[tuple[str, str]] = deque()
  1294. logical_root = name or "" # shared prefix, if any, for logical paths
  1295. for dirpath, _, filenames in os.walk(local_path, followlinks=True):
  1296. for fname in filenames:
  1297. physical_path = os.path.join(dirpath, fname)
  1298. logical_path = os.path.relpath(physical_path, start=local_path)
  1299. logical_path = os.path.join(logical_root, logical_path)
  1300. paths.append((logical_path, physical_path))
  1301. def add_manifest_file(logical_pth: str, physical_pth: str) -> None:
  1302. self._add_local_file(
  1303. name=logical_pth,
  1304. path=physical_pth,
  1305. skip_cache=skip_cache,
  1306. policy=policy,
  1307. overwrite=merge,
  1308. )
  1309. num_threads = 8
  1310. pool = multiprocessing.dummy.Pool(num_threads)
  1311. pool.starmap(add_manifest_file, paths)
  1312. pool.close()
  1313. pool.join()
  1314. termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
  1315. @ensure_not_finalized
  1316. def add_reference(
  1317. self,
  1318. uri: ArtifactManifestEntry | str,
  1319. name: StrPath | None = None,
  1320. checksum: bool = True,
  1321. max_objects: int | None = None,
  1322. ) -> Sequence[ArtifactManifestEntry]:
  1323. """Add a reference denoted by a URI to the artifact.
  1324. Unlike files or directories that you add to an artifact, references are not
  1325. uploaded to W&B. For more information,
  1326. see [Track external files](https://docs.wandb.ai/models/artifacts/track-external-files).
  1327. By default, the following schemes are supported:
  1328. - http(s): The size and digest of the file will be inferred by the
  1329. `Content-Length` and the `ETag` response headers returned by the server.
  1330. - s3: The checksum and size are pulled from the object metadata.
  1331. If bucket versioning is enabled, then the version ID is also tracked.
  1332. - gs: The checksum and size are pulled from the object metadata. If bucket
  1333. versioning is enabled, then the version ID is also tracked.
  1334. - https, domain matching `*.blob.core.windows.net`
  1335. - Azure: The checksum and size are be pulled from the blob metadata.
  1336. If storage account versioning is enabled, then the version ID is
  1337. also tracked.
  1338. - file: The checksum and size are pulled from the file system. This scheme
  1339. is useful if you have an NFS share or other externally mounted volume
  1340. containing files you wish to track but not necessarily upload.
  1341. For any other scheme, the digest is just a hash of the URI and the size is left
  1342. blank.
  1343. Args:
  1344. uri: The URI path of the reference to add. The URI path can be an object
  1345. returned from `Artifact.get_entry` to store a reference to another
  1346. artifact's entry.
  1347. name: The path within the artifact to place the contents of this reference.
  1348. checksum: Whether or not to checksum the resource(s) located at the
  1349. reference URI. Checksumming is strongly recommended as it enables
  1350. automatic integrity validation. Disabling checksumming will speed up
  1351. artifact creation but reference directories will not iterated through so
  1352. the objects in the directory will not be saved to the artifact.
  1353. We recommend setting `checksum=False` when adding reference objects,
  1354. in which case a new version will only be created if the reference URI
  1355. changes.
  1356. max_objects: The maximum number of objects to consider when adding a
  1357. reference that points to directory or bucket store prefix.
  1358. By default, the maximum number of objects allowed for Amazon S3,
  1359. GCS, Azure, and local files is 10,000,000. Other URI schemas
  1360. do not have a maximum.
  1361. Returns:
  1362. The added manifest entries.
  1363. Raises:
  1364. ArtifactFinalizedError: You cannot make changes to the current
  1365. artifact version because it is finalized. Log a new artifact
  1366. version instead.
  1367. """
  1368. if name is not None:
  1369. name = LogicalPath(name)
  1370. # This is a bit of a hack, we want to check if the uri is a of the type
  1371. # ArtifactManifestEntry. If so, then recover the reference URL.
  1372. if isinstance(uri, ArtifactManifestEntry):
  1373. uri_str = uri.ref_url()
  1374. elif isinstance(uri, str):
  1375. uri_str = uri
  1376. url = urlparse(str(uri_str))
  1377. if not url.scheme:
  1378. raise ValueError(
  1379. "References must be URIs. To reference a local file, use file://"
  1380. )
  1381. manifest_entries = self.manifest.storage_policy.store_reference(
  1382. self,
  1383. URIStr(uri_str),
  1384. name=name,
  1385. checksum=checksum,
  1386. max_objects=max_objects,
  1387. )
  1388. for entry in manifest_entries:
  1389. self.manifest.add_entry(entry)
  1390. return manifest_entries
  1391. @ensure_not_finalized
  1392. def add(
  1393. self, obj: WBValue, name: StrPath, overwrite: bool = False
  1394. ) -> ArtifactManifestEntry:
  1395. """Add wandb.WBValue `obj` to the artifact.
  1396. Args:
  1397. obj: The object to add. Currently support one of Bokeh, JoinedTable,
  1398. PartitionedTable, Table, Classes, ImageMask, BoundingBoxes2D,
  1399. Audio, Image, Video, Html, Object3D
  1400. name: The path within the artifact to add the object.
  1401. overwrite: If True, overwrite existing objects with the same file
  1402. path if applicable.
  1403. Returns:
  1404. The added manifest entry
  1405. Raises:
  1406. ArtifactFinalizedError: You cannot make changes to the current
  1407. artifact version because it is finalized. Log a new artifact
  1408. version instead.
  1409. """
  1410. name = LogicalPath(name)
  1411. # This is a "hack" to automatically rename tables added to
  1412. # the wandb /media/tables directory to their sha-based name.
  1413. # TODO: figure out a more appropriate convention.
  1414. is_tmp_name = name.startswith("media/tables")
  1415. # Validate that the object is one of the correct wandb.Media types
  1416. # TODO: move this to checking subclass of wandb.Media once all are
  1417. # generally supported
  1418. allowed_types = (
  1419. data_types.Bokeh,
  1420. data_types.JoinedTable,
  1421. data_types.PartitionedTable,
  1422. data_types.Table,
  1423. data_types.Classes,
  1424. data_types.ImageMask,
  1425. data_types.BoundingBoxes2D,
  1426. data_types.Audio,
  1427. data_types.Image,
  1428. data_types.Video,
  1429. data_types.Html,
  1430. data_types.Object3D,
  1431. data_types.Molecule,
  1432. data_types._SavedModel,
  1433. )
  1434. if not isinstance(obj, allowed_types):
  1435. raise TypeError(
  1436. f"Found object of type {obj.__class__}, expected one of:"
  1437. f" {allowed_types}"
  1438. )
  1439. obj_id = id(obj)
  1440. if obj_id in self._added_objs:
  1441. return self._added_objs[obj_id][1]
  1442. # If the object is coming from another artifact, save it as a reference
  1443. ref_path = obj._get_artifact_entry_ref_url()
  1444. if ref_path is not None:
  1445. return self.add_reference(ref_path, type(obj).with_suffix(name))[0]
  1446. val = obj.to_json(self)
  1447. name = obj.with_suffix(name)
  1448. entry = self.manifest.get_entry_by_path(name)
  1449. if (not overwrite) and (entry is not None):
  1450. return entry
  1451. if is_tmp_name:
  1452. file_path = os.path.join(self._TMP_DIR.name, str(id(self)), name)
  1453. folder_path, _ = os.path.split(file_path)
  1454. os.makedirs(folder_path, exist_ok=True)
  1455. with open(file_path, "w", encoding="utf-8") as tmp_f:
  1456. json.dump(val, tmp_f, sort_keys=True)
  1457. else:
  1458. filemode = "w" if overwrite else "x"
  1459. with self.new_file(name, mode=filemode, encoding="utf-8") as f:
  1460. json.dump(val, f, sort_keys=True)
  1461. file_path = f.name
  1462. # Note, we add the file from our temp directory.
  1463. # It will be added again later on finalize, but succeed since
  1464. # the checksum should match
  1465. entry = self.add_file(file_path, name, is_tmp_name)
  1466. # We store a reference to the obj so that its id doesn't get reused.
  1467. self._added_objs[obj_id] = (obj, entry)
  1468. if obj._artifact_target is None:
  1469. obj._set_artifact_target(self, entry.path)
  1470. if is_tmp_name:
  1471. with contextlib.suppress(FileNotFoundError):
  1472. os.remove(file_path)
  1473. return entry
  1474. def _add_local_file(
  1475. self,
  1476. name: StrPath,
  1477. path: StrPath,
  1478. digest: B64MD5 | None = None,
  1479. skip_cache: bool | None = False,
  1480. policy: Literal["mutable", "immutable"] | None = "mutable",
  1481. overwrite: bool = False,
  1482. ) -> ArtifactManifestEntry:
  1483. policy = policy or "mutable"
  1484. if policy not in ["mutable", "immutable"]:
  1485. raise ValueError(
  1486. f"Invalid policy {policy!r}. Policy may only be `mutable` or `immutable`."
  1487. )
  1488. upload_path = path
  1489. if policy == "mutable":
  1490. with tempfile.NamedTemporaryFile(dir=get_staging_dir(), delete=False) as f:
  1491. staging_path = f.name
  1492. shutil.copyfile(path, staging_path)
  1493. # Set as read-only to prevent changes to the file during upload process
  1494. os.chmod(staging_path, stat.S_IRUSR)
  1495. upload_path = staging_path
  1496. entry = ArtifactManifestEntry(
  1497. path=name,
  1498. digest=digest or md5_file_b64(upload_path),
  1499. size=os.path.getsize(upload_path),
  1500. local_path=upload_path,
  1501. skip_cache=skip_cache,
  1502. )
  1503. self.manifest.add_entry(entry, overwrite=overwrite)
  1504. self._added_local_paths[os.fspath(path)] = entry
  1505. return entry
  1506. @ensure_not_finalized
  1507. def remove(self, item: StrPath | ArtifactManifestEntry) -> None:
  1508. """Remove an item from the artifact.
  1509. Args:
  1510. item: The item to remove. Can be a specific manifest entry
  1511. or the name of an artifact-relative path. If the item
  1512. matches a directory all items in that directory will be removed.
  1513. Raises:
  1514. ArtifactFinalizedError: You cannot make changes to the current
  1515. artifact version because it is finalized. Log a new artifact
  1516. version instead.
  1517. FileNotFoundError: If the item isn't found in the artifact.
  1518. """
  1519. if isinstance(item, ArtifactManifestEntry):
  1520. self.manifest.remove_entry(item)
  1521. return
  1522. path = str(PurePosixPath(item))
  1523. if entry := self.manifest.get_entry_by_path(path):
  1524. return self.manifest.remove_entry(entry)
  1525. entries = self.manifest.get_entries_in_directory(path)
  1526. if not entries:
  1527. raise FileNotFoundError(f"No such file or directory: {path}")
  1528. for entry in entries:
  1529. self.manifest.remove_entry(entry)
  1530. def get_path(self, name: StrPath) -> ArtifactManifestEntry:
  1531. """Deprecated. Use `get_entry(name)`."""
  1532. warn_and_record_deprecation(
  1533. feature=Deprecated(artifact__get_path=True),
  1534. message="Artifact.get_path(name) is deprecated, use Artifact.get_entry(name) instead.",
  1535. )
  1536. return self.get_entry(name)
  1537. @ensure_logged
  1538. def get_entry(self, name: StrPath) -> ArtifactManifestEntry:
  1539. """Get the entry with the given name.
  1540. Args:
  1541. name: The artifact relative name to get
  1542. Returns:
  1543. A `W&B` object.
  1544. Raises:
  1545. ArtifactNotLoggedError: if the artifact isn't logged or the run is offline.
  1546. KeyError: if the artifact doesn't contain an entry with the given name.
  1547. """
  1548. name = LogicalPath(name)
  1549. entry = self.manifest.entries.get(name) or self._get_obj_entry(name)[0]
  1550. if entry is None:
  1551. raise KeyError(f"Path not contained in artifact: {name}")
  1552. entry._parent_artifact = self
  1553. return entry
  1554. @ensure_logged
  1555. def get(self, name: str) -> WBValue | None:
  1556. """Get the WBValue object located at the artifact relative `name`.
  1557. Args:
  1558. name: The artifact relative name to retrieve.
  1559. Returns:
  1560. W&B object that can be logged with `run.log()` and
  1561. visualized in the W&B UI.
  1562. Raises:
  1563. ArtifactNotLoggedError: if the artifact isn't logged or the
  1564. run is offline.
  1565. """
  1566. entry, wb_class = self._get_obj_entry(name)
  1567. if entry is None or wb_class is None:
  1568. return None
  1569. # If the entry is a reference from another artifact, then get it directly from
  1570. # that artifact.
  1571. if referenced_id := entry._referenced_artifact_id():
  1572. assert self._client is not None
  1573. artifact = self._from_id(referenced_id, client=self._client)
  1574. assert artifact is not None
  1575. return artifact.get(uri_from_path(entry.ref))
  1576. # Special case for wandb.Table. This is intended to be a short term
  1577. # optimization. Since tables are likely to download many other assets in
  1578. # artifact(s), we eagerly download the artifact using the parallelized
  1579. # `artifact.download`. In the future, we should refactor the deserialization
  1580. # pattern such that this special case is not needed.
  1581. if wb_class == wandb.Table:
  1582. self.download()
  1583. # Get the ArtifactManifestEntry
  1584. item = self.get_entry(entry.path)
  1585. item_path = item.download()
  1586. # Load the object from the JSON blob
  1587. with open(item_path) as file:
  1588. json_obj = json.load(file)
  1589. result = wb_class.from_json(json_obj, self)
  1590. result._set_artifact_source(self, name)
  1591. return result
  1592. def get_added_local_path_name(self, local_path: str) -> str | None:
  1593. """Get the artifact relative name of a file added by a local filesystem path.
  1594. Args:
  1595. local_path: The local path to resolve into an artifact relative name.
  1596. Returns:
  1597. The artifact relative name.
  1598. """
  1599. if entry := self._added_local_paths.get(local_path):
  1600. return entry.path
  1601. return None
  1602. def _get_obj_entry(
  1603. self, name: str
  1604. ) -> tuple[ArtifactManifestEntry, Type[WBValue]] | tuple[None, None]: # noqa: UP006 # `type` shadows `Artifact.type`
  1605. """Return an object entry by name, handling any type suffixes.
  1606. When objects are added with `.add(obj, name)`, the name is typically changed to
  1607. include the suffix of the object type when serializing to JSON. So we need to be
  1608. able to resolve a name, without tasking the user with appending .THING.json.
  1609. This method returns an entry if it exists by a suffixed name.
  1610. Args:
  1611. name: name used when adding
  1612. """
  1613. for wb_class in WBValue.type_mapping().values():
  1614. wandb_file_name = wb_class.with_suffix(name)
  1615. if entry := self.manifest.entries.get(wandb_file_name):
  1616. return entry, wb_class
  1617. return None, None
  1618. # Downloading.
  1619. @ensure_logged
  1620. def download(
  1621. self,
  1622. root: StrPath | None = None,
  1623. allow_missing_references: bool = False,
  1624. skip_cache: bool | None = None,
  1625. path_prefix: StrPath | None = None,
  1626. multipart: bool | None = None,
  1627. ) -> FilePathStr:
  1628. """Download the contents of the artifact to the specified root directory.
  1629. Existing files located within `root` are not modified. Explicitly delete `root`
  1630. before you call `download` if you want the contents of `root` to exactly match
  1631. the artifact.
  1632. Args:
  1633. root: The directory W&B stores the artifact's files.
  1634. allow_missing_references: If set to `True`, any invalid reference paths
  1635. will be ignored while downloading referenced files.
  1636. skip_cache: If set to `True`, the artifact cache will be skipped when
  1637. downloading and W&B will download each file into the default root or
  1638. specified download directory.
  1639. path_prefix: If specified, only files with a path that starts with the given
  1640. prefix will be downloaded. Uses unix format (forward slashes).
  1641. multipart: If set to `None` (default), the artifact will be downloaded
  1642. in parallel using multipart download if individual file size is greater
  1643. than 2GB. If set to `True` or `False`, the artifact will be downloaded in
  1644. parallel or serially regardless of the file size.
  1645. Returns:
  1646. The path to the downloaded contents.
  1647. Raises:
  1648. ArtifactNotLoggedError: If the artifact is not logged.
  1649. """
  1650. root = self._add_download_root(root)
  1651. # TODO: download artifacts using core when implemented
  1652. # if is_require_core():
  1653. # return self._download_using_core(
  1654. # root=root,
  1655. # allow_missing_references=allow_missing_references,
  1656. # skip_cache=bool(skip_cache),
  1657. # path_prefix=path_prefix,
  1658. # )
  1659. return self._download(
  1660. root=root,
  1661. allow_missing_references=allow_missing_references,
  1662. skip_cache=skip_cache,
  1663. path_prefix=path_prefix,
  1664. multipart=multipart,
  1665. )
  1666. def _download_using_core(
  1667. self,
  1668. root: str,
  1669. allow_missing_references: bool = False,
  1670. skip_cache: bool = False,
  1671. path_prefix: StrPath | None = None,
  1672. ) -> FilePathStr:
  1673. import pathlib
  1674. from wandb.sdk.backend.backend import Backend
  1675. # TODO: Create a special stream instead of relying on an existing run.
  1676. if wandb.run is None:
  1677. wl = wandb_setup.singleton()
  1678. stream_id = generate_id()
  1679. settings = wl.settings.to_proto()
  1680. # TODO: remove this
  1681. tmp_dir = pathlib.Path(tempfile.mkdtemp())
  1682. settings.sync_dir.value = str(tmp_dir)
  1683. settings.sync_file.value = str(tmp_dir / f"{stream_id}.wandb")
  1684. settings.run_id.value = stream_id
  1685. service = wl.ensure_service()
  1686. service.inform_init(settings=settings, run_id=stream_id)
  1687. backend = Backend(settings=wl.settings, service=service)
  1688. backend.ensure_launched()
  1689. assert backend.interface
  1690. backend.interface._stream_id = stream_id # type: ignore
  1691. else:
  1692. assert wandb.run._backend
  1693. backend = wandb.run._backend
  1694. assert backend.interface
  1695. handle = backend.interface.deliver_download_artifact(
  1696. self.id, # type: ignore
  1697. root,
  1698. allow_missing_references,
  1699. skip_cache,
  1700. path_prefix, # type: ignore
  1701. )
  1702. # TODO: Start the download process in the user process too, to handle reference downloads
  1703. self._download(
  1704. root=root,
  1705. allow_missing_references=allow_missing_references,
  1706. skip_cache=skip_cache,
  1707. path_prefix=path_prefix,
  1708. )
  1709. result = handle.wait_or(timeout=None)
  1710. response = result.response.download_artifact_response
  1711. if response.error_message:
  1712. raise ValueError(f"Error downloading artifact: {response.error_message}")
  1713. return FilePathStr(root)
  1714. def _download(
  1715. self,
  1716. root: str,
  1717. allow_missing_references: bool = False,
  1718. skip_cache: bool | None = None,
  1719. path_prefix: StrPath | None = None,
  1720. multipart: bool | None = None,
  1721. ) -> FilePathStr:
  1722. nfiles = len(self.manifest.entries)
  1723. size_mb = self.size / _MB
  1724. if log := (nfiles > 5000 or size_mb > 50):
  1725. termlog(
  1726. f"Downloading large artifact {self.name!r}, {size_mb:.2f}MB. {nfiles!r} files...",
  1727. )
  1728. start_time = time.monotonic()
  1729. download_logger = ArtifactDownloadLogger(nfiles=nfiles)
  1730. def _download_entry(
  1731. entry: ArtifactManifestEntry, mp_executor: Executor
  1732. ) -> None:
  1733. multipart_executor = (
  1734. mp_executor
  1735. if should_multipart_download(entry.size, override=multipart)
  1736. else None
  1737. )
  1738. try:
  1739. entry.download(root, skip_cache=skip_cache, executor=multipart_executor)
  1740. except FileNotFoundError as e:
  1741. if allow_missing_references:
  1742. wandb.termwarn(str(e))
  1743. return
  1744. raise
  1745. except _GCSIsADirectoryError as e:
  1746. logger.debug(str(e))
  1747. return
  1748. except IsADirectoryError:
  1749. wandb.termwarn(
  1750. f"Unable to download file {entry.path!r} as there is a directory with the same path, skipping."
  1751. )
  1752. return
  1753. except NotADirectoryError:
  1754. wandb.termwarn(
  1755. f"Unable to download file {entry.path!r} as there is a file with the same path as a directory this file is expected to be in, skipping."
  1756. )
  1757. return
  1758. download_logger.notify_downloaded()
  1759. with (
  1760. ThreadPoolExecutor(max_workers=_FILE_EXECUTOR_WORKERS) as file_executor,
  1761. ThreadPoolExecutor(max_workers=_MP_EXECUTOR_WORKERS) as mp_executor,
  1762. ):
  1763. batch_size = env.get_artifact_fetch_file_url_batch_size()
  1764. active_futures = set()
  1765. cursor, has_more = None, True
  1766. while has_more:
  1767. files_page = self._fetch_file_urls(cursor=cursor, per_page=batch_size)
  1768. has_more = files_page.page_info.has_next_page
  1769. cursor = files_page.page_info.end_cursor
  1770. # `File` nodes are formally nullable, so filter them out just in case.
  1771. file_nodes = (e.node for e in files_page.edges if e.node)
  1772. for node in file_nodes:
  1773. entry = self.get_entry(node.name)
  1774. # TODO: uncomment once artifact downloads are supported in core
  1775. # if require_core and entry.ref is None:
  1776. # # Handled by core
  1777. # continue
  1778. entry._download_url = node.direct_url
  1779. if (not path_prefix) or entry.path.startswith(str(path_prefix)):
  1780. active_futures.add(
  1781. file_executor.submit(
  1782. _download_entry, entry, mp_executor=mp_executor
  1783. )
  1784. )
  1785. # Wait for download threads to catch up.
  1786. #
  1787. # Extra context and observations (tonyyli):
  1788. # - Even though the ThreadPoolExecutor limits the number of
  1789. # concurrently-executed tasks, its internal task queue is unbounded.
  1790. # The code below seems intended to ensure that at most `batch_size`
  1791. # "backlogged" futures are held in memory at any given time. This seems
  1792. # like a reasonable safeguard against unbounded memory consumption.
  1793. #
  1794. # - We should probably use a builtin bounded Queue or Semaphore instead.
  1795. # Consider this for a future change, or (depending on appetite for risk)
  1796. # managing this logic via asyncio instead, if viable.
  1797. if len(active_futures) > batch_size:
  1798. for future in as_completed(active_futures):
  1799. future.result() # check for errors
  1800. active_futures.remove(future)
  1801. if len(active_futures) <= batch_size:
  1802. break
  1803. # Check for errors.
  1804. for future in as_completed(active_futures):
  1805. future.result()
  1806. if log:
  1807. # If you're wondering if we can display a `timedelta`, note that it
  1808. # doesn't really support custom string format specifiers (compared to
  1809. # e.g. `datetime` objs). To truncate the number of decimal places for
  1810. # the seconds part, we manually convert/format each part below.
  1811. dt_secs = abs(time.monotonic() - start_time)
  1812. hrs, mins = divmod(dt_secs, 3600)
  1813. mins, secs = divmod(mins, 60)
  1814. termlog(
  1815. f"Done. {int(hrs):02d}:{int(mins):02d}:{secs:04.1f} ({size_mb / dt_secs:.1f}MB/s)",
  1816. prefix=False,
  1817. )
  1818. return FilePathStr(root)
  1819. def _build_fetch_file_urls_wrapper(self) -> Callable[..., Any]:
  1820. import requests
  1821. @retry.retriable(
  1822. retry_timedelta=timedelta(minutes=3),
  1823. retryable_exceptions=(requests.RequestException),
  1824. )
  1825. def _impl(cursor: str | None, per_page: int = 5000) -> FileWithUrlConnection:
  1826. from ._generated import (
  1827. GET_ARTIFACT_FILE_URLS_GQL,
  1828. GET_ARTIFACT_MEMBERSHIP_FILE_URLS_GQL,
  1829. GetArtifactFileUrls,
  1830. GetArtifactMembershipFileUrls,
  1831. )
  1832. from ._models.pagination import FileWithUrlConnection
  1833. if self._client is None:
  1834. raise RuntimeError("Client not initialized")
  1835. if server_supports(self._client, pb.ARTIFACT_COLLECTION_MEMBERSHIP_FILES):
  1836. query = gql(GET_ARTIFACT_MEMBERSHIP_FILE_URLS_GQL)
  1837. gql_vars = {
  1838. "entity": self.entity,
  1839. "project": self.project,
  1840. "collection": self.name.split(":")[0],
  1841. "alias": self.version,
  1842. "cursor": cursor,
  1843. "perPage": per_page,
  1844. }
  1845. data = self._client.execute(query, variable_values=gql_vars, timeout=60)
  1846. result = GetArtifactMembershipFileUrls.model_validate(data)
  1847. if not (
  1848. (project := result.project)
  1849. and (collection := project.artifact_collection)
  1850. and (membership := collection.artifact_membership)
  1851. and (files := membership.files)
  1852. ):
  1853. raise ValueError(
  1854. f"Unable to fetch files for artifact: {self.name!r}"
  1855. )
  1856. return FileWithUrlConnection.model_validate(files)
  1857. else:
  1858. query = gql(GET_ARTIFACT_FILE_URLS_GQL)
  1859. gql_vars = {"id": self.id, "cursor": cursor, "perPage": per_page}
  1860. data = self._client.execute(query, variable_values=gql_vars, timeout=60)
  1861. result = GetArtifactFileUrls.model_validate(data)
  1862. if not ((artifact := result.artifact) and (files := artifact.files)):
  1863. raise ValueError(
  1864. f"Unable to fetch files for artifact: {self.name!r}"
  1865. )
  1866. return FileWithUrlConnection.model_validate(files)
  1867. return _impl
  1868. def _fetch_file_urls(
  1869. self, cursor: str | None, per_page: int = 5000
  1870. ) -> FileWithUrlConnection:
  1871. if self._fetch_file_urls_decorated is None:
  1872. self._fetch_file_urls_decorated = self._build_fetch_file_urls_wrapper()
  1873. return self._fetch_file_urls_decorated(cursor, per_page)
  1874. @ensure_logged
  1875. def checkout(self, root: str | None = None) -> str:
  1876. """Replace the specified root directory with the contents of the artifact.
  1877. WARNING: This will delete all files in `root` that are not included in the
  1878. artifact.
  1879. Args:
  1880. root: The directory to replace with this artifact's files.
  1881. Returns:
  1882. The path of the checked out contents.
  1883. Raises:
  1884. ArtifactNotLoggedError: If the artifact is not logged.
  1885. """
  1886. root = root or self._default_root(include_version=False)
  1887. for dirpath, _, files in os.walk(root):
  1888. for file in files:
  1889. full_path = os.path.join(dirpath, file)
  1890. artifact_path = os.path.relpath(full_path, start=root)
  1891. try:
  1892. self.get_entry(artifact_path)
  1893. except KeyError:
  1894. # File is not part of the artifact, remove it.
  1895. os.remove(full_path)
  1896. return self.download(root=root)
  1897. @ensure_logged
  1898. def verify(self, root: str | None = None) -> None:
  1899. """Verify that the contents of an artifact match the manifest.
  1900. All files in the directory are checksummed and the checksums are then
  1901. cross-referenced against the artifact's manifest. References are not verified.
  1902. Args:
  1903. root: The directory to verify. If None artifact will be downloaded to
  1904. './artifacts/self.name/'.
  1905. Raises:
  1906. ArtifactNotLoggedError: If the artifact is not logged.
  1907. ValueError: If the verification fails.
  1908. """
  1909. root = root or self._default_root()
  1910. for dirpath, _, files in os.walk(root):
  1911. for file in files:
  1912. full_path = os.path.join(dirpath, file)
  1913. artifact_path = os.path.relpath(full_path, start=root)
  1914. try:
  1915. self.get_entry(artifact_path)
  1916. except KeyError:
  1917. raise ValueError(
  1918. f"Found file {full_path} which is not a member of artifact {self.name}"
  1919. )
  1920. ref_count = 0
  1921. for entry in self.manifest.entries.values():
  1922. if entry.ref is None:
  1923. if md5_file_b64(os.path.join(root, entry.path)) != entry.digest:
  1924. raise ValueError(f"Digest mismatch for file: {entry.path}")
  1925. else:
  1926. ref_count += 1
  1927. if ref_count > 0:
  1928. termwarn(f"skipped verification of {ref_count} refs")
  1929. @ensure_logged
  1930. def file(self, root: str | None = None) -> StrPath:
  1931. """Download a single file artifact to the directory you specify with `root`.
  1932. Args:
  1933. root: The root directory to store the file. Defaults to
  1934. `./artifacts/self.name/`.
  1935. Returns:
  1936. The full path of the downloaded file.
  1937. Raises:
  1938. ArtifactNotLoggedError: If the artifact is not logged.
  1939. ValueError: If the artifact contains more than one file.
  1940. """
  1941. if root is None:
  1942. root = os.path.join(".", "artifacts", self.name)
  1943. if len(self.manifest.entries) > 1:
  1944. raise ValueError(
  1945. "This artifact contains more than one file, call `.download()` to get "
  1946. 'all files or call .get_entry("filename").download()'
  1947. )
  1948. return self.get_entry(list(self.manifest.entries)[0]).download(root)
  1949. @ensure_logged
  1950. def files(
  1951. self, names: list[str] | None = None, per_page: int = 50
  1952. ) -> ArtifactFiles:
  1953. """Iterate over all files stored in this artifact.
  1954. Args:
  1955. names: The filename paths relative to the root of the artifact you wish to
  1956. list.
  1957. per_page: The number of files to return per request.
  1958. Returns:
  1959. An iterator containing `File` objects.
  1960. Raises:
  1961. ArtifactNotLoggedError: If the artifact is not logged.
  1962. """
  1963. if (client := self._client) is None:
  1964. raise RuntimeError("Client not initialized")
  1965. return ArtifactFiles(client, self, names, per_page)
  1966. def _default_root(self, include_version: bool = True) -> FilePathStr:
  1967. name = self.source_name if include_version else self.source_name.split(":")[0]
  1968. root = os.path.join(env.get_artifact_dir(), name)
  1969. # In case we're on a system where the artifact dir has a name corresponding to
  1970. # an unexpected filesystem, we'll check for alternate roots. If one exists we'll
  1971. # use that, otherwise we'll fall back to the system-preferred path.
  1972. return FilePathStr(check_exists(root) or system_preferred_path(root))
  1973. def _add_download_root(self, dir_path: StrPath | None) -> FilePathStr:
  1974. root = str(dir_path or self._default_root())
  1975. self._download_roots.add(os.path.abspath(root))
  1976. return root
  1977. def _local_path_to_name(self, file_path: str) -> str | None:
  1978. """Convert a local file path to a path entry in the artifact."""
  1979. abs_file_path = os.path.abspath(file_path)
  1980. abs_file_parts = abs_file_path.split(os.sep)
  1981. for i in range(len(abs_file_parts) + 1):
  1982. if os.path.join(os.sep, *abs_file_parts[:i]) in self._download_roots:
  1983. return os.path.join(*abs_file_parts[i:])
  1984. return None
  1985. # Others.
  1986. @ensure_logged
  1987. def delete(self, delete_aliases: bool = False) -> None:
  1988. """Delete an artifact and its files.
  1989. If called on a linked artifact, only the link is deleted, and the
  1990. source artifact is unaffected.
  1991. Use `Artifact.unlink()` instead of `Artifact.delete()` to remove a
  1992. link between a source artifact and a collection.
  1993. Args:
  1994. delete_aliases: If set to `True`, delete all aliases associated
  1995. with the artifact. If `False`, raise an exception if
  1996. the artifact has existing aliases. This parameter is ignored
  1997. if the artifact is retrieved from a collection it is linked to.
  1998. Raises:
  1999. ArtifactNotLoggedError: If the artifact is not logged.
  2000. """
  2001. if self.is_link:
  2002. wandb.termwarn(
  2003. "Deleting a link artifact will only unlink the artifact from the source artifact and not delete the source artifact and the data of the source artifact."
  2004. )
  2005. self._unlink()
  2006. else:
  2007. self._delete(delete_aliases)
  2008. @normalize_exceptions
  2009. def _delete(self, delete_aliases: bool = False) -> None:
  2010. from ._generated import DELETE_ARTIFACT_GQL, DeleteArtifactInput
  2011. if self._client is None:
  2012. raise RuntimeError("Client not initialized for artifact mutations")
  2013. gql_op = gql(DELETE_ARTIFACT_GQL)
  2014. gql_input = DeleteArtifactInput(
  2015. artifact_id=self.id,
  2016. delete_aliases=delete_aliases,
  2017. )
  2018. self._client.execute(gql_op, variable_values={"input": gql_input.model_dump()})
  2019. @normalize_exceptions
  2020. def link(self, target_path: str, aliases: Iterable[str] | None = None) -> Artifact:
  2021. """Link this artifact to a collection.
  2022. Args:
  2023. target_path: The path of the collection. Path consists of the prefix
  2024. "wandb-registry-" along with the registry name and the
  2025. collection name `wandb-registry-{REGISTRY_NAME}/{COLLECTION_NAME}`.
  2026. aliases: Add one or more aliases to the linked artifact. The
  2027. "latest" alias is automatically applied to the most recent artifact
  2028. you link.
  2029. Raises:
  2030. ArtifactNotLoggedError: If the artifact is not logged.
  2031. Returns:
  2032. The linked artifact.
  2033. """
  2034. from wandb import Api
  2035. from wandb.sdk.internal.internal_api import Api as InternalApi
  2036. from ._generated import LINK_ARTIFACT_GQL, LinkArtifact, LinkArtifactInput
  2037. from ._validators import ArtifactPath, FullArtifactPath, validate_aliases
  2038. if self.is_link:
  2039. wandb.termwarn(
  2040. "Linking to a link artifact will result in directly linking to the source artifact of that link artifact."
  2041. )
  2042. # Save the artifact first if necessary
  2043. if self.is_draft():
  2044. if not self._is_draft_save_started():
  2045. # Avoiding public `.source_project` property here,
  2046. # as it requires the artifact is logged first.
  2047. self.save(project=self._source_project)
  2048. # Wait until the artifact is committed before trying to link it.
  2049. self.wait()
  2050. if (client := self._client) is None:
  2051. raise RuntimeError("Client not initialized for artifact mutations")
  2052. # FIXME: Find a way to avoid using InternalApi here, due to the perf overhead
  2053. settings = InternalApi().settings()
  2054. target = ArtifactPath.from_str(target_path).with_defaults(
  2055. project=settings.get("project") or "uncategorized",
  2056. )
  2057. # Parse the entity (first part of the path) appropriately,
  2058. # depending on whether we're linking to a registry
  2059. if target.is_registry_path():
  2060. # In a Registry linking, the entity is used to fetch the organization of the
  2061. # artifact, therefore the source artifact's entity is passed to the backend
  2062. org = target.prefix or settings.get("organization") or None
  2063. target.prefix = resolve_org_entity_name(client, self.source_entity, org)
  2064. else:
  2065. target = target.with_defaults(prefix=self.source_entity)
  2066. # Explicitly convert to FullArtifactPath to ensure all fields are present
  2067. target = FullArtifactPath(**asdict(target))
  2068. # Prepare the validated GQL input, send it
  2069. alias_inputs = [
  2070. {"artifactCollectionName": target.name, "alias": a}
  2071. for a in validate_aliases(aliases or [])
  2072. ]
  2073. gql_input = LinkArtifactInput(
  2074. artifact_id=self.id,
  2075. artifact_portfolio_name=target.name,
  2076. entity_name=target.prefix,
  2077. project_name=target.project,
  2078. aliases=alias_inputs,
  2079. )
  2080. gql_vars = {"input": gql_input.model_dump()}
  2081. # Newer server versions can return `artifactMembership` directly in the response,
  2082. # avoiding the need to re-fetch the linked artifact at the end.
  2083. omit_variables = omit_fields = None
  2084. if not server_supports(
  2085. client, pb.ARTIFACT_MEMBERSHIP_IN_LINK_ARTIFACT_RESPONSE
  2086. ):
  2087. omit_variables = {"includeAliases"}
  2088. omit_fields = {"artifactMembership"}
  2089. gql_op = gql_compat(
  2090. LINK_ARTIFACT_GQL, omit_variables=omit_variables, omit_fields=omit_fields
  2091. )
  2092. data = client.execute(gql_op, variable_values=gql_vars)
  2093. result = LinkArtifact.model_validate(data).result
  2094. # Newer server versions can return artifactMembership directly in the response
  2095. if result and (membership := result.artifact_membership):
  2096. return self._from_membership(membership, target=target, client=client)
  2097. # Old behavior, which requires re-fetching the linked artifact to return it
  2098. if not (result and (version_idx := result.version_index) is not None):
  2099. raise ValueError("Unable to parse linked artifact version from response")
  2100. link_name = f"{target.to_str()}:v{version_idx}"
  2101. return Api(overrides={"entity": self.source_entity})._artifact(link_name)
  2102. @ensure_logged
  2103. def unlink(self) -> None:
  2104. """Unlink this artifact if it is a linked member of an artifact collection.
  2105. Raises:
  2106. ArtifactNotLoggedError: If the artifact is not logged.
  2107. ValueError: If the artifact is not linked to any collection.
  2108. """
  2109. # Fail early if this isn't a linked artifact to begin with
  2110. if not self.is_link:
  2111. raise ValueError(
  2112. f"Artifact {self.qualified_name!r} is not a linked artifact and cannot be unlinked. "
  2113. f"To delete it, use {nameof(self.delete)!r} instead."
  2114. )
  2115. self._unlink()
  2116. @normalize_exceptions
  2117. def _unlink(self) -> None:
  2118. from ._generated import UNLINK_ARTIFACT_GQL, UnlinkArtifactInput
  2119. if self._client is None:
  2120. raise RuntimeError("Client not initialized for artifact mutations")
  2121. mutation = gql(UNLINK_ARTIFACT_GQL)
  2122. gql_input = UnlinkArtifactInput(
  2123. artifact_id=self.id,
  2124. artifact_portfolio_id=self.collection.id,
  2125. )
  2126. gql_vars = {"input": gql_input.model_dump()}
  2127. try:
  2128. self._client.execute(mutation, variable_values=gql_vars)
  2129. except CommError as e:
  2130. raise CommError(
  2131. f"You do not have permission to unlink the artifact {self.qualified_name!r}"
  2132. ) from e
  2133. @ensure_logged
  2134. def used_by(self) -> list[Run]:
  2135. """Get a list of the runs that have used this artifact and its linked artifacts.
  2136. Returns:
  2137. A list of `Run` objects.
  2138. Raises:
  2139. ArtifactNotLoggedError: If the artifact is not logged.
  2140. """
  2141. from ._generated import ARTIFACT_USED_BY_GQL, ArtifactUsedBy
  2142. if (client := self._client) is None:
  2143. raise RuntimeError("Client not initialized for artifact queries")
  2144. query = gql(ARTIFACT_USED_BY_GQL)
  2145. gql_vars = {"id": self.id}
  2146. data = client.execute(query, variable_values=gql_vars)
  2147. result = ArtifactUsedBy.model_validate(data)
  2148. if (
  2149. (artifact := result.artifact)
  2150. and (used_by := artifact.used_by)
  2151. and (edges := used_by.edges)
  2152. ):
  2153. run_nodes = (e.node for e in edges)
  2154. return [
  2155. Run(client, proj.entity.name, proj.name, run.name)
  2156. for run in run_nodes
  2157. if (proj := run.project)
  2158. ]
  2159. return []
  2160. @ensure_logged
  2161. def logged_by(self) -> Run | None:
  2162. """Get the W&B run that originally logged the artifact.
  2163. Returns:
  2164. The name of the W&B run that originally logged the artifact.
  2165. Raises:
  2166. ArtifactNotLoggedError: If the artifact is not logged.
  2167. """
  2168. from ._generated import ARTIFACT_CREATED_BY_GQL, ArtifactCreatedBy
  2169. if (client := self._client) is None:
  2170. raise RuntimeError("Client not initialized for artifact queries")
  2171. gql_op = gql(ARTIFACT_CREATED_BY_GQL)
  2172. gql_vars = {"id": self.id}
  2173. data = client.execute(gql_op, variable_values=gql_vars)
  2174. result = ArtifactCreatedBy.model_validate(data)
  2175. if (
  2176. (artifact := result.artifact)
  2177. and (creator := artifact.created_by)
  2178. and (name := creator.name)
  2179. and (project := creator.project)
  2180. ):
  2181. return Run(client, project.entity.name, project.name, name)
  2182. return None
  2183. @ensure_logged
  2184. def json_encode(self) -> dict[str, Any]:
  2185. """Returns the artifact encoded to the JSON format.
  2186. Returns:
  2187. A `dict` with `string` keys representing attributes of the artifact.
  2188. """
  2189. return artifact_to_json(self)
  2190. @staticmethod
  2191. def _expected_type(
  2192. entity_name: str, project_name: str, name: str, client: RetryingClient
  2193. ) -> str | None:
  2194. """Returns the expected type for a given artifact name and project."""
  2195. from ._generated import ARTIFACT_TYPE_GQL, ArtifactType
  2196. name = name if (":" in name) else f"{name}:latest"
  2197. gql_op = gql(ARTIFACT_TYPE_GQL)
  2198. gql_vars = {"entity": entity_name, "project": project_name, "name": name}
  2199. data = client.execute(gql_op, variable_values=gql_vars)
  2200. result = ArtifactType.model_validate(data)
  2201. if (project := result.project) and (artifact := project.artifact):
  2202. return artifact.artifact_type.name
  2203. return None
  2204. def _ttl_duration_seconds_to_gql(self) -> int | None:
  2205. # Set the artifact TTL to `ttl_duration_seconds` if the user provided a value.
  2206. # Otherwise, use `ttl_status` to indicate backend values INHERIT (-1) or
  2207. # DISABLED (-2) when the TTL is None.
  2208. # When `ttl_change is None`, nothing changed and this is a no-op.
  2209. INHERIT = -1 # noqa: N806
  2210. DISABLED = -2 # noqa: N806
  2211. if not self._ttl_changed:
  2212. return None
  2213. if self._ttl_is_inherited:
  2214. return INHERIT
  2215. return self._ttl_duration_seconds or DISABLED
  2216. def _fetch_linked_artifacts(self) -> list[Artifact]:
  2217. """Fetches all linked artifacts from the server."""
  2218. from wandb._pydantic import gql_typename
  2219. from ._generated import (
  2220. FETCH_LINKED_ARTIFACTS_GQL,
  2221. ArtifactPortfolioTypeFields,
  2222. FetchLinkedArtifacts,
  2223. )
  2224. from ._validators import LinkArtifactFields
  2225. if self.id is None:
  2226. raise ValueError(
  2227. "Unable to find any artifact memberships for artifact without an ID"
  2228. )
  2229. if (client := self._client) is None:
  2230. raise ValueError("Client is not initialized")
  2231. gql_op = gql_compat(FETCH_LINKED_ARTIFACTS_GQL)
  2232. data = client.execute(gql_op, variable_values={"artifactID": self.id})
  2233. result = FetchLinkedArtifacts.model_validate(data)
  2234. if not (
  2235. (artifact := result.artifact)
  2236. and (memberships := artifact.artifact_memberships)
  2237. and (membership_edges := memberships.edges)
  2238. ):
  2239. raise ValueError("Unable to find any artifact memberships for artifact")
  2240. linked_artifacts: deque[Artifact] = deque()
  2241. linked_nodes = (
  2242. node
  2243. for edge in membership_edges
  2244. if (
  2245. (node := edge.node)
  2246. and (col := node.artifact_collection)
  2247. and (col.typename__ == gql_typename(ArtifactPortfolioTypeFields))
  2248. )
  2249. )
  2250. for node in linked_nodes:
  2251. alias_names = unique_list(a.alias for a in node.aliases)
  2252. version = f"v{node.version_index}"
  2253. aliases = (
  2254. [*alias_names, version]
  2255. if version not in alias_names
  2256. else [*alias_names]
  2257. )
  2258. if not (
  2259. node
  2260. and (col := node.artifact_collection)
  2261. and (proj := col.project)
  2262. and (proj.entity.name and proj.name)
  2263. ):
  2264. raise ValueError("Unable to fetch fields for linked artifact")
  2265. link_fields = LinkArtifactFields(
  2266. entity_name=proj.entity.name,
  2267. project_name=proj.name,
  2268. name=f"{col.name}:{version}",
  2269. version=version,
  2270. aliases=aliases,
  2271. )
  2272. link = self._create_linked_artifact_using_source_artifact(link_fields)
  2273. linked_artifacts.append(link)
  2274. return list(linked_artifacts)
  2275. def _create_linked_artifact_using_source_artifact(
  2276. self,
  2277. link_fields: LinkArtifactFields,
  2278. ) -> Artifact:
  2279. """Copies the source artifact to a linked artifact."""
  2280. linked_artifact = copy(self)
  2281. linked_artifact._version = link_fields.version
  2282. linked_artifact._aliases = link_fields.aliases
  2283. linked_artifact._saved_aliases = copy(link_fields.aliases)
  2284. linked_artifact._name = link_fields.name
  2285. linked_artifact._entity = link_fields.entity_name
  2286. linked_artifact._project = link_fields.project_name
  2287. linked_artifact._is_link = link_fields.is_link
  2288. linked_artifact._linked_artifacts = link_fields.linked_artifacts
  2289. return linked_artifact
  2290. class _ArtifactVersionType(WBType):
  2291. name = "artifactVersion"
  2292. types = [Artifact]
  2293. TypeRegistry.add(_ArtifactVersionType)