_gqlutils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. from contextlib import suppress
  3. from dataclasses import dataclass
  4. from functools import lru_cache
  5. from typing import TYPE_CHECKING
  6. from wandb_gql import gql
  7. from wandb._iterutils import one
  8. from wandb.proto.wandb_internal_pb2 import ServerFeature
  9. from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
  10. if TYPE_CHECKING:
  11. from wandb.apis.public import RetryingClient
  12. from wandb.sdk.artifacts._generated.fetch_org_info_from_entity import (
  13. FetchOrgInfoFromEntityEntity,
  14. )
  15. @lru_cache(maxsize=16)
  16. def org_info_from_entity(
  17. client: RetryingClient, entity: str
  18. ) -> FetchOrgInfoFromEntityEntity | None:
  19. """Returns the organization info for a given entity."""
  20. from ._generated import FETCH_ORG_INFO_FROM_ENTITY_GQL, FetchOrgInfoFromEntity
  21. gql_op = gql(FETCH_ORG_INFO_FROM_ENTITY_GQL)
  22. data = client.execute(gql_op, variable_values={"entity": entity})
  23. return FetchOrgInfoFromEntity.model_validate(data).entity
  24. @lru_cache(maxsize=16)
  25. def _server_features(client: RetryingClient) -> dict[str, bool]:
  26. """Returns a mapping of `{server_feature_name (str) -> is_enabled (bool)}`.
  27. Results are cached per client instance.
  28. """
  29. try:
  30. response = client.execute(gql(SERVER_FEATURES_QUERY_GQL))
  31. except Exception as e:
  32. # Unfortunately we currently have to match on the text of the error message,
  33. # as the `gql` client raises `Exception` rather than a more specific error.
  34. if 'Cannot query field "features" on type "ServerInfo".' in str(e):
  35. return {}
  36. raise
  37. result = ServerFeaturesQuery.model_validate(response)
  38. if (server_info := result.server_info) and (features := server_info.features):
  39. return {feat.name: feat.is_enabled for feat in features if feat}
  40. return {}
  41. def server_supports(client: RetryingClient, feature: str | int) -> bool:
  42. """Return whether the current server supports the given feature.
  43. NOTE: This is deprecated. Please use `ServiceApi.feature_enabled()` when
  44. possible, like in all public API code.
  45. Good to use for features that have a fallback mechanism for older servers.
  46. """
  47. # If we're given the protobuf enum value, convert to a string name.
  48. # NOTE: We deliberately use names (str) instead of enum values (int)
  49. # as the keys here, since:
  50. # - the server identifies features by their name, rather than (client-side) enum value
  51. # - the defined list of client-side flags may be behind the server-side list of flags
  52. try:
  53. name = ServerFeature.Name(feature) if isinstance(feature, int) else feature
  54. except ValueError:
  55. return False # Invalid int-like value, assume unsupported
  56. return _server_features(client).get(name) or False
  57. @dataclass(frozen=True)
  58. class OrgInfo:
  59. org_name: str
  60. entity_name: str
  61. def __contains__(self, other: str) -> bool:
  62. return other in {self.org_name, self.entity_name}
  63. def resolve_org_entity_name(
  64. client: RetryingClient,
  65. non_org_entity: str | None,
  66. org_or_entity: str | None = None,
  67. ) -> str:
  68. # Resolve the portfolio's org entity name.
  69. #
  70. # The `org_or_org_entity` parameter may be empty, an org display name, or an
  71. # org entity name.
  72. #
  73. # If the server cannot fetch the portfolio's org name, return the provided
  74. # value or raise an error if it is empty. Otherwise, return the fetched
  75. # value after validating that the given organization, if provided, matches
  76. # either the display or entity name.
  77. if not non_org_entity:
  78. raise ValueError("Entity name is required to resolve org entity name.")
  79. # Fetch candidate orgs to verify or identify the correct orgEntity name.
  80. entity = org_info_from_entity(client, non_org_entity)
  81. # Parse possible organization(s) from the response...
  82. # ----------------------------------------------------------------------------
  83. # If a team entity was provided, a single organization should exist under
  84. # the team/org entity type.
  85. if entity and (org := entity.organization) and (org_entity := org.org_entity):
  86. # Ensure the provided name, if given, matches the org or org entity name before
  87. # returning the org entity.
  88. org_info = OrgInfo(org_name=org.name, entity_name=org_entity.name)
  89. if (not org_or_entity) or (org_or_entity in org_info):
  90. return org_entity.name
  91. # ----------------------------------------------------------------------------
  92. # If a personal entity was provided, the user may belong to multiple
  93. # organizations.
  94. if entity and (user := entity.user) and (orgs := user.organizations):
  95. org_infos = [
  96. OrgInfo(org_name=org.name, entity_name=org_entity.name)
  97. for org in orgs
  98. if (org_entity := org.org_entity)
  99. ]
  100. if org_or_entity:
  101. with suppress(StopIteration):
  102. return next(
  103. info.entity_name for info in org_infos if (org_or_entity in info)
  104. )
  105. if len(org_infos) == 1:
  106. raise ValueError(
  107. f"Expecting the organization name or entity name to match {org_infos[0].org_name!r} "
  108. f"and cannot be linked/fetched with {org_or_entity!r}. "
  109. "Please update the target path with the correct organization name."
  110. )
  111. else:
  112. raise ValueError(
  113. "Personal entity belongs to multiple organizations "
  114. f"and cannot be linked/fetched with {org_or_entity!r}. "
  115. "Please update the target path with the correct organization name "
  116. "or use a team entity in the entity settings."
  117. )
  118. else:
  119. # If no input organization provided, error if entity belongs to:
  120. # - multiple orgs, because we cannot determine which one to use.
  121. # - no orgs, because there's nothing to use.
  122. return one(
  123. (org.entity_name for org in org_infos),
  124. too_short=ValueError(
  125. f"Unable to resolve an organization associated with personal entity: {non_org_entity!r}. "
  126. "This could be because its a personal entity that doesn't belong to any organizations. "
  127. "Please specify the organization in the Registry path or use a team entity in the entity settings."
  128. ),
  129. too_long=ValueError(
  130. f"Personal entity {non_org_entity!r} belongs to multiple organizations "
  131. "and cannot be used without specifying the organization name. "
  132. "Please specify the organization in the Registry path or use a team entity in the entity settings."
  133. ),
  134. )
  135. raise ValueError(f"Unable to find organization for entity {non_org_entity!r}.")