_utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from __future__ import annotations
  2. from collections.abc import Collection
  3. from enum import Enum
  4. from functools import lru_cache, partial
  5. from typing import TYPE_CHECKING, Any
  6. from wandb_gql import gql
  7. from wandb._strutils import ensureprefix
  8. if TYPE_CHECKING:
  9. from wandb.apis.public.api import RetryingClient
  10. class Visibility(str, Enum):
  11. # names are what users see/pass into Python methods
  12. # values are what's expected by backend API
  13. organization = "PRIVATE"
  14. restricted = "RESTRICTED"
  15. @classmethod
  16. def _missing_(cls, value: object) -> Any:
  17. # Allow instantiation from enum names too (e.g. "organization" or "restricted")
  18. return cls.__members__.get(value)
  19. @classmethod
  20. def from_gql(cls, value: str) -> Visibility:
  21. """Convert a GraphQL `visibility` value to a Visibility enum."""
  22. try:
  23. return cls(value)
  24. except ValueError:
  25. expected = ",".join(repr(e.value) for e in cls)
  26. raise ValueError(
  27. f"Invalid visibility {value!r} from backend. Expected one of: {expected}"
  28. ) from None
  29. @classmethod
  30. def from_python(cls, name: str) -> Visibility:
  31. """Convert a visibility string to a `Visibility` enum."""
  32. try:
  33. return cls(name)
  34. except ValueError:
  35. expected = ",".join(repr(e.name) for e in cls)
  36. raise ValueError(
  37. f"Invalid visibility {name!r}. Expected one of: {expected}"
  38. ) from None
  39. def prepare_artifact_types_input(
  40. artifact_types: Collection[str] | None,
  41. ) -> list[dict[str, str]] | None:
  42. """Format the artifact types for the GQL input.
  43. Args:
  44. artifact_types: The artifact types to add to the registry.
  45. Returns:
  46. The artifact types for the GQL input.
  47. """
  48. from wandb.sdk.artifacts._validators import validate_artifact_types
  49. if artifact_types:
  50. return [{"name": typ} for typ in validate_artifact_types(artifact_types)]
  51. return None
  52. def ensure_registry_prefix_on_names(query: Any, in_name: bool = False) -> Any:
  53. """Recursively the registry prefix to values under "name" keys, excluding regex ops.
  54. - in_name: True if we are under a "name" key (or propagating from one).
  55. EX: {"name": "model"} -> {"name": "wandb-registry-model"}
  56. """
  57. from wandb.sdk.artifacts._validators import REGISTRY_PREFIX
  58. if isinstance((txt := query), str):
  59. return ensureprefix(txt, REGISTRY_PREFIX) if in_name else txt
  60. if isinstance((dct := query), dict):
  61. new_dict = {}
  62. for key, obj in dct.items():
  63. if key == "$regex":
  64. # For regex operator, we skip transformation of its value.
  65. new_dict[key] = obj
  66. elif key == "name":
  67. new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=True)
  68. else:
  69. # For any other key, propagate flags as-is.
  70. new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=in_name)
  71. return new_dict
  72. if isinstance((seq := query), (list, tuple)):
  73. return list(map(partial(ensure_registry_prefix_on_names, in_name=in_name), seq))
  74. return query
  75. @lru_cache(maxsize=10)
  76. def fetch_org_entity_from_organization(
  77. client: RetryingClient, organization: str
  78. ) -> str:
  79. """Fetch the org entity from the organization.
  80. Args:
  81. client (Client): Graphql client.
  82. organization (str): The organization to fetch the org entity for.
  83. """
  84. from wandb.sdk.artifacts._generated import (
  85. FETCH_ORG_ENTITY_FROM_ORGANIZATION_GQL,
  86. FetchOrgEntityFromOrganization,
  87. )
  88. gql_op = gql(FETCH_ORG_ENTITY_FROM_ORGANIZATION_GQL)
  89. try:
  90. data = client.execute(gql_op, variable_values={"organization": organization})
  91. except Exception as e:
  92. msg = f"Error fetching org entity for organization: {organization!r}"
  93. raise ValueError(msg) from e
  94. result = FetchOrgEntityFromOrganization.model_validate(data)
  95. if (
  96. not (org := result.organization)
  97. or not (org_entity := org.org_entity)
  98. or not (org_name := org_entity.name)
  99. ):
  100. raise ValueError(f"Organization entity for {organization!r} not found.")
  101. return org_name