registry.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, Any, Literal
  3. from pydantic import PositiveInt
  4. from typing_extensions import Self, assert_never
  5. from wandb_gql import gql
  6. import wandb
  7. from wandb._analytics import tracked
  8. from wandb._strutils import nameof
  9. from wandb.apis.public.teams import Team
  10. from wandb.apis.public.users import User
  11. from wandb.proto import wandb_internal_pb2 as pb
  12. from wandb.sdk.artifacts._models import RegistryData
  13. from ._freezable_list import AddOnlyArtifactTypesList
  14. from ._members import (
  15. MemberId,
  16. MemberKind,
  17. MemberRole,
  18. TeamMember,
  19. UserMember,
  20. parse_member_ids,
  21. )
  22. from ._utils import (
  23. Visibility,
  24. fetch_org_entity_from_organization,
  25. prepare_artifact_types_input,
  26. )
  27. from .registries_search import Collections, Versions
  28. if TYPE_CHECKING:
  29. from wandb.apis.public.api import RetryingClient
  30. from wandb.sdk.artifacts._generated import RegistryFragment
  31. class Registry:
  32. """A single registry in the Registry."""
  33. _saved: RegistryData
  34. """The saved registry data as last fetched from the W&B server."""
  35. _current: RegistryData
  36. """The local, editable registry data."""
  37. def __init__(
  38. self,
  39. client: RetryingClient,
  40. organization: str,
  41. entity: str,
  42. name: str,
  43. attrs: RegistryFragment | None = None,
  44. ):
  45. self.client = client
  46. if attrs is None:
  47. # FIXME: This is awkward and bypasses validation which seems shaky.
  48. # Reconsider the init signature of `Registry` so this isn't necessary?
  49. draft = RegistryData.model_construct(
  50. organization=organization, entity=entity, name=name
  51. )
  52. self._saved = draft
  53. self._current = draft.model_copy(deep=True)
  54. else:
  55. self._update_attributes(attrs)
  56. def _update_attributes(self, fragment: RegistryFragment) -> None:
  57. """Update instance attributes from a GraphQL fragment."""
  58. saved = RegistryData.from_fragment(fragment)
  59. self._saved = saved
  60. self._current = saved.model_copy(deep=True)
  61. @property
  62. def id(self) -> str:
  63. """The unique ID for this registry."""
  64. return self._current.id
  65. @property
  66. def full_name(self) -> str:
  67. """Full name of the registry including the `wandb-registry-` prefix."""
  68. return self._current.full_name
  69. @property
  70. def name(self) -> str:
  71. """Name of the registry without the `wandb-registry-` prefix."""
  72. return self._current.name
  73. @name.setter
  74. def name(self, value: str):
  75. self._current.name = value
  76. @property
  77. def entity(self) -> str:
  78. """Organization entity of the registry."""
  79. return self._current.entity
  80. @property
  81. def organization(self) -> str:
  82. """Organization name of the registry."""
  83. return self._current.organization
  84. @property
  85. def description(self) -> str | None:
  86. """Description of the registry."""
  87. return self._current.description
  88. @description.setter
  89. def description(self, value: str) -> None:
  90. """Set the description of the registry."""
  91. self._current.description = value
  92. @property
  93. def allow_all_artifact_types(self) -> bool:
  94. """Return whether all artifact types are allowed in the registry.
  95. If `True`, artifacts of any type can be added. If `False`, artifacts are
  96. restricted to the types listed in `artifact_types`.
  97. """
  98. return self._current.allow_all_artifact_types
  99. @allow_all_artifact_types.setter
  100. def allow_all_artifact_types(self, value: bool) -> None:
  101. """Set whether all artifact types are allowed in the registry."""
  102. self._current.allow_all_artifact_types = value
  103. @property
  104. def artifact_types(self) -> AddOnlyArtifactTypesList:
  105. """Returns the artifact types allowed in the registry.
  106. If `allow_all_artifact_types` is `True` then `artifact_types` reflects the
  107. types previously saved or currently used in the registry.
  108. If `allow_all_artifact_types` is `False` then artifacts are restricted to the
  109. types in `artifact_types`.
  110. Note:
  111. Previously saved artifact types cannot be removed.
  112. Example:
  113. ```python
  114. import wandb
  115. registry = wandb.Api().create_registry()
  116. registry.artifact_types.append("model")
  117. registry.save() # once saved, the artifact type `model` cannot be removed
  118. registry.artifact_types.append("accidentally_added")
  119. registry.artifact_types.remove(
  120. "accidentally_added"
  121. ) # Types can only be removed if it has not been saved yet
  122. ```
  123. """
  124. return self._current.artifact_types
  125. @property
  126. def created_at(self) -> str:
  127. """Timestamp of when the registry was created."""
  128. return self._current.created_at
  129. @property
  130. def updated_at(self) -> str:
  131. """Timestamp of when the registry was last updated."""
  132. return self._current.updated_at
  133. @property
  134. def path(self) -> list[str]:
  135. return [self.entity, self.full_name]
  136. @property
  137. def visibility(self) -> Literal["organization", "restricted"]:
  138. """Visibility of the registry.
  139. Returns:
  140. Literal["organization", "restricted"]: The visibility level.
  141. - "organization": Anyone in the organization can view this registry.
  142. You can edit their roles later from the settings in the UI.
  143. - "restricted": Only invited members via the UI can access this registry.
  144. Public sharing is disabled.
  145. """
  146. return self._current.visibility.name
  147. @visibility.setter
  148. def visibility(self, value: Literal["organization", "restricted"]):
  149. """Set the visibility of the registry.
  150. Args:
  151. value: The visibility level. Options are:
  152. - "organization": Anyone in the organization can view this registry.
  153. You can edit their roles later from the settings in the UI.
  154. - "restricted": Only invited members via the UI can access this registry.
  155. Public sharing is disabled.
  156. """
  157. self._current.visibility = value
  158. @tracked
  159. def collections(
  160. self, filter: dict[str, Any] | None = None, per_page: PositiveInt = 100
  161. ) -> Collections:
  162. """Returns the collections belonging to the registry."""
  163. return Collections(
  164. client=self.client,
  165. organization=self.organization,
  166. registry_filter={"name": self.full_name},
  167. collection_filter=filter,
  168. per_page=per_page,
  169. )
  170. @tracked
  171. def versions(
  172. self, filter: dict[str, Any] | None = None, per_page: PositiveInt = 100
  173. ) -> Versions:
  174. """Returns the versions belonging to the registry."""
  175. return Versions(
  176. client=self.client,
  177. organization=self.organization,
  178. registry_filter={"name": self.full_name},
  179. collection_filter=None,
  180. artifact_filter=filter,
  181. per_page=per_page,
  182. )
  183. @classmethod
  184. @tracked
  185. def create(
  186. cls,
  187. client: RetryingClient,
  188. organization: str,
  189. name: str,
  190. visibility: Literal["organization", "restricted"],
  191. description: str | None = None,
  192. artifact_types: list[str] | None = None,
  193. ) -> Self:
  194. """Create a new registry.
  195. The registry name must be unique within the organization.
  196. This function should be called using `api.create_registry()`
  197. Args:
  198. client: The GraphQL client.
  199. organization: The name of the organization.
  200. name: The name of the registry (without the `wandb-registry-` prefix).
  201. visibility: The visibility level ('organization' or 'restricted').
  202. description: An optional description for the registry.
  203. artifact_types: An optional list of allowed artifact types.
  204. Returns:
  205. Registry: The newly created Registry object.
  206. Raises:
  207. ValueError: If a registry with the same name already exists in the
  208. organization or if the creation fails.
  209. """
  210. from wandb.sdk.artifacts._generated import (
  211. UPSERT_REGISTRY_GQL,
  212. UpsertModelInput,
  213. UpsertRegistry,
  214. )
  215. from wandb.sdk.artifacts._validators import (
  216. REGISTRY_PREFIX,
  217. validate_project_name,
  218. )
  219. failed_msg = (
  220. f"Failed to create registry {name!r} in organization {organization!r}."
  221. )
  222. org_entity = fetch_org_entity_from_organization(client, organization)
  223. gql_op = gql(UPSERT_REGISTRY_GQL)
  224. gql_input = UpsertModelInput(
  225. description=description,
  226. entity_name=org_entity,
  227. name=validate_project_name(f"{REGISTRY_PREFIX}{name}"),
  228. access=Visibility.from_python(visibility).value,
  229. allow_all_artifact_types_in_registry=not artifact_types,
  230. artifact_types=prepare_artifact_types_input(artifact_types),
  231. )
  232. gql_vars = {"input": gql_input.model_dump()}
  233. try:
  234. data = client.execute(gql_op, variable_values=gql_vars)
  235. result = UpsertRegistry.model_validate(data).upsert_model
  236. except Exception as e:
  237. raise ValueError(failed_msg) from e
  238. if not (result and result.inserted and (registry_project := result.project)):
  239. raise ValueError(failed_msg)
  240. return cls(
  241. client,
  242. organization=organization,
  243. entity=org_entity,
  244. name=name,
  245. attrs=registry_project,
  246. )
  247. @tracked
  248. def delete(self) -> None:
  249. """Delete the registry. This is irreversible."""
  250. from wandb.sdk.artifacts._generated import DELETE_REGISTRY_GQL, DeleteRegistry
  251. failed_msg = f"Failed to delete registry {self.name!r} in organization {self.organization!r}"
  252. gql_op = gql(DELETE_REGISTRY_GQL)
  253. gql_vars = {"id": self.id}
  254. try:
  255. data = self.client.execute(gql_op, variable_values=gql_vars)
  256. result = DeleteRegistry.model_validate(data).delete_model
  257. except Exception as e:
  258. raise ValueError(failed_msg) from e
  259. if not (result and result.success):
  260. raise ValueError(failed_msg)
  261. @tracked
  262. def load(self) -> None:
  263. """Load registry attributes from the backend."""
  264. from wandb.sdk.artifacts._generated import FETCH_REGISTRY_GQL, FetchRegistry
  265. failed_msg = (
  266. f"Failed to load registry {self.name!r} in organization"
  267. f" {self.organization!r}."
  268. )
  269. gql_op = gql(FETCH_REGISTRY_GQL)
  270. gql_vars = {"name": self.full_name, "entity": self.entity}
  271. try:
  272. data = self.client.execute(gql_op, variable_values=gql_vars)
  273. result = FetchRegistry.model_validate(data)
  274. except Exception as e:
  275. raise ValueError(failed_msg) from e
  276. if not ((entity := result.entity) and (registry_project := entity.project)):
  277. raise ValueError(failed_msg)
  278. self._update_attributes(registry_project)
  279. @tracked
  280. def save(self) -> None:
  281. """Save registry attributes to the backend."""
  282. from wandb.sdk.artifacts._generated import (
  283. RENAME_REGISTRY_GQL,
  284. UPSERT_REGISTRY_GQL,
  285. RenameProjectInput,
  286. RenameRegistry,
  287. UpsertModelInput,
  288. UpsertRegistry,
  289. )
  290. from wandb.sdk.artifacts._gqlutils import server_supports
  291. from wandb.sdk.artifacts._validators import validate_project_name
  292. if not server_supports(
  293. self.client, pb.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION
  294. ):
  295. raise RuntimeError(
  296. "Saving the registry is not enabled on this wandb server version. "
  297. "Please upgrade your server version or contact support at support@wandb.com."
  298. )
  299. # If `artifact_types.draft` has items, the user added types that are not
  300. # yet saved.
  301. if (
  302. new_artifact_types := self.artifact_types.draft
  303. ) and self.allow_all_artifact_types:
  304. raise ValueError(
  305. f"Cannot update artifact types when `allows_all_artifact_types` is {True!r}. Set it to {False!r} first."
  306. )
  307. failed_msg = f"Failed to save registry {self.name!r} in organization {self.organization!r}"
  308. old_project_name = validate_project_name(self._saved.full_name)
  309. new_project_name = validate_project_name(self._current.full_name)
  310. upsert_op = gql(UPSERT_REGISTRY_GQL)
  311. upsert_input = UpsertModelInput(
  312. description=self.description,
  313. entity_name=self.entity,
  314. name=old_project_name,
  315. access=self._current.visibility.value,
  316. allow_all_artifact_types_in_registry=self.allow_all_artifact_types,
  317. artifact_types=prepare_artifact_types_input(new_artifact_types),
  318. )
  319. upsert_vars = {"input": upsert_input.model_dump()}
  320. try:
  321. data = self.client.execute(upsert_op, variable_values=upsert_vars)
  322. result = UpsertRegistry.model_validate(data).upsert_model
  323. except Exception as e:
  324. raise ValueError(failed_msg) from e
  325. if result and result.inserted:
  326. # This should only trigger if `_saved_name` was modified unexpectedly.
  327. wandb.termlog(
  328. f"Created registry {self.name!r} in organization {self.organization!r} on save"
  329. )
  330. if not (result and (registry_project := result.project)):
  331. raise ValueError(failed_msg)
  332. self._update_attributes(registry_project)
  333. # Update the name of the registry if it has changed
  334. if old_project_name != new_project_name:
  335. rename_op = gql(RENAME_REGISTRY_GQL)
  336. rename_input = RenameProjectInput(
  337. entity_name=self.entity,
  338. old_project_name=old_project_name,
  339. new_project_name=new_project_name,
  340. )
  341. rename_vars = {"input": rename_input.model_dump()}
  342. data = self.client.execute(rename_op, variable_values=rename_vars)
  343. result = RenameRegistry.model_validate(data).rename_project
  344. if not (result and (registry_project := result.project)):
  345. raise ValueError(failed_msg)
  346. if result.inserted:
  347. # This should only trigger if `_saved_name` was modified unexpectedly.
  348. wandb.termlog(f"Created new registry {self.name!r} on save")
  349. self._update_attributes(registry_project)
  350. def members(self) -> list[UserMember | TeamMember]:
  351. """Returns the current members (users and teams) of this registry."""
  352. return [*self.user_members(), *self.team_members()]
  353. def user_members(self) -> list[UserMember]:
  354. """Returns the current member users of this registry."""
  355. from wandb.sdk.artifacts._generated import (
  356. REGISTRY_USER_MEMBERS_GQL,
  357. RegistryUserMembers,
  358. )
  359. gql_op = gql(REGISTRY_USER_MEMBERS_GQL)
  360. gql_vars = {"project": self.full_name, "entity": self.entity}
  361. data = self.client.execute(gql_op, variable_values=gql_vars)
  362. result = RegistryUserMembers.model_validate(data)
  363. if not (project := result.project):
  364. raise ValueError(f"Failed to fetch user members for registry {self.name!r}")
  365. return [
  366. UserMember(
  367. user=User(
  368. client=self.client,
  369. # The `User` class requires an unstructured attribute dict.
  370. # Exclude `.role`, which is specific to this registry membership.
  371. attrs=m.model_dump(exclude_none=True, exclude={"role"}),
  372. ),
  373. role=m.role.name,
  374. )
  375. for m in project.members
  376. ]
  377. def team_members(self) -> list[TeamMember]:
  378. """Returns the current member teams of this registry."""
  379. from wandb.sdk.artifacts._generated import (
  380. REGISTRY_TEAM_MEMBERS_GQL,
  381. RegistryTeamMembers,
  382. )
  383. gql_op = gql(REGISTRY_TEAM_MEMBERS_GQL)
  384. gql_vars = {"project": self.full_name, "entity": self.entity}
  385. data = self.client.execute(gql_op, variable_values=gql_vars)
  386. result = RegistryTeamMembers.model_validate(data)
  387. if not (project := result.project):
  388. raise ValueError(f"Failed to fetch team members for registry {self.name!r}")
  389. return [
  390. TeamMember(
  391. team=Team(
  392. client=self.client,
  393. name=m.team.name,
  394. # The `Team` class currently requires an unstructured attribute dict.
  395. attrs=m.team.model_dump(exclude_none=True),
  396. ),
  397. role=m.role.name,
  398. )
  399. for m in project.team_members
  400. ]
  401. def add_members(
  402. self, *members: User | UserMember | Team | TeamMember | str
  403. ) -> Self:
  404. """Adds users or teams to this registry.
  405. Args:
  406. members: The users or teams to add to the registry. Accepts
  407. `User` objects, `Team` objects, or their string IDs.
  408. Returns:
  409. This registry for further method chaining, if needed.
  410. Raises:
  411. TypeError: If no members are passed as arguments.
  412. ValueError: If unable to infer or parse the user or team IDs.
  413. Examples:
  414. ```python
  415. import wandb
  416. api = wandb.Api()
  417. # Fetch an existing registry
  418. registry = api.registry(name="my-registry", organization="my-org")
  419. user1 = api.user(username="some-user")
  420. user2 = api.user(username="other-user")
  421. registry.add_members(user1, user2)
  422. my_team = api.team(name="my-team")
  423. registry.add_members(my_team)
  424. ```
  425. """
  426. from wandb.sdk.artifacts._generated import (
  427. CREATE_REGISTRY_MEMBERS_GQL,
  428. CreateProjectMembersInput,
  429. CreateRegistryMembers,
  430. )
  431. if not members:
  432. raise TypeError(
  433. f"Must provide at least one member to {nameof(self.add_members)!r}."
  434. )
  435. user_ids, team_ids = parse_member_ids(members)
  436. gql_op = gql(CREATE_REGISTRY_MEMBERS_GQL)
  437. gql_input = CreateProjectMembersInput(
  438. user_ids=user_ids, team_ids=team_ids, project_id=self.id
  439. )
  440. gql_vars = {"input": gql_input.model_dump()}
  441. data = self.client.execute(gql_op, variable_values=gql_vars)
  442. result = CreateRegistryMembers.model_validate(data).result
  443. if not (result and result.success):
  444. raise ValueError(f"Failed to add members to registry {self.name!r}")
  445. return self
  446. def remove_members(
  447. self, *members: User | UserMember | Team | TeamMember | str
  448. ) -> Self:
  449. """Removes users or teams from this registry.
  450. Args:
  451. members: The users or teams to remove from the registry. Accepts
  452. `User` objects, `Team` objects, or their string IDs.
  453. Returns:
  454. This registry for further method chaining, if needed.
  455. Raises:
  456. TypeError: If no members are passed as arguments.
  457. ValueError: If unable to infer or parse the user or team IDs.
  458. Examples:
  459. ```python
  460. import wandb
  461. api = wandb.Api()
  462. # Fetch an existing registry
  463. registry = api.registry(name="my-registry", organization="my-org")
  464. user1 = api.user(username="some-user")
  465. user2 = api.user(username="other-user")
  466. registry.remove_members(user1, user2)
  467. old_team = api.team(name="old-team")
  468. registry.remove_members(old_team)
  469. ```
  470. """
  471. from wandb.sdk.artifacts._generated import (
  472. DELETE_REGISTRY_MEMBERS_GQL,
  473. DeleteProjectMembersInput,
  474. DeleteRegistryMembers,
  475. )
  476. if not members:
  477. raise TypeError(
  478. f"Must provide at least one member to {nameof(self.add_members)!r}."
  479. )
  480. user_ids, team_ids = parse_member_ids(members)
  481. gql_op = gql(DELETE_REGISTRY_MEMBERS_GQL)
  482. gql_input = DeleteProjectMembersInput(
  483. user_ids=user_ids, team_ids=team_ids, project_id=self.id
  484. )
  485. gql_vars = {"input": gql_input.model_dump()}
  486. data = self.client.execute(gql_op, variable_values=gql_vars)
  487. result = DeleteRegistryMembers.model_validate(data).result
  488. if not (result and result.success):
  489. raise ValueError(f"Failed to remove members from registry {self.name!r}")
  490. return self
  491. def update_member(
  492. self,
  493. member: User | UserMember | Team | TeamMember | str,
  494. role: MemberRole | str,
  495. ) -> Self:
  496. """Updates the role of a member (user or team) within this registry.
  497. Args:
  498. member: The user or team to update the role of.
  499. Accepts a `User` object, `Team` object, or their string ID.
  500. role: The new role to assign to the member. May be one of:
  501. - "admin"
  502. - "member"
  503. - "viewer"
  504. - "restricted_viewer" (if supported by the W&B server)
  505. Returns:
  506. This registry for further method chaining, if needed.
  507. Raises:
  508. ValueError: If unable to infer the user or team ID.
  509. Examples:
  510. Make all users in the registry admins
  511. ```python
  512. import wandb
  513. api = wandb.Api()
  514. # Fetch an existing registry
  515. registry = api.registry(name="my-registry", organization="my-org")
  516. for member in registry.user_members():
  517. registry.update_member(member.user, role="admin")
  518. ```
  519. """
  520. from wandb.sdk.artifacts._generated import (
  521. UPDATE_TEAM_REGISTRY_ROLE_GQL,
  522. UPDATE_USER_REGISTRY_ROLE_GQL,
  523. UpdateProjectMemberInput,
  524. UpdateProjectTeamMemberInput,
  525. UpdateTeamRegistryRole,
  526. UpdateUserRegistryRole,
  527. )
  528. id_ = MemberId.from_obj(member)
  529. if id_.kind is MemberKind.USER:
  530. gql_op = gql(UPDATE_USER_REGISTRY_ROLE_GQL)
  531. gql_input = UpdateProjectMemberInput(
  532. user_id=id_.encode(), project_id=self.id, user_project_role=role
  533. )
  534. result_cls = UpdateUserRegistryRole
  535. elif id_.kind is MemberKind.ENTITY:
  536. gql_op = gql(UPDATE_TEAM_REGISTRY_ROLE_GQL)
  537. gql_input = UpdateProjectTeamMemberInput(
  538. team_id=id_.encode(), project_id=self.id, team_project_role=role
  539. )
  540. result_cls = UpdateTeamRegistryRole
  541. else:
  542. assert_never(id_.kind)
  543. gql_vars = {"input": gql_input.model_dump()}
  544. data = self.client.execute(gql_op, variable_values=gql_vars)
  545. result = result_cls.model_validate(data).result
  546. if not (result and result.success):
  547. raise ValueError(
  548. f"Failed to update member {member!r} role to {role!r} in registry {self.name!r}"
  549. )
  550. return self