artifacts.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077
  1. """W&B Public API for Artifact objects.
  2. This module provides classes for interacting with W&B artifacts and their
  3. collections.
  4. """
  5. from __future__ import annotations
  6. import json
  7. from collections.abc import Collection, Iterable, Mapping, Sequence
  8. from copy import copy
  9. from functools import lru_cache
  10. from typing import TYPE_CHECKING, Any, ClassVar, List, Literal, TypeVar # noqa: UP035
  11. from typing_extensions import override
  12. from wandb_gql import gql
  13. from wandb._iterutils import always_list
  14. from wandb._pydantic import Connection, ConnectionWithTotal, Edge
  15. from wandb._strutils import nameof
  16. from wandb.apis.normalize import normalize_exceptions
  17. from wandb.apis.paginator import RelayPaginator, SizedRelayPaginator
  18. from wandb.errors.errors import UnsupportedError
  19. from wandb.errors.term import termlog
  20. from wandb.proto import wandb_internal_pb2 as pb
  21. from wandb.proto.wandb_telemetry_pb2 import Deprecated
  22. from wandb.sdk.artifacts._gqlutils import server_supports
  23. from wandb.sdk.artifacts._models import ArtifactCollectionData
  24. from wandb.sdk.lib.deprecation import warn_and_record_deprecation
  25. from .files import File
  26. from .utils import gql_compat
  27. if TYPE_CHECKING:
  28. from wandb_graphql.language.ast import Document
  29. from wandb.apis.public.api import RetryingClient
  30. from wandb.sdk.artifacts._generated import (
  31. ArtifactAliasFragment,
  32. ArtifactCollectionFragment,
  33. ArtifactFragment,
  34. ArtifactTypeFragment,
  35. FileFragment,
  36. )
  37. from wandb.sdk.artifacts._models.pagination import (
  38. ArtifactCollectionConnection,
  39. ArtifactFileConnection,
  40. ArtifactTypeConnection,
  41. )
  42. from wandb.sdk.artifacts.artifact import Artifact
  43. from . import Run
  44. TNode = TypeVar("TNode")
  45. @lru_cache(maxsize=1)
  46. def _run_artifacts_mode_to_gql() -> dict[Literal["logged", "used"], str]:
  47. """Lazily import and cache the run artifact GQL query strings.
  48. This keeps import-time light and only loads the generated GQL
  49. when RunArtifacts is actually used.
  50. """
  51. from wandb.sdk.artifacts._generated import (
  52. RUN_INPUT_ARTIFACTS_GQL,
  53. RUN_OUTPUT_ARTIFACTS_GQL,
  54. )
  55. return {"logged": RUN_OUTPUT_ARTIFACTS_GQL, "used": RUN_INPUT_ARTIFACTS_GQL}
  56. class _ArtifactCollectionAliases(RelayPaginator["ArtifactAliasFragment", str]):
  57. """An internal iterator of collection alias names.
  58. <!-- lazydoc-ignore-init: internal -->
  59. """
  60. QUERY: ClassVar[Document | None] = None
  61. last_response: Connection[ArtifactAliasFragment] | None
  62. def __init__(
  63. self,
  64. client: RetryingClient,
  65. collection_id: str,
  66. per_page: int = 1_000,
  67. ):
  68. if self.QUERY is None:
  69. from wandb.sdk.artifacts._generated import ARTIFACT_COLLECTION_ALIASES_GQL
  70. type(self).QUERY = gql(ARTIFACT_COLLECTION_ALIASES_GQL)
  71. variables = {"id": collection_id}
  72. super().__init__(client, variables=variables, per_page=per_page)
  73. def _update_response(self) -> None:
  74. from wandb.sdk.artifacts._generated import (
  75. ArtifactAliasFragment,
  76. ArtifactCollectionAliases,
  77. )
  78. data = self.client.execute(self.QUERY, variable_values=self.variables)
  79. result = ArtifactCollectionAliases.model_validate(data)
  80. # Extract the inner `*Connection` result for faster/easier access.
  81. if not ((coll := result.artifact_collection) and (conn := coll.aliases)):
  82. raise ValueError(f"Unable to parse {nameof(type(self))!r} response data")
  83. self.last_response = Connection[ArtifactAliasFragment].model_validate(conn)
  84. def _convert(self, node: ArtifactAliasFragment) -> str:
  85. return node.alias
  86. class ArtifactTypes(RelayPaginator["ArtifactTypeFragment", "ArtifactType"]):
  87. """An lazy iterator of `ArtifactType` objects for a specific project.
  88. <!-- lazydoc-ignore-init: internal -->
  89. """
  90. QUERY: ClassVar[Document | None] = None
  91. last_response: ArtifactTypeConnection | None
  92. def __init__(
  93. self,
  94. client: RetryingClient,
  95. entity: str,
  96. project: str,
  97. per_page: int = 50,
  98. ):
  99. if self.QUERY is None:
  100. from wandb.sdk.artifacts._generated import PROJECT_ARTIFACT_TYPES_GQL
  101. type(self).QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL)
  102. self.entity = entity
  103. self.project = project
  104. variables = {"entity": entity, "project": project}
  105. super().__init__(client, variables=variables, per_page=per_page)
  106. @override
  107. def _update_response(self) -> None:
  108. """Fetch and validate the response data for the current page."""
  109. from wandb.sdk.artifacts._generated import ProjectArtifactTypes
  110. from wandb.sdk.artifacts._models.pagination import ArtifactTypeConnection
  111. data = self.client.execute(self.QUERY, variable_values=self.variables)
  112. result = ProjectArtifactTypes.model_validate(data)
  113. # Extract the inner `*Connection` result for faster/easier access.
  114. if not ((proj := result.project) and (conn := proj.artifact_types)):
  115. raise ValueError(f"Unable to parse {nameof(type(self))!r} response data")
  116. self.last_response = ArtifactTypeConnection.model_validate(conn)
  117. def _convert(self, node: ArtifactTypeFragment) -> ArtifactType:
  118. return ArtifactType(
  119. client=self.client,
  120. entity=self.entity,
  121. project=self.project,
  122. type_name=node.name,
  123. attrs=node,
  124. )
  125. class ArtifactType:
  126. """An artifact object that satisfies query based on the specified type.
  127. Args:
  128. client: The client instance to use for querying W&B.
  129. entity: The entity (user or team) that owns the project.
  130. project: The name of the project to query for artifact types.
  131. type_name: The name of the artifact type.
  132. attrs: Optional attributes to initialize the ArtifactType.
  133. If omitted, the object will load its attributes from W&B upon
  134. initialization.
  135. <!-- lazydoc-ignore-init: internal -->
  136. """
  137. _attrs: ArtifactTypeFragment
  138. def __init__(
  139. self,
  140. client: RetryingClient,
  141. entity: str,
  142. project: str,
  143. type_name: str,
  144. attrs: ArtifactTypeFragment | None = None,
  145. ):
  146. from wandb.sdk.artifacts._generated import ArtifactTypeFragment
  147. self.client = client
  148. self.entity = entity
  149. self.project = project
  150. self.type = type_name
  151. # FIXME: Make this lazy, so we don't (re-)fetch the attributes until they are needed
  152. self._attrs = ArtifactTypeFragment.model_validate(attrs or self.load())
  153. def load(self) -> ArtifactTypeFragment:
  154. """Load the artifact type attributes from W&B.
  155. <!-- lazydoc-ignore: internal -->
  156. """
  157. from wandb.sdk.artifacts._generated import (
  158. PROJECT_ARTIFACT_TYPE_GQL,
  159. ArtifactTypeFragment,
  160. ProjectArtifactType,
  161. )
  162. gql_op = gql(PROJECT_ARTIFACT_TYPE_GQL)
  163. gql_vars = {"entity": self.entity, "project": self.project, "type": self.type}
  164. data = self.client.execute(gql_op, variable_values=gql_vars)
  165. result = ProjectArtifactType.model_validate(data)
  166. if not ((proj := result.project) and (artifact_type := proj.artifact_type)):
  167. raise ValueError(f"Could not find artifact type {self.type!r}")
  168. return ArtifactTypeFragment.model_validate(artifact_type)
  169. @property
  170. def id(self) -> str:
  171. """The unique identifier of the artifact type."""
  172. return self._attrs.id
  173. @property
  174. def name(self) -> str:
  175. """The name of the artifact type."""
  176. return self._attrs.name
  177. @normalize_exceptions
  178. def collections(
  179. self,
  180. filters: Mapping[str, Any] | None = None,
  181. order: str | None = None,
  182. per_page: int = 50,
  183. ) -> ArtifactCollections:
  184. """Get all artifact collections associated with this artifact type.
  185. Args:
  186. filters (dict): Optional mapping of filters to apply to the query.
  187. order (str): Optional string to specify the order of the results.
  188. If you prepend order with a + order is ascending (default).
  189. If you prepend order with a - order is descending.
  190. The default order is the collection ID in descending order.
  191. per_page (int): The number of artifact collections to fetch per page.
  192. Default is 50.
  193. """
  194. return ArtifactCollections(
  195. self.client,
  196. entity=self.entity,
  197. project=self.project,
  198. filters=filters,
  199. order=order,
  200. type_name=self.type,
  201. per_page=per_page,
  202. )
  203. def collection(self, name: str) -> ArtifactCollection:
  204. """Get a specific artifact collection by name.
  205. Args:
  206. name (str): The name of the artifact collection to retrieve.
  207. """
  208. return ArtifactCollection(
  209. self.client,
  210. entity=self.entity,
  211. project=self.project,
  212. name=name,
  213. type=self.type,
  214. )
  215. def __repr__(self) -> str:
  216. return f"<ArtifactType {self.type}>"
  217. class ArtifactCollections(
  218. SizedRelayPaginator["ArtifactCollectionFragment", "ArtifactCollection"]
  219. ):
  220. """Artifact collections of a specific type in a project.
  221. Args:
  222. client: The client instance to use for querying W&B.
  223. entity: The entity (user or team) that owns the project.
  224. project: The name of the project to query for artifact collections.
  225. type_name: The name of the artifact type for which to fetch collections.
  226. filters: Optional mapping of filters to apply to the query.
  227. order: Optional string to specify the order of the results.
  228. If you prepend order with a + order is ascending (default).
  229. If you prepend order with a - order is descending.
  230. per_page: The number of artifact collections to fetch per page. Default is 50.
  231. <!-- lazydoc-ignore-init: internal -->
  232. """
  233. QUERY: ClassVar[Document | None] = None
  234. last_response: ArtifactCollectionConnection | None
  235. def __init__(
  236. self,
  237. client: RetryingClient,
  238. entity: str,
  239. project: str,
  240. type_name: str,
  241. filters: Mapping[str, Any] | None = None,
  242. order: str | None = None,
  243. per_page: int = 50,
  244. ):
  245. if self.QUERY is None:
  246. from wandb.sdk.artifacts._generated import (
  247. ARTIFACT_TYPE_ARTIFACT_COLLECTIONS_GQL,
  248. )
  249. type(self).QUERY = gql(ARTIFACT_TYPE_ARTIFACT_COLLECTIONS_GQL)
  250. if (order is not None or filters is not None) and not server_supports(
  251. client, pb.ARTIFACT_COLLECTIONS_FILTERING_SORTING
  252. ):
  253. raise UnsupportedError(
  254. "Filtering and ordering of artifact collections is not supported on this wandb server version. "
  255. "Please upgrade your server version or contact support at support@wandb.com."
  256. )
  257. self.entity = entity
  258. self.project = project
  259. self.type_name = type_name
  260. self.filters = filters
  261. self.order = order
  262. variables = {
  263. "entity": entity,
  264. "project": project,
  265. "type": type_name,
  266. "order": order,
  267. "filters": json.dumps(f) if (f := filters) else None,
  268. }
  269. super().__init__(client, variables=variables, per_page=per_page)
  270. @override
  271. def _update_response(self) -> None:
  272. """Fetch and validate the response data for the current page."""
  273. from wandb.sdk.artifacts._generated import ArtifactTypeArtifactCollections
  274. from wandb.sdk.artifacts._models.pagination import ArtifactCollectionConnection
  275. data = self.client.execute(self.QUERY, variable_values=self.variables)
  276. result = ArtifactTypeArtifactCollections.model_validate(data)
  277. # Extract the inner `*Connection` result for faster/easier access.
  278. if not (
  279. (proj := result.project)
  280. and (artifact_type := proj.artifact_type)
  281. and (conn := artifact_type.artifact_collections)
  282. ):
  283. raise ValueError(f"Unable to parse {nameof(type(self))!r} response data")
  284. self.last_response = ArtifactCollectionConnection.model_validate(conn)
  285. def _convert(self, node: ArtifactCollectionFragment) -> ArtifactCollection | None:
  286. if not node.project:
  287. return None
  288. return ArtifactCollection(
  289. client=self.client,
  290. entity=node.project.entity.name,
  291. project=node.project.name,
  292. name=node.name,
  293. type=node.type.name,
  294. attrs=node,
  295. )
  296. class ProjectArtifactCollections(
  297. SizedRelayPaginator["ArtifactCollectionFragment", "ArtifactCollection"]
  298. ):
  299. """Artifact collections in a project.
  300. Args:
  301. client: The client instance to use for querying W&B.
  302. entity: The entity (user or team) that owns the project.
  303. project: The name of the project to query for artifact collections.
  304. filters: Optional mapping of filters to apply to the query.
  305. order: Optional string to specify the order of the results.
  306. If you prepend order with a + order is ascending (default).
  307. If you prepend order with a - order is descending.
  308. per_page: The number of artifact collections to fetch per page. Default is 50.
  309. <!-- lazydoc-ignore-init: internal -->
  310. """
  311. QUERY: ClassVar[Document | None] = None
  312. last_response: ArtifactCollectionConnection | None
  313. def __init__(
  314. self,
  315. client: RetryingClient,
  316. entity: str,
  317. project: str,
  318. filters: Mapping[str, Any] | None = None,
  319. order: str | None = None,
  320. per_page: int = 50,
  321. ):
  322. if (order is not None or filters is not None) and not server_supports(
  323. client, pb.ARTIFACT_COLLECTIONS_FILTERING_SORTING
  324. ):
  325. raise UnsupportedError(
  326. "Filtering and ordering of artifact collections is not supported on this wandb server version. "
  327. "Please upgrade your server version or contact support at support@wandb.com."
  328. )
  329. if self.QUERY is None:
  330. from wandb.sdk.artifacts._generated import PROJECT_ARTIFACT_COLLECTIONS_GQL
  331. omit_fields = (
  332. None
  333. if server_supports(client, pb.ARTIFACT_COLLECTIONS_FILTERING_SORTING)
  334. else {"totalCount"}
  335. )
  336. omit_variables = (
  337. None
  338. if server_supports(client, pb.ARTIFACT_COLLECTIONS_FILTERING_SORTING)
  339. else {"filters"}
  340. )
  341. type(self).QUERY = gql_compat(
  342. PROJECT_ARTIFACT_COLLECTIONS_GQL,
  343. omit_variables=omit_variables,
  344. omit_fields=omit_fields,
  345. )
  346. self.entity = entity
  347. self.project = project
  348. self.filters = filters
  349. self.order = order
  350. variables = {
  351. "entity": entity,
  352. "project": project,
  353. "order": order,
  354. "filters": json.dumps(f) if (f := filters) else None,
  355. }
  356. super().__init__(client, variables=variables, per_page=per_page)
  357. @override
  358. def _update_response(self) -> None:
  359. """Fetch and validate the response data for the current page."""
  360. from wandb.sdk.artifacts._generated import ProjectArtifactCollections
  361. from wandb.sdk.artifacts._models.pagination import (
  362. ProjectArtifactCollectionConnection,
  363. )
  364. data = self.client.execute(self.QUERY, variable_values=self.variables)
  365. result = ProjectArtifactCollections.model_validate(data)
  366. # Extract the inner `*Connection` result for faster/easier access.
  367. if not ((proj := result.project) and (conn := proj.artifact_collections)):
  368. raise ValueError(f"Unable to parse {nameof(type(self))!r} response data")
  369. self.last_response = ProjectArtifactCollectionConnection.model_validate(conn)
  370. def _convert(self, node: ArtifactCollectionFragment) -> ArtifactCollection | None:
  371. if not node.project:
  372. return None
  373. return ArtifactCollection(
  374. client=self.client,
  375. entity=node.project.entity.name,
  376. project=node.project.name,
  377. name=node.name,
  378. type=node.type.name,
  379. attrs=node,
  380. )
  381. class ArtifactCollection:
  382. """An artifact collection that represents a group of related artifacts.
  383. Args:
  384. client: The client instance to use for querying W&B.
  385. entity: The entity (user or team) that owns the project.
  386. project: The name of the project to query for artifact collections.
  387. name: The name of the artifact collection.
  388. type: The type of the artifact collection (e.g., "dataset", "model").
  389. organization: Optional organization name if applicable.
  390. attrs: Optional mapping of attributes to initialize the artifact collection.
  391. If not provided, the object will load its attributes from W&B upon
  392. initialization.
  393. <!-- lazydoc-ignore-init: internal -->
  394. """
  395. _saved: ArtifactCollectionData
  396. """The saved artifact collection data as last fetched from the W&B server."""
  397. _current: ArtifactCollectionData
  398. """The local, editable artifact collection data."""
  399. def __init__(
  400. self,
  401. client: RetryingClient,
  402. entity: str,
  403. project: str,
  404. name: str,
  405. type: str,
  406. organization: str | None = None,
  407. attrs: ArtifactCollectionFragment | None = None,
  408. ):
  409. self.client = client
  410. # FIXME: Make this lazy, so we don't (re-)fetch the attributes until they are needed
  411. self._update_data(attrs or self.load(entity, project, type, name))
  412. self.organization = organization
  413. def _update_data(self, fragment: ArtifactCollectionFragment) -> None:
  414. """Update the saved/current state of this collection with the given fragment.
  415. Can be used after receiving a GraphQL response with ArtifactCollection data.
  416. """
  417. # Separate "saved" vs "current" copies of the artifact collection data
  418. validated = ArtifactCollectionData.from_fragment(fragment)
  419. self._saved = validated
  420. self._current = validated.model_copy(deep=True)
  421. @property
  422. def id(self) -> str:
  423. """The unique identifier of the artifact collection."""
  424. return self._current.id
  425. @property
  426. def entity(self) -> str:
  427. """The entity (user or team) that owns the project."""
  428. return self._current.entity
  429. @property
  430. def project(self) -> str:
  431. """The project that contains the artifact collection."""
  432. return self._current.project
  433. @normalize_exceptions
  434. def artifacts(self, per_page: int = 50) -> Artifacts:
  435. """Get all artifacts in the collection."""
  436. return Artifacts(
  437. client=self.client,
  438. entity=self.entity,
  439. project=self.project,
  440. # Use the saved name and type, as they're mutable attributes
  441. # and may have been edited locally.
  442. collection_name=self._saved.name,
  443. type=self._saved.type,
  444. per_page=per_page,
  445. )
  446. @property
  447. def aliases(self) -> list[str]:
  448. """The aliases for all artifact versions contained in this collection."""
  449. if self._saved.aliases is None:
  450. aliases = list(
  451. _ArtifactCollectionAliases(self.client, collection_id=self.id)
  452. )
  453. self._saved = self._saved.model_copy(update={"aliases": aliases})
  454. self._current = self._current.model_copy(update={"aliases": aliases})
  455. return list(self._saved.aliases)
  456. @property
  457. def created_at(self) -> str:
  458. """The creation date of the artifact collection."""
  459. return self._saved.created_at
  460. @property
  461. def updated_at(self) -> str | None:
  462. """The date at which the artifact collection was last updated."""
  463. return self._saved.updated_at
  464. def load(
  465. self, entity: str, project: str, type_: str, name: str
  466. ) -> ArtifactCollectionFragment:
  467. """Fetch and return the validated artifact collection data from W&B.
  468. <!-- lazydoc-ignore: internal -->
  469. """
  470. from wandb.sdk.artifacts._generated import (
  471. PROJECT_ARTIFACT_COLLECTION_GQL,
  472. ProjectArtifactCollection,
  473. )
  474. gql_op = gql(PROJECT_ARTIFACT_COLLECTION_GQL)
  475. gql_vars = {"entity": entity, "project": project, "type": type_, "name": name}
  476. data = self.client.execute(gql_op, variable_values=gql_vars)
  477. result = ProjectArtifactCollection.model_validate(data)
  478. if not (
  479. result.project
  480. and (proj := result.project)
  481. and (artifact_type := proj.artifact_type)
  482. and (collection := artifact_type.artifact_collection)
  483. ):
  484. raise ValueError(f"Could not find artifact type {type_!r}")
  485. return collection
  486. @normalize_exceptions
  487. def change_type(self, new_type: str) -> None:
  488. """Deprecated, change type directly with `save` instead."""
  489. from wandb.sdk.artifacts._generated import (
  490. UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL,
  491. MoveArtifactSequenceInput,
  492. )
  493. from wandb.sdk.artifacts._validators import validate_artifact_type
  494. warn_and_record_deprecation(
  495. feature=Deprecated(artifact_collection__change_type=True),
  496. message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.",
  497. )
  498. if (old_type := self._saved.type) != new_type:
  499. try:
  500. validate_artifact_type(old_type, self.name)
  501. except ValueError as e:
  502. raise ValueError(
  503. f"The current type {old_type!r} is an internal type and cannot be changed."
  504. ) from e
  505. # Check that the new type is not going to conflict with internal types
  506. new_type = validate_artifact_type(new_type, self.name)
  507. if not self.is_sequence():
  508. raise ValueError("Artifact collection needs to be a sequence")
  509. termlog(f"Changing artifact collection type of {old_type!r} to {new_type!r}")
  510. gql_op = gql(UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL)
  511. gql_input = MoveArtifactSequenceInput(
  512. artifact_sequence_id=self.id,
  513. destination_artifact_type_name=new_type,
  514. )
  515. self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()})
  516. self._saved.type = new_type
  517. self._current.type = new_type
  518. def is_sequence(self) -> bool:
  519. """Return whether the artifact collection is a sequence."""
  520. return self._saved.is_sequence
  521. @normalize_exceptions
  522. def delete(self) -> None:
  523. """Delete the entire artifact collection."""
  524. from wandb.sdk.artifacts._generated import (
  525. DELETE_ARTIFACT_PORTFOLIO_GQL,
  526. DELETE_ARTIFACT_SEQUENCE_GQL,
  527. )
  528. gql_op = gql(
  529. DELETE_ARTIFACT_SEQUENCE_GQL
  530. if self.is_sequence()
  531. else DELETE_ARTIFACT_PORTFOLIO_GQL
  532. )
  533. self.client.execute(gql_op, variable_values={"id": self.id})
  534. @property
  535. def description(self) -> str | None:
  536. """A description of the artifact collection."""
  537. return self._current.description
  538. @description.setter
  539. def description(self, description: str | None) -> None:
  540. """Set the description of the artifact collection."""
  541. self._current.description = description
  542. @property
  543. def tags(self) -> list[str]:
  544. """The tags associated with the artifact collection."""
  545. return self._current.tags
  546. @tags.setter
  547. def tags(self, tags: Collection[str]) -> None:
  548. """Set the tags associated with the artifact collection."""
  549. self._current.tags = tags
  550. @property
  551. def name(self) -> str:
  552. """The name of the artifact collection."""
  553. return self._current.name
  554. @name.setter
  555. def name(self, name: str) -> None:
  556. """Set the name of the artifact collection."""
  557. self._current.name = name
  558. @property
  559. def type(self):
  560. """Returns the type of the artifact collection."""
  561. return self._current.type
  562. @type.setter
  563. def type(self, type: str) -> None:
  564. """Set the type of the artifact collection."""
  565. if not self.is_sequence():
  566. raise ValueError(
  567. "Type can only be changed if the artifact collection is a sequence."
  568. )
  569. self._current.type = type
  570. def _update_collection(self) -> None:
  571. from wandb.sdk.artifacts._generated import (
  572. UPDATE_ARTIFACT_PORTFOLIO_GQL,
  573. UPDATE_ARTIFACT_SEQUENCE_GQL,
  574. UpdateArtifactPortfolioInput,
  575. UpdateArtifactSequenceInput,
  576. )
  577. if self.is_sequence():
  578. gql_op = gql(UPDATE_ARTIFACT_SEQUENCE_GQL)
  579. gql_input = UpdateArtifactSequenceInput(
  580. artifact_sequence_id=self.id,
  581. name=self.name,
  582. description=self.description,
  583. )
  584. else:
  585. gql_op = gql(UPDATE_ARTIFACT_PORTFOLIO_GQL)
  586. gql_input = UpdateArtifactPortfolioInput(
  587. artifact_portfolio_id=self.id,
  588. name=self.name,
  589. description=self.description,
  590. )
  591. self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()})
  592. self._saved.name = self._current.name
  593. self._saved.description = self._current.description
  594. self._saved.updated_at = self._current.updated_at
  595. def _update_sequence_type(self) -> None:
  596. from wandb.sdk.artifacts._generated import (
  597. UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL,
  598. MoveArtifactSequenceInput,
  599. )
  600. gql_op = gql(UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL)
  601. gql_input = MoveArtifactSequenceInput(
  602. artifact_sequence_id=self.id,
  603. destination_artifact_type_name=self.type,
  604. )
  605. self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()})
  606. self._saved.type = self._current.type
  607. def _add_tags(self, tag_names: Iterable[str]) -> None:
  608. from wandb.sdk.artifacts._generated import (
  609. ADD_ARTIFACT_COLLECTION_TAGS_GQL,
  610. CreateArtifactCollectionTagAssignmentsInput,
  611. )
  612. gql_op = gql(ADD_ARTIFACT_COLLECTION_TAGS_GQL)
  613. gql_input = CreateArtifactCollectionTagAssignmentsInput(
  614. entity_name=self.entity,
  615. project_name=self.project,
  616. artifact_collection_name=self._saved.name,
  617. tags=[{"tagName": tag} for tag in tag_names],
  618. )
  619. self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()})
  620. def _delete_tags(self, tag_names: Iterable[str]) -> None:
  621. from wandb.sdk.artifacts._generated import (
  622. DELETE_ARTIFACT_COLLECTION_TAGS_GQL,
  623. DeleteArtifactCollectionTagAssignmentsInput,
  624. )
  625. gql_op = gql(DELETE_ARTIFACT_COLLECTION_TAGS_GQL)
  626. gql_input = DeleteArtifactCollectionTagAssignmentsInput(
  627. entity_name=self.entity,
  628. project_name=self.project,
  629. artifact_collection_name=self._saved.name,
  630. tags=[{"tagName": tag} for tag in tag_names],
  631. )
  632. self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()})
  633. @normalize_exceptions
  634. def save(self) -> None:
  635. """Persist any changes made to the artifact collection."""
  636. from wandb.sdk.artifacts._validators import validate_artifact_type
  637. if (old_type := self._saved.type) != (new_type := self.type):
  638. try:
  639. validate_artifact_type(new_type, self.name)
  640. except ValueError as e:
  641. reason = str(e)
  642. raise ValueError(
  643. f"Failed to save artifact collection {self.name!r}: {reason}"
  644. ) from e
  645. try:
  646. validate_artifact_type(old_type, self.name)
  647. except ValueError as e:
  648. reason = f"The current type {old_type!r} is an internal type and cannot be changed."
  649. raise ValueError(
  650. f"Failed to save artifact collection {self.name!r}: {reason}"
  651. ) from e
  652. # FIXME: Consider consolidating the multiple GQL mutations into a single call.
  653. self._update_collection()
  654. if self.is_sequence() and (old_type != new_type):
  655. self._update_sequence_type()
  656. if (new_tags := set(self._current.tags)) != (old_tags := set(self._saved.tags)):
  657. if added_tags := (new_tags - old_tags):
  658. self._add_tags(added_tags)
  659. if deleted_tags := (old_tags - new_tags):
  660. self._delete_tags(deleted_tags)
  661. self._saved.tags = copy(new_tags)
  662. def __repr__(self) -> str:
  663. return f"<ArtifactCollection {self.name} ({self.type})>"
  664. class _ArtifactEdgeGeneric(Edge[TNode]):
  665. version: str # Extra field defined only on VersionedArtifactEdge
  666. class _ArtifactConnectionGeneric(ConnectionWithTotal[TNode]):
  667. edges: List[_ArtifactEdgeGeneric] # noqa: UP006
  668. class Artifacts(SizedRelayPaginator["ArtifactFragment", "Artifact"]):
  669. """An iterable collection of artifact versions associated with a project.
  670. Optionally pass in filters to narrow down the results based on specific criteria.
  671. Args:
  672. client: The client instance to use for querying W&B.
  673. entity: The entity (user or team) that owns the project.
  674. project: The name of the project to query for artifacts.
  675. collection_name: The name of the artifact collection to query.
  676. type: The type of the artifacts to query. Common examples include
  677. "dataset" or "model".
  678. filters: Optional mapping of filters to apply to the query.
  679. order: Optional string to specify the order of the results.
  680. per_page: The number of artifact versions to fetch per page. Default is 50.
  681. tags: Optional string or list of strings to filter artifacts by tags.
  682. <!-- lazydoc-ignore-init: internal -->
  683. """
  684. QUERY: Document # Must be set per-instance
  685. # Loosely-annotated to avoid importing heavy types at module import time.
  686. last_response: _ArtifactConnectionGeneric | None
  687. def __init__(
  688. self,
  689. client: RetryingClient,
  690. entity: str,
  691. project: str,
  692. collection_name: str,
  693. type: str,
  694. filters: Mapping[str, Any] | None = None,
  695. order: str | None = None,
  696. per_page: int = 50,
  697. tags: str | list[str] | None = None,
  698. ):
  699. from wandb.sdk.artifacts._generated import PROJECT_ARTIFACTS_GQL
  700. self.QUERY = gql(PROJECT_ARTIFACTS_GQL)
  701. self.entity = entity
  702. self.collection_name = collection_name
  703. self.type = type
  704. self.project = project
  705. self.filters = {"state": "COMMITTED"} if filters is None else filters
  706. self.tags = always_list(tags or [])
  707. self.order = order
  708. variables = {
  709. "entity": self.entity,
  710. "project": self.project,
  711. "order": self.order,
  712. "type": self.type,
  713. "collection": self.collection_name,
  714. "filters": json.dumps(self.filters),
  715. }
  716. super().__init__(client, variables=variables, per_page=per_page)
  717. @override
  718. def _update_response(self) -> None:
  719. from wandb.sdk.artifacts._generated import ArtifactFragment, ProjectArtifacts
  720. data = self.client.execute(self.QUERY, variable_values=self.variables)
  721. result = ProjectArtifacts.model_validate(data)
  722. # Extract the inner `*Connection` result for faster/easier access.
  723. if not (
  724. (proj := result.project)
  725. and (type_ := proj.artifact_type)
  726. and (collection := type_.artifact_collection)
  727. and (conn := collection.artifacts)
  728. ):
  729. raise ValueError(f"Unable to parse {nameof(type(self))!r} response data")
  730. self.last_response = _ArtifactConnectionGeneric[
  731. ArtifactFragment
  732. ].model_validate(conn)
  733. # FIXME: For now, we deliberately override the signatures of:
  734. # - `_convert()`
  735. # - `convert_objects()`
  736. # ... since the prior implementation must get `version` from the GQL edge
  737. # (i.e. `edge.version`), which lives outside of the GQL node (`edge.node`).
  738. #
  739. # In the future, we should move to fetching artifacts via (GQL) artifactMemberships,
  740. # not (GQL) artifacts, so we don't have to deal with this hack.
  741. @override
  742. def _convert(self, edge: _ArtifactEdgeGeneric[ArtifactFragment]) -> Artifact:
  743. from wandb.sdk.artifacts._validators import FullArtifactPath
  744. from wandb.sdk.artifacts.artifact import Artifact
  745. return Artifact._from_attrs(
  746. path=FullArtifactPath(
  747. prefix=self.entity,
  748. project=self.project,
  749. name=f"{self.collection_name}:{edge.version}",
  750. ),
  751. src_art=edge.node,
  752. client=self.client,
  753. )
  754. @override
  755. def convert_objects(self) -> list[Artifact]:
  756. """Convert the raw response data into a list of wandb.Artifact objects.
  757. <!-- lazydoc-ignore: internal -->
  758. """
  759. if (conn := self.last_response) is None:
  760. return []
  761. artifacts = (self._convert(edge) for edge in conn.edges if edge.node)
  762. required_tags = set(self.tags or [])
  763. return [art for art in artifacts if required_tags.issubset(art.tags)]
  764. class RunArtifacts(SizedRelayPaginator["ArtifactFragment", "Artifact"]):
  765. """An iterable collection of artifacts associated with a specific run.
  766. <!-- lazydoc-ignore-init: internal -->
  767. """
  768. QUERY: Document # Must be set per-instance
  769. last_response: ConnectionWithTotal[ArtifactFragment] | None
  770. def __init__(
  771. self,
  772. client: RetryingClient,
  773. run: Run,
  774. mode: Literal["logged", "used"] = "logged",
  775. per_page: int = 50,
  776. ):
  777. try:
  778. query_str = _run_artifacts_mode_to_gql()[mode]
  779. except LookupError:
  780. raise ValueError("mode must be logged or used")
  781. else:
  782. self.QUERY = gql(query_str)
  783. self.run = run
  784. variables = {"entity": run.entity, "project": run.project, "run": run.id}
  785. super().__init__(client, variables=variables, per_page=per_page)
  786. @override
  787. def _update_response(self) -> None:
  788. from wandb.sdk.artifacts._models.pagination import RunArtifactConnection
  789. data = self.client.execute(self.QUERY, variable_values=self.variables)
  790. # Extract the inner `*Connection` result for faster/easier access.
  791. inner_data = data["project"]["run"]["artifacts"]
  792. self.last_response = RunArtifactConnection.model_validate(inner_data)
  793. def _convert(self, node: ArtifactFragment) -> Artifact | None:
  794. from wandb.sdk.artifacts._validators import FullArtifactPath
  795. from wandb.sdk.artifacts.artifact import Artifact
  796. if node.artifact_sequence.project is None:
  797. return None
  798. return Artifact._from_attrs(
  799. path=FullArtifactPath(
  800. prefix=node.artifact_sequence.project.entity.name,
  801. project=node.artifact_sequence.project.name,
  802. name=f"{node.artifact_sequence.name}:v{node.version_index}",
  803. ),
  804. src_art=node,
  805. client=self.client,
  806. )
  807. class ArtifactFiles(SizedRelayPaginator["FileFragment", "File"]):
  808. """A paginator for files in an artifact.
  809. <!-- lazydoc-ignore-init: internal -->
  810. """
  811. QUERY: Document # Must be set per-instance
  812. last_response: ArtifactFileConnection | None
  813. def __init__(
  814. self,
  815. client: RetryingClient,
  816. artifact: Artifact,
  817. names: Sequence[str] | None = None,
  818. per_page: int = 50,
  819. ):
  820. from wandb.sdk.artifacts._generated import (
  821. GET_ARTIFACT_FILES_GQL,
  822. GET_ARTIFACT_MEMBERSHIP_FILES_GQL,
  823. )
  824. from wandb.sdk.artifacts._gqlutils import server_supports
  825. self.query_via_membership = server_supports(
  826. client, pb.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
  827. )
  828. self.artifact = artifact
  829. if self.query_via_membership:
  830. query_str = GET_ARTIFACT_MEMBERSHIP_FILES_GQL
  831. variables = {
  832. "entity": artifact.entity,
  833. "project": artifact.project,
  834. "collection": artifact.name.split(":")[0],
  835. "alias": artifact.version,
  836. "fileNames": names,
  837. }
  838. else:
  839. query_str = GET_ARTIFACT_FILES_GQL
  840. variables = {
  841. "entity": artifact.source_entity,
  842. "project": artifact.source_project,
  843. "name": artifact.source_name,
  844. "type": artifact.type,
  845. "fileNames": names,
  846. }
  847. omit_fields = (
  848. None
  849. if server_supports(client, pb.TOTAL_COUNT_IN_FILE_CONNECTION)
  850. else {"totalCount"}
  851. )
  852. self.QUERY = gql_compat(query_str, omit_fields=omit_fields)
  853. super().__init__(client, variables=variables, per_page=per_page)
  854. @override
  855. def _update_response(self) -> None:
  856. from wandb.sdk.artifacts._generated import (
  857. GetArtifactFiles,
  858. GetArtifactMembershipFiles,
  859. )
  860. from wandb.sdk.artifacts._models.pagination import ArtifactFileConnection
  861. data = self.client.execute(self.QUERY, variable_values=self.variables)
  862. # Extract the inner `*Connection` result for faster/easier access.
  863. if self.query_via_membership:
  864. result = GetArtifactMembershipFiles.model_validate(data)
  865. conn = result.project.artifact_collection.artifact_membership.files
  866. else:
  867. result = GetArtifactFiles.model_validate(data)
  868. conn = result.project.artifact_type.artifact.files
  869. if conn is None:
  870. raise ValueError(f"Unable to parse {nameof(type(self))!r} response data")
  871. self.last_response = ArtifactFileConnection.model_validate(conn)
  872. @property
  873. def path(self) -> list[str]:
  874. """Returns the path of the artifact."""
  875. return [self.artifact.entity, self.artifact.project, self.artifact.name]
  876. def _convert(self, node: FileFragment) -> File:
  877. return File(self.client, attrs=node.model_dump(exclude_unset=True))
  878. def __repr__(self) -> str:
  879. path_str = "/".join(self.path)
  880. try:
  881. total = len(self)
  882. except NotImplementedError:
  883. # Older server versions don't correctly support totalCount
  884. return f"<ArtifactFiles {path_str}>"
  885. else:
  886. return f"<ArtifactFiles {path_str} ({total})>"