_members.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. """Types and helpers for managing registry members."""
  2. from __future__ import annotations
  3. from collections import defaultdict
  4. from collections.abc import Iterable
  5. from enum import Enum
  6. from functools import singledispatchmethod
  7. from typing import Literal, Union
  8. from pydantic.dataclasses import dataclass as pydantic_dataclass
  9. from wandb._strutils import b64decode_ascii, b64encode_ascii, nameof
  10. from wandb.sdk.artifacts._models import ArtifactsBase
  11. from ..teams import Team
  12. from ..users import User
  13. class MemberKind(str, Enum):
  14. """Identifies what kind of object a registry member is."""
  15. USER = "User"
  16. ENTITY = "Entity"
  17. TEAM = ENTITY # Convenience alias
  18. class MemberRole(str, Enum):
  19. """Identifies the role of a member."""
  20. ADMIN = "admin"
  21. MEMBER = "member"
  22. VIEWER = "viewer"
  23. RESTRICTED_VIEWER = "restricted_viewer"
  24. class UserMember(ArtifactsBase, arbitrary_types_allowed=True):
  25. kind: Literal[MemberKind.USER] = MemberKind.USER
  26. user: User
  27. role: Union[MemberRole, str] # noqa: UP007
  28. class TeamMember(ArtifactsBase, arbitrary_types_allowed=True):
  29. kind: Literal[MemberKind.ENTITY] = MemberKind.ENTITY
  30. team: Team
  31. role: Union[MemberRole, str] # noqa: UP007
  32. MemberOrId = Union[User, Team, UserMember, TeamMember, str]
  33. """Type hint for a registry member argument that accepts a User, Team, or their ID."""
  34. def parse_member_ids(members: Iterable[MemberOrId]) -> tuple[list[str], list[str]]:
  35. """Returns a tuple of (user_ids, team_ids) from parsing the given objects."""
  36. ids_by_kind: dict[MemberKind, set[str]] = defaultdict(set)
  37. for parsed in map(MemberId.from_obj, members):
  38. ids_by_kind[parsed.kind].add(parsed.encode())
  39. user_ids = ids_by_kind[MemberKind.USER]
  40. team_ids = ids_by_kind[MemberKind.ENTITY]
  41. # Ordering shouldn't matter, but sort anyway for reproducibility and testing
  42. return sorted(user_ids), sorted(team_ids)
  43. @pydantic_dataclass
  44. class MemberId:
  45. kind: MemberKind
  46. index: int
  47. def encode(self) -> str:
  48. """Converts this parsed ID to a base64-encoded GraphQL ID."""
  49. return b64encode_ascii(f"{self.kind.value}:{self.index}")
  50. @singledispatchmethod
  51. @classmethod
  52. def from_obj(cls, obj: MemberOrId, /) -> MemberId:
  53. """Parses `User` or `Team` ID from the argument."""
  54. # Fallback for unexpected types
  55. raise TypeError(
  56. f"Member arg must be a {nameof(User)!r}, {nameof(Team)!r}, or a user/team ID. "
  57. f"Got: {nameof(type(obj))!r}"
  58. )
  59. @from_obj.register(User)
  60. @from_obj.register(Team)
  61. @classmethod
  62. def _from_obj_with_id(cls, obj: User | Team, /) -> MemberId:
  63. # Use the object's string (base64-encoded) GraphQL ID
  64. return cls._from_id(obj.id)
  65. @from_obj.register(UserMember)
  66. @classmethod
  67. def _from_user_member(cls, member: UserMember, /) -> MemberId:
  68. return cls._from_id(member.user.id)
  69. @from_obj.register(TeamMember)
  70. @classmethod
  71. def _from_team_member(cls, member: TeamMember, /) -> MemberId:
  72. return cls._from_id(member.team.id)
  73. @from_obj.register(str)
  74. @classmethod
  75. def _from_id(cls, id_: str, /) -> MemberId:
  76. # Parse the ID to figure out if it's a team or user ID
  77. kind, index = b64decode_ascii(id_).split(":", maxsplit=1)
  78. return cls(kind=kind, index=index)